summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE21
-rw-r--r--benchmarks/BUILD1
-rw-r--r--benchmarks/README.md126
-rw-r--r--benchmarks/harness/BUILD5
-rw-r--r--benchmarks/harness/__init__.py9
-rw-r--r--benchmarks/harness/machine.py11
-rw-r--r--benchmarks/harness/machine_producers/BUILD40
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer.py268
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer_test.py48
-rw-r--r--benchmarks/harness/machine_producers/machine_producer.py21
-rw-r--r--benchmarks/harness/machine_producers/mock_producer.py23
-rw-r--r--benchmarks/harness/machine_producers/testdata/get_five.json211
-rw-r--r--benchmarks/harness/machine_producers/testdata/get_one.json145
-rw-r--r--benchmarks/harness/ssh_connection.py9
-rw-r--r--benchmarks/runner/BUILD10
-rw-r--r--benchmarks/runner/__init__.py119
-rw-r--r--benchmarks/runner/commands.py135
-rw-r--r--benchmarks/runner/runner_test.py2
-rw-r--r--benchmarks/suites/http.py2
-rwxr-xr-xbenchmarks/tcp/tcp_benchmark.sh21
-rw-r--r--benchmarks/tcp/tcp_proxy.go66
-rw-r--r--benchmarks/workloads/BUILD40
-rw-r--r--benchmarks/workloads/ab/BUILD5
-rw-r--r--benchmarks/workloads/absl/BUILD5
-rw-r--r--benchmarks/workloads/curl/BUILD6
-rw-r--r--benchmarks/workloads/ffmpeg/BUILD6
-rw-r--r--benchmarks/workloads/fio/BUILD5
-rw-r--r--benchmarks/workloads/httpd/BUILD6
-rw-r--r--benchmarks/workloads/iperf/BUILD5
-rw-r--r--benchmarks/workloads/netcat/BUILD6
-rw-r--r--benchmarks/workloads/nginx/BUILD6
-rw-r--r--benchmarks/workloads/node/BUILD6
-rw-r--r--benchmarks/workloads/node_template/BUILD6
-rw-r--r--benchmarks/workloads/redis/BUILD6
-rw-r--r--benchmarks/workloads/redisbenchmark/BUILD5
-rw-r--r--benchmarks/workloads/ruby/BUILD13
-rw-r--r--benchmarks/workloads/ruby_template/BUILD7
-rw-r--r--benchmarks/workloads/sleep/BUILD6
-rw-r--r--benchmarks/workloads/sysbench/BUILD5
-rw-r--r--benchmarks/workloads/syscall/BUILD5
-rw-r--r--benchmarks/workloads/tensorflow/BUILD6
-rw-r--r--benchmarks/workloads/true/BUILD7
-rw-r--r--go.mod2
-rw-r--r--go.sum2
-rw-r--r--kokoro/issue_reviver.cfg15
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/netfilter.go89
-rw-r--r--pkg/abi/linux/rseq.go130
-rw-r--r--pkg/abi/linux/time.go13
-rw-r--r--pkg/amutex/BUILD1
-rw-r--r--pkg/amutex/amutex_test.go3
-rw-r--r--pkg/atomicbitops/BUILD1
-rw-r--r--pkg/atomicbitops/atomic_bitops_test.go3
-rw-r--r--pkg/compressio/BUILD5
-rw-r--r--pkg/compressio/compressio.go2
-rw-r--r--pkg/control/server/BUILD1
-rw-r--r--pkg/control/server/server.go2
-rw-r--r--pkg/cpuid/cpuid.go44
-rw-r--r--pkg/eventchannel/BUILD2
-rw-r--r--pkg/eventchannel/event.go2
-rw-r--r--pkg/eventchannel/event_test.go2
-rw-r--r--pkg/fdchannel/BUILD1
-rw-r--r--pkg/fdchannel/fdchannel_test.go3
-rw-r--r--pkg/fdnotifier/BUILD1
-rw-r--r--pkg/fdnotifier/fdnotifier.go2
-rw-r--r--pkg/flipcall/BUILD3
-rw-r--r--pkg/flipcall/flipcall_example_test.go3
-rw-r--r--pkg/flipcall/flipcall_test.go3
-rw-r--r--pkg/flipcall/flipcall_unsafe.go10
-rw-r--r--pkg/gate/BUILD1
-rw-r--r--pkg/gate/gate_test.go2
-rw-r--r--pkg/goid/BUILD26
-rw-r--r--pkg/goid/empty_test.go (renamed from pkg/sentry/fsimpl/proc/proc.go)12
-rw-r--r--pkg/goid/goid.go (renamed from pkg/sentry/socket/rpcinet/rpcinet.go)14
-rw-r--r--pkg/goid/goid_amd64.s (renamed from pkg/sentry/socket/rpcinet/device.go)12
-rw-r--r--pkg/goid/goid_race.go (renamed from pkg/sentry/fsimpl/proc/filesystems.go)20
-rw-r--r--pkg/goid/goid_test.go74
-rw-r--r--pkg/goid/goid_unsafe.go64
-rw-r--r--pkg/linewriter/BUILD1
-rw-r--r--pkg/linewriter/linewriter.go3
-rw-r--r--pkg/log/BUILD5
-rw-r--r--pkg/log/log.go2
-rw-r--r--pkg/metric/BUILD1
-rw-r--r--pkg/metric/metric.go2
-rw-r--r--pkg/p9/BUILD1
-rw-r--r--pkg/p9/client.go2
-rw-r--r--pkg/p9/client_file.go29
-rw-r--r--pkg/p9/file.go16
-rw-r--r--pkg/p9/handlers.go29
-rw-r--r--pkg/p9/messages.go129
-rw-r--r--pkg/p9/messages_test.go15
-rw-r--r--pkg/p9/p9.go4
-rw-r--r--pkg/p9/p9test/BUILD2
-rw-r--r--pkg/p9/p9test/client_test.go2
-rw-r--r--pkg/p9/p9test/p9test.go2
-rw-r--r--pkg/p9/path_tree.go3
-rw-r--r--pkg/p9/pool.go2
-rw-r--r--pkg/p9/server.go2
-rw-r--r--pkg/p9/transport.go2
-rw-r--r--pkg/p9/version.go8
-rw-r--r--pkg/procid/BUILD2
-rw-r--r--pkg/procid/procid_test.go3
-rw-r--r--pkg/rand/BUILD5
-rw-r--r--pkg/rand/rand_linux.go2
-rw-r--r--pkg/refs/BUILD2
-rw-r--r--pkg/refs/refcounter.go2
-rw-r--r--pkg/refs/refcounter_test.go3
-rw-r--r--pkg/sentry/arch/BUILD7
-rw-r--r--pkg/sentry/arch/arch.go6
-rw-r--r--pkg/sentry/arch/arch_aarch64.go293
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/arch/arch_arm64.go266
-rw-r--r--pkg/sentry/arch/arch_state_aarch64.go38
-rw-r--r--pkg/sentry/arch/arch_state_x86.go2
-rw-r--r--pkg/sentry/arch/arch_x86.go2
-rw-r--r--pkg/sentry/arch/registers.proto37
-rw-r--r--pkg/sentry/arch/signal.go250
-rw-r--r--pkg/sentry/arch/signal_amd64.go230
-rw-r--r--pkg/sentry/arch/signal_arm64.go126
-rw-r--r--pkg/sentry/arch/signal_stack.go2
-rw-r--r--pkg/sentry/arch/syscalls_arm64.go62
-rw-r--r--pkg/sentry/control/BUILD1
-rw-r--r--pkg/sentry/control/pprof.go2
-rw-r--r--pkg/sentry/device/BUILD5
-rw-r--r--pkg/sentry/device/device.go2
-rw-r--r--pkg/sentry/fs/BUILD3
-rw-r--r--pkg/sentry/fs/copy_up.go11
-rw-r--r--pkg/sentry/fs/copy_up_test.go2
-rw-r--r--pkg/sentry/fs/dirent.go2
-rw-r--r--pkg/sentry/fs/dirent_cache.go3
-rw-r--r--pkg/sentry/fs/dirent_cache_limiter.go3
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD1
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go2
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_state.go2
-rw-r--r--pkg/sentry/fs/file.go9
-rw-r--r--pkg/sentry/fs/file_overlay.go4
-rw-r--r--pkg/sentry/fs/filesystems.go2
-rw-r--r--pkg/sentry/fs/fs.go3
-rw-r--r--pkg/sentry/fs/fsutil/BUILD3
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go14
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go2
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode.go44
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go2
-rw-r--r--pkg/sentry/fs/gofer/BUILD1
-rw-r--r--pkg/sentry/fs/gofer/context_file.go14
-rw-r--r--pkg/sentry/fs/gofer/inode.go20
-rw-r--r--pkg/sentry/fs/gofer/session.go2
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/inode.go2
-rw-r--r--pkg/sentry/fs/host/socket.go2
-rw-r--r--pkg/sentry/fs/host/tty.go3
-rw-r--r--pkg/sentry/fs/inode.go27
-rw-r--r--pkg/sentry/fs/inode_inotify.go3
-rw-r--r--pkg/sentry/fs/inode_operations.go25
-rw-r--r--pkg/sentry/fs/inode_overlay.go41
-rw-r--r--pkg/sentry/fs/inode_overlay_test.go4
-rw-r--r--pkg/sentry/fs/inotify.go2
-rw-r--r--pkg/sentry/fs/inotify_watch.go2
-rw-r--r--pkg/sentry/fs/lock/BUILD1
-rw-r--r--pkg/sentry/fs/lock/lock.go5
-rw-r--r--pkg/sentry/fs/mounts.go2
-rw-r--r--pkg/sentry/fs/overlay.go5
-rw-r--r--pkg/sentry/fs/proc/BUILD3
-rw-r--r--pkg/sentry/fs/proc/cgroup.go4
-rw-r--r--pkg/sentry/fs/proc/cpuinfo.go12
-rw-r--r--pkg/sentry/fs/proc/exec_args.go4
-rw-r--r--pkg/sentry/fs/proc/fds.go4
-rw-r--r--pkg/sentry/fs/proc/filesystems.go4
-rw-r--r--pkg/sentry/fs/proc/fs.go4
-rw-r--r--pkg/sentry/fs/proc/inode.go4
-rw-r--r--pkg/sentry/fs/proc/loadavg.go4
-rw-r--r--pkg/sentry/fs/proc/meminfo.go4
-rw-r--r--pkg/sentry/fs/proc/mounts.go4
-rw-r--r--pkg/sentry/fs/proc/net.go4
-rw-r--r--pkg/sentry/fs/proc/proc.go13
-rw-r--r--pkg/sentry/fs/proc/rpcinet_proc.go217
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD1
-rw-r--r--pkg/sentry/fs/proc/seqfile/seqfile.go2
-rw-r--r--pkg/sentry/fs/proc/stat.go4
-rw-r--r--pkg/sentry/fs/proc/sys.go13
-rw-r--r--pkg/sentry/fs/proc/sys_net.go6
-rw-r--r--pkg/sentry/fs/proc/task.go4
-rw-r--r--pkg/sentry/fs/proc/uid_gid_map.go4
-rw-r--r--pkg/sentry/fs/proc/uptime.go4
-rw-r--r--pkg/sentry/fs/proc/version.go4
-rw-r--r--pkg/sentry/fs/ramfs/BUILD1
-rw-r--r--pkg/sentry/fs/ramfs/dir.go2
-rw-r--r--pkg/sentry/fs/restore.go2
-rw-r--r--pkg/sentry/fs/splice.go5
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD1
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go2
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go18
-rw-r--r--pkg/sentry/fs/tty/BUILD1
-rw-r--r--pkg/sentry/fs/tty/dir.go2
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go2
-rw-r--r--pkg/sentry/fs/tty/queue.go3
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD1
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go4
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go10
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go4
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go8
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go6
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go22
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go7
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go63
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go68
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go13
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go8
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go54
-rw-r--r--pkg/sentry/fsimpl/memfs/regular_file.go154
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD45
-rw-r--r--pkg/sentry/fsimpl/proc/boot_test.go149
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go80
-rw-r--r--pkg/sentry/fsimpl/proc/loadavg.go40
-rw-r--r--pkg/sentry/fsimpl/proc/meminfo.go77
-rw-r--r--pkg/sentry/fsimpl/proc/mounts.go33
-rw-r--r--pkg/sentry/fsimpl/proc/stat.go127
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go126
-rw-r--r--pkg/sentry/fsimpl/proc/sys.go51
-rw-r--r--pkg/sentry/fsimpl/proc/task.go364
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go527
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go235
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go337
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_net.go (renamed from pkg/sentry/fsimpl/proc/net.go)3
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go143
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys_test.go (renamed from pkg/sentry/fsimpl/proc/net_test.go)0
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go566
-rw-r--r--pkg/sentry/fsimpl/proc/version.go68
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD (renamed from pkg/sentry/fsimpl/memfs/BUILD)35
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go (renamed from pkg/sentry/fsimpl/memfs/benchmark_test.go)16
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go (renamed from pkg/sentry/fsimpl/memfs/directory.go)2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go (renamed from pkg/sentry/fsimpl/memfs/filesystem.go)28
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go (renamed from pkg/sentry/fsimpl/memfs/named_pipe.go)4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go (renamed from pkg/sentry/fsimpl/memfs/pipe_test.go)8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go392
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go436
-rw-r--r--pkg/sentry/fsimpl/tmpfs/stat_test.go232
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go (renamed from pkg/sentry/fsimpl/memfs/symlink.go)2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go (renamed from pkg/sentry/fsimpl/memfs/memfs.go)162
-rw-r--r--pkg/sentry/kernel/BUILD5
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go2
-rw-r--r--pkg/sentry/kernel/auth/BUILD3
-rw-r--r--pkg/sentry/kernel/auth/user_namespace.go2
-rw-r--r--pkg/sentry/kernel/epoll/BUILD1
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go2
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD1
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go2
-rw-r--r--pkg/sentry/kernel/fasync/BUILD1
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go3
-rw-r--r--pkg/sentry/kernel/fd_table.go2
-rw-r--r--pkg/sentry/kernel/fd_table_test.go2
-rw-r--r--pkg/sentry/kernel/fs_context.go2
-rw-r--r--pkg/sentry/kernel/futex/BUILD8
-rw-r--r--pkg/sentry/kernel/futex/futex.go3
-rw-r--r--pkg/sentry/kernel/futex/futex_test.go2
-rw-r--r--pkg/sentry/kernel/kernel.go9
-rw-r--r--pkg/sentry/kernel/memevent/BUILD1
-rw-r--r--pkg/sentry/kernel/memevent/memory_events.go2
-rw-r--r--pkg/sentry/kernel/pipe/BUILD1
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go2
-rw-r--r--pkg/sentry/kernel/pipe/node.go3
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go2
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go2
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go3
-rw-r--r--pkg/sentry/kernel/rseq.go383
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD1
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go2
-rw-r--r--pkg/sentry/kernel/shm/BUILD1
-rw-r--r--pkg/sentry/kernel/shm/shm.go87
-rw-r--r--pkg/sentry/kernel/signal_handlers.go3
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD1
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go3
-rw-r--r--pkg/sentry/kernel/syscalls.go2
-rw-r--r--pkg/sentry/kernel/syslog.go3
-rw-r--r--pkg/sentry/kernel/task.go48
-rw-r--r--pkg/sentry/kernel/task_clone.go9
-rw-r--r--pkg/sentry/kernel/task_exec.go6
-rw-r--r--pkg/sentry/kernel/task_run.go16
-rw-r--r--pkg/sentry/kernel/task_start.go10
-rw-r--r--pkg/sentry/kernel/thread_group.go28
-rw-r--r--pkg/sentry/kernel/threads.go2
-rw-r--r--pkg/sentry/kernel/time/BUILD1
-rw-r--r--pkg/sentry/kernel/time/time.go2
-rw-r--r--pkg/sentry/kernel/timekeeper.go2
-rw-r--r--pkg/sentry/kernel/tty.go2
-rw-r--r--pkg/sentry/kernel/uts_namespace.go3
-rw-r--r--pkg/sentry/limits/BUILD1
-rw-r--r--pkg/sentry/limits/limits.go3
-rw-r--r--pkg/sentry/mm/BUILD2
-rw-r--r--pkg/sentry/mm/aio_context.go3
-rw-r--r--pkg/sentry/mm/mm.go8
-rw-r--r--pkg/sentry/mm/procfs.go12
-rw-r--r--pkg/sentry/pgalloc/BUILD1
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go2
-rw-r--r--pkg/sentry/platform/interrupt/BUILD1
-rw-r--r--pkg/sentry/platform/interrupt/interrupt.go3
-rw-r--r--pkg/sentry/platform/kvm/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/address_space.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go2
-rw-r--r--pkg/sentry/platform/kvm/kvm.go2
-rw-r--r--pkg/sentry/platform/kvm/machine.go2
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go4
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go4
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go2
-rw-r--r--pkg/sentry/platform/ptrace/stub_amd64.s29
-rw-r--r--pkg/sentry/platform/ptrace/stub_arm64.s30
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go27
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go2
-rw-r--r--pkg/sentry/platform/ring0/defs.go2
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go1
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s81
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s7
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD5
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_x86.go2
-rw-r--r--pkg/sentry/socket/control/control.go8
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go408
-rw-r--r--pkg/sentry/socket/netlink/BUILD2
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/port/port.go3
-rw-r--r--pkg/sentry/socket/netlink/socket.go31
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go173
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD69
-rw-r--r--pkg/sentry/socket/rpcinet/conn/BUILD17
-rw-r--r--pkg/sentry/socket/rpcinet/conn/conn.go187
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/BUILD16
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/notifier.go231
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go909
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go177
-rw-r--r--pkg/sentry/socket/rpcinet/stack_unsafe.go193
-rw-r--r--pkg/sentry/socket/rpcinet/syscall_rpc.proto352
-rw-r--r--pkg/sentry/socket/unix/io.go13
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go3
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go26
-rw-r--r--pkg/sentry/socket/unix/unix.go28
-rw-r--r--pkg/sentry/strace/BUILD3
-rw-r--r--pkg/sentry/strace/linux64_amd64.go (renamed from pkg/sentry/strace/linux64.go)19
-rw-r--r--pkg/sentry/strace/linux64_arm64.go323
-rw-r--r--pkg/sentry/strace/socket.go2
-rw-r--r--pkg/sentry/strace/syscalls.go9
-rw-r--r--pkg/sentry/syscalls/linux/BUILD4
-rw-r--r--pkg/sentry/syscalls/linux/error.go2
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go6
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_clone_amd64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_clone_arm64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_rseq.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go13
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go48
-rw-r--r--pkg/sentry/time/BUILD4
-rw-r--r--pkg/sentry/time/calibrated_clock.go2
-rw-r--r--pkg/sentry/usage/BUILD1
-rw-r--r--pkg/sentry/usage/memory.go2
-rw-r--r--pkg/sentry/vfs/BUILD4
-rw-r--r--pkg/sentry/vfs/dentry.go2
-rw-r--r--pkg/sentry/vfs/device.go100
-rw-r--r--pkg/sentry/vfs/file_description.go101
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go28
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go2
-rw-r--r--pkg/sentry/vfs/filesystem.go21
-rw-r--r--pkg/sentry/vfs/filesystem_type.go55
-rw-r--r--pkg/sentry/vfs/mount.go15
-rw-r--r--pkg/sentry/vfs/mount_test.go3
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go4
-rw-r--r--pkg/sentry/vfs/options.go4
-rw-r--r--pkg/sentry/vfs/pathname.go3
-rw-r--r--pkg/sentry/vfs/permissions.go24
-rw-r--r--pkg/sentry/vfs/resolving_path.go2
-rw-r--r--pkg/sentry/vfs/vfs.go20
-rw-r--r--pkg/sentry/watchdog/BUILD1
-rw-r--r--pkg/sentry/watchdog/watchdog.go2
-rw-r--r--pkg/sleep/sleep_test.go31
-rw-r--r--pkg/sync/BUILD (renamed from pkg/syncutil/BUILD)9
-rw-r--r--pkg/sync/LICENSE (renamed from pkg/syncutil/LICENSE)0
-rw-r--r--pkg/sync/README.md (renamed from pkg/syncutil/README.md)0
-rw-r--r--pkg/sync/aliases.go37
-rw-r--r--pkg/sync/atomicptr_unsafe.go (renamed from pkg/syncutil/atomicptr_unsafe.go)0
-rw-r--r--pkg/sync/atomicptrtest/BUILD (renamed from pkg/syncutil/atomicptrtest/BUILD)4
-rw-r--r--pkg/sync/atomicptrtest/atomicptr_test.go (renamed from pkg/syncutil/atomicptrtest/atomicptr_test.go)0
-rw-r--r--pkg/sync/downgradable_rwmutex_test.go (renamed from pkg/syncutil/downgradable_rwmutex_test.go)2
-rw-r--r--pkg/sync/downgradable_rwmutex_unsafe.go (renamed from pkg/syncutil/downgradable_rwmutex_unsafe.go)2
-rw-r--r--pkg/sync/memmove_unsafe.go (renamed from pkg/syncutil/memmove_unsafe.go)2
-rw-r--r--pkg/sync/norace_unsafe.go (renamed from pkg/syncutil/norace_unsafe.go)2
-rw-r--r--pkg/sync/race_unsafe.go (renamed from pkg/syncutil/race_unsafe.go)2
-rw-r--r--pkg/sync/seqatomic_unsafe.go (renamed from pkg/syncutil/seqatomic_unsafe.go)16
-rw-r--r--pkg/sync/seqatomictest/BUILD (renamed from pkg/syncutil/seqatomictest/BUILD)10
-rw-r--r--pkg/sync/seqatomictest/seqatomic_test.go (renamed from pkg/syncutil/seqatomictest/seqatomic_test.go)18
-rw-r--r--pkg/sync/seqcount.go (renamed from pkg/syncutil/seqcount.go)2
-rw-r--r--pkg/sync/seqcount_test.go (renamed from pkg/syncutil/seqcount_test.go)2
-rw-r--r--pkg/sync/syncutil.go (renamed from pkg/syncutil/syncutil.go)4
-rw-r--r--pkg/tcpip/BUILD9
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go2
-rw-r--r--pkg/tcpip/checker/checker.go22
-rw-r--r--pkg/tcpip/header/BUILD2
-rw-r--r--pkg/tcpip/header/ipv6.go95
-rw-r--r--pkg/tcpip/header/ipv6_test.go259
-rw-r--r--pkg/tcpip/header/ndp_router_solicit.go36
-rw-r--r--pkg/tcpip/iptables/BUILD5
-rw-r--r--pkg/tcpip/iptables/iptables.go120
-rw-r--r--pkg/tcpip/iptables/targets.go16
-rw-r--r--pkg/tcpip/iptables/types.go49
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go2
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD2
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD1
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go3
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go2
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go2
-rw-r--r--pkg/tcpip/network/arp/arp.go6
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD1
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go2
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go2
-rw-r--r--pkg/tcpip/network/ip_test.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go18
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go14
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go2
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go2
-rw-r--r--pkg/tcpip/stack/ndp.go743
-rw-r--r--pkg/tcpip/stack/ndp_test.go1179
-rw-r--r--pkg/tcpip/stack/nic.go320
-rw-r--r--pkg/tcpip/stack/registration.go6
-rw-r--r--pkg/tcpip/stack/route.go6
-rw-r--r--pkg/tcpip/stack/stack.go217
-rw-r--r--pkg/tcpip/stack/stack_test.go401
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go56
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go92
-rw-r--r--pkg/tcpip/stack/transport_test.go14
-rw-r--r--pkg/tcpip/tcpip.go41
-rw-r--r--pkg/tcpip/timer.go161
-rw-r--r--pkg/tcpip/timer_test.go236
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/packet/BUILD1
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go25
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go47
-rw-r--r--pkg/tcpip/transport/tcp/BUILD16
-rw-r--r--pkg/tcpip/transport/tcp/accept.go11
-rw-r--r--pkg/tcpip/transport/tcp/connect.go340
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go224
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go501
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go32
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go3
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go13
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go27
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go2
-rw-r--r--pkg/tcpip/transport/tcp/snd.go23
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go211
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go16
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go167
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go105
-rw-r--r--pkg/tmutex/BUILD1
-rw-r--r--pkg/tmutex/tmutex_test.go3
-rw-r--r--pkg/unet/BUILD1
-rw-r--r--pkg/unet/unet_test.go3
-rw-r--r--pkg/urpc/BUILD1
-rw-r--r--pkg/urpc/urpc.go2
-rw-r--r--pkg/waiter/BUILD1
-rw-r--r--pkg/waiter/waiter.go2
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/compat.go2
-rw-r--r--runsc/boot/limits.go2
-rw-r--r--runsc/boot/loader.go2
-rw-r--r--runsc/boot/loader_test.go2
-rw-r--r--runsc/boot/network.go17
-rw-r--r--runsc/cmd/BUILD1
-rw-r--r--runsc/cmd/create.go1
-rw-r--r--runsc/cmd/gofer.go2
-rw-r--r--runsc/cmd/start.go1
-rw-r--r--runsc/container/BUILD2
-rw-r--r--runsc/container/console_test.go2
-rw-r--r--runsc/container/container_test.go2
-rw-r--r--runsc/container/multi_container_test.go2
-rw-r--r--runsc/container/state_file.go2
-rw-r--r--runsc/fsgofer/BUILD1
-rw-r--r--runsc/fsgofer/fsgofer.go12
-rw-r--r--runsc/sandbox/BUILD1
-rw-r--r--runsc/sandbox/network.go23
-rw-r--r--runsc/sandbox/sandbox.go2
-rw-r--r--runsc/testutil/BUILD1
-rw-r--r--runsc/testutil/testutil.go2
-rwxr-xr-xscripts/common.sh2
-rwxr-xr-xscripts/issue_reviver.sh27
-rw-r--r--test/iptables/BUILD2
-rw-r--r--test/iptables/README.md2
-rw-r--r--test/iptables/filter_input.go64
-rw-r--r--test/iptables/filter_output.go87
-rw-r--r--test/iptables/iptables_test.go38
-rw-r--r--test/iptables/iptables_util.go53
-rw-r--r--test/iptables/nat.go80
-rw-r--r--test/syscalls/BUILD5
-rw-r--r--test/syscalls/linux/BUILD2
-rw-r--r--test/syscalls/linux/inotify.cc32
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc10
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h6
-rw-r--r--test/syscalls/linux/partial_bad_buffer.cc138
-rw-r--r--test/syscalls/linux/poll.cc3
-rw-r--r--test/syscalls/linux/preadv2.cc2
-rw-r--r--test/syscalls/linux/readv_common.cc2
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc25
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc185
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc40
-rw-r--r--test/syscalls/linux/socket_ip_unbound.cc33
-rw-r--r--test/syscalls/linux/socket_non_stream.cc113
-rw-r--r--test/syscalls/linux/socket_non_stream_blocking.cc37
-rw-r--r--test/syscalls/linux/socket_stream.cc55
-rw-r--r--test/syscalls/linux/tcp_socket.cc14
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.cc8
-rw-r--r--test/syscalls/linux/xattr.cc152
-rw-r--r--tools/issue_reviver/BUILD12
-rw-r--r--tools/issue_reviver/github/BUILD17
-rw-r--r--tools/issue_reviver/github/github.go164
-rw-r--r--tools/issue_reviver/main.go89
-rw-r--r--tools/issue_reviver/reviver/BUILD19
-rw-r--r--tools/issue_reviver/reviver/reviver.go192
-rw-r--r--tools/issue_reviver/reviver/reviver_test.go88
538 files changed, 15775 insertions, 5995 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 4b5a3bfe2..e2afc073c 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -290,6 +290,27 @@ go_repository(
version = "v1.3.1",
)
+go_repository(
+ name = "com_github_google_go-github",
+ importpath = "github.com/google/go-github",
+ sum = "h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=",
+ version = "v17.0.0",
+)
+
+go_repository(
+ name = "org_golang_x_oauth2",
+ importpath = "golang.org/x/oauth2",
+ sum = "h1:pE8b58s1HRDMi8RDc79m0HISf9D4TzseP40cEA6IGfs=",
+ version = "v0.0.0-20191202225959-858c2ad4c8b6",
+)
+
+go_repository(
+ name = "com_github_google_go-querystring",
+ importpath = "github.com/google/go-querystring",
+ sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=",
+ version = "v1.0.0",
+)
+
# System Call test dependencies.
http_archive(
name = "com_google_absl",
diff --git a/benchmarks/BUILD b/benchmarks/BUILD
index dbadeeaf2..1455c6c5b 100644
--- a/benchmarks/BUILD
+++ b/benchmarks/BUILD
@@ -5,5 +5,6 @@ py_binary(
srcs = ["run.py"],
main = "run.py",
python_version = "PY3",
+ srcs_version = "PY3",
deps = ["//benchmarks/runner"],
)
diff --git a/benchmarks/README.md b/benchmarks/README.md
index ad44cd6ac..ff21614c5 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -6,66 +6,55 @@ These scripts are tools for collecting performance data for Docker-based tests.
The scripts assume the following:
-* You have a local machine with bazel installed.
-* You have some machine(s) with docker installed. These machines will be
- refered to as the "Environment".
-* Environment machines have the runtime(s) under test installed, such that you
- can run docker with a command like: `docker run --runtime=$RUNTIME
- your/image`.
-* You are able to login to machines in the environment with the local machine
- via ssh and the user for ssh can run docker commands without using `sudo`.
+* There are two sets of machines: one where the scripts will be run
+ (controller) and one or more machines on which docker containers will be run
+ (environment).
+* The controller machine must have bazel installed along with this source
+ code. You should be able to run a command like `bazel run :benchmarks --
+ --list`
+* Environment machines must have docker and the required runtimes installed.
+ More specifically, you should be able to run a command like: `docker run
+ --runtime=$RUNTIME your/image`.
+* The controller has ssh private key which can be used to login to environment
+ machines and run docker commands without using `sudo`. This is not required
+ if running locally via the `run-local` command.
* The docker daemon on each of your environment machines is listening on
`unix:///var/run/docker.sock` (docker's default).
For configuring the environment manually, consult the
[dockerd documentation][dockerd].
-## Environment
-
-All benchmarks require a user defined yaml file describe the environment. These
-files are of the form:
-
-```yaml
-machine1: local
-machine2:
- hostname: 100.100.100.100
- username: username
- key_path: ~/private_keyfile
- key_password: passphrase
-machine3:
- hostname: 100.100.100.101
- username: username
- key_path: ~/private_keyfile
- key_password: passphrase
-```
+## Running benchmarks
-The yaml file defines an environment with three machines named `machine1`,
-`machine2` and `machine3`. `machine1` is the local machine, `machine2` and
-`machine3` are remote machines. Both `machine2` and `machine3` should be
-reachable by `ssh`. For example, the command `ssh -i ~/private_keyfile
-username@100.100.100.100` (using the passphrase `passphrase`) should connect to
-`machine2`.
+Run the following from the benchmarks directory:
-The above is an example only. Machines should be uniform, since they are treated
-as such by the tests. Machines must also be accessible to each other via their
-default routes. Furthermore, some benchmarks will meaningless if running on the
-local machine, such as density.
+```bash
+bazel run :benchmarks -- run-local startup
-For remote machines, `hostname`, `key_path`, and `username` are required and
-others are optional. In addition key files must be generated
-[using the instrcutions below](#generating-ssh-keys).
+...
+method,metric,result
+startup.empty,startup_time_ms,652.5772
+startup.node,startup_time_ms,1654.4042000000002
+startup.ruby,startup_time_ms,1429.835
+```
-The above yaml file can be checked for correctness with the `validate` command
-in the top level perf.py script:
+The above command ran the startup benchmark locally, which consists of three
+benchmarks (empty, node, and ruby). Benchmark tools ran it on the default
+runtime, runc. Running on another installed runtime, like say runsc, is as
+simple as:
-`bazel run :benchmarks -- validate $PWD/examples/localhost.yaml`
+```bash
+bazel run :benchmakrs -- run-local startup --runtime=runsc
+```
-## Running benchmarks
+There is help: ``bash bash bazel run :benchmarks -- --help bazel
+run :benchmarks -- run-local --help` ``
To list available benchmarks, use the `list` commmand:
```bash
bazel run :benchmarks -- list
+ls
...
Benchmark: sysbench.cpu
@@ -75,24 +64,44 @@ Metrics: events_per_second
:param max_prime: The maximum prime number to search.
```
-To run benchmarks, use the `run` command. For example, to run the sysbench
-benchmark above:
+You can choose benchmarks by name or regex like:
```bash
-bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml sysbench.cpu
+bazel run :benchmarks -- run-local startup.node
+...
+metric,result
+startup_time_ms,1671.7178000000001
+
+```
+
+or
+
+```bash
+bazel run :benchmarks -- run-local s
+...
+method,metric,result
+startup.empty,startup_time_ms,1792.8292
+startup.node,startup_time_ms,3113.5274
+startup.ruby,startup_time_ms,3025.2424
+sysbench.cpu,cpu_events_per_second,12661.47
+sysbench.memory,memory_ops_per_second,7228268.44
+sysbench.mutex,mutex_time,17.4835
+sysbench.mutex,mutex_latency,3496.7
+sysbench.mutex,mutex_deviation,0.04
+syscall.syscall,syscall_time_ns,2065.0
```
You can run parameterized benchmarks, for example to run with different
runtimes:
```bash
-bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --runtime=runc --runtime=runsc sysbench.cpu
+bazel run :benchmarks -- run-local --runtime=runc --runtime=runsc sysbench.cpu
```
Or with different parameters:
```bash
-bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --max_prime=10 --max_prime=100 sysbench.cpu
+bazel run :benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu
```
## Writing benchmarks
@@ -121,7 +130,7 @@ The harness requires workloads to run. These are all available in the
In general, a workload consists of a Dockerfile to build it (while these are not
hermetic, in general they should be as fixed and isolated as possible), some
-parses for output if required, parser tests and sample data. Provided the test
+parsers for output if required, parser tests and sample data. Provided the test
is named after the workload package and contains a function named `sample`, this
variable will be used to automatically mock workload output when the `--mock`
flag is provided to the main tool.
@@ -149,24 +158,5 @@ To write a new benchmark, open a module in the `suites` directory and use the
above signature. You should add a descriptive doc string to describe what your
benchmark is and any test centric arguments.
-## Generating SSH Keys
-
-The scripts only support RSA Keys, and ssh library used in paramiko. Paramiko
-only supports RSA keys that look like the following (PEM format):
-
-```bash
-$ cat /path/to/ssh/key
-
------BEGIN RSA PRIVATE KEY-----
-...private key text...
------END RSA PRIVATE KEY-----
-
-```
-
-To generate ssh keys in PEM format, use the [`-t rsa -m PEM -b 4096`][RSA-keys].
-option.
-
[dockerd]: https://docs.docker.com/engine/reference/commandline/dockerd/
[docker-py]: https://docker-py.readthedocs.io/en/stable/
-[paramiko]: http://docs.paramiko.org/en/2.4/api/client.html
-[RSA-keys]: https://serverfault.com/questions/939909/ssh-keygen-does-not-create-rsa-private-key
diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD
index 9546220c4..081a74243 100644
--- a/benchmarks/harness/BUILD
+++ b/benchmarks/harness/BUILD
@@ -24,6 +24,7 @@ py_library(
name = "container",
srcs = ["container.py"],
deps = [
+ "//benchmarks/workloads",
requirement("asn1crypto", False),
requirement("chardet", False),
requirement("certifi", False),
@@ -45,6 +46,7 @@ py_library(
"//benchmarks/harness:container",
"//benchmarks/harness:ssh_connection",
"//benchmarks/harness:tunnel_dispatcher",
+ "//benchmarks/harness/machine_mocks",
requirement("asn1crypto", False),
requirement("chardet", False),
requirement("certifi", False),
@@ -53,6 +55,7 @@ py_library(
requirement("idna", False),
requirement("ptyprocess", False),
requirement("requests", False),
+ requirement("six", False),
requirement("urllib3", False),
requirement("websocket-client", False),
],
@@ -64,7 +67,7 @@ py_library(
deps = [
"//benchmarks/harness",
requirement("bcrypt", False),
- requirement("cffi", False),
+ requirement("cffi", True),
requirement("paramiko", True),
requirement("cryptography", False),
],
diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py
index a7f34da9e..61fd25f73 100644
--- a/benchmarks/harness/__init__.py
+++ b/benchmarks/harness/__init__.py
@@ -13,13 +13,20 @@
# limitations under the License.
"""Core benchmark utilities."""
+import getpass
import os
# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a
# format string that accepts a single string parameter.
LOCAL_WORKLOADS_PATH = os.path.join(
- os.path.dirname(__file__), "../workloads/{}")
+ os.path.dirname(__file__), "../workloads/{}/tar.tar")
# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the
# remote host. This is a format string that accepts a single string parameter.
REMOTE_WORKLOADS_PATH = "workloads/{}"
+
+# DEFAULT_USER is the default user running this script.
+DEFAULT_USER = getpass.getuser()
+
+# DEFAULT_USER_HOME is the home directory of the user running the script.
+DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else ""
diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py
index 66b719b63..2df4c9e31 100644
--- a/benchmarks/harness/machine.py
+++ b/benchmarks/harness/machine.py
@@ -160,15 +160,17 @@ class LocalMachine(Machine):
stdout, stderr = process.communicate()
return stdout.decode("utf-8"), stderr.decode("utf-8")
- def read(self, path: str) -> str:
+ def read(self, path: str) -> bytes:
# Read the exact path locally.
return open(path, "r").read()
def pull(self, workload: str) -> str:
# Run the docker build command locally.
logging.info("Building %s@%s locally...", workload, self._name)
- self.run("docker build --tag={} {}".format(
- workload, harness.LOCAL_WORKLOADS_PATH.format(workload)))
+ with open(harness.LOCAL_WORKLOADS_PATH.format(workload),
+ "rb") as dockerfile:
+ self._docker_client.images.build(
+ fileobj=dockerfile, tag=workload, custom_context=True)
return workload # Workload is the tag.
def container(self, image: str, **kwargs) -> container.Container:
@@ -212,6 +214,9 @@ class RemoteMachine(Machine):
# Push to the remote machine and build.
logging.info("Building %s@%s remotely...", workload, self._name)
remote_path = self._ssh_connection.send_workload(workload)
+ # Workloads are all tarballs.
+ self.run("tar -xvf {remote_path}/tar.tar -C {remote_path}".format(
+ remote_path=remote_path))
self.run("docker build --tag={} {}".format(workload, remote_path))
return workload # Workload is the tag.
diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD
index a48da02a1..c4e943882 100644
--- a/benchmarks/harness/machine_producers/BUILD
+++ b/benchmarks/harness/machine_producers/BUILD
@@ -20,6 +20,7 @@ py_library(
srcs = ["mock_producer.py"],
deps = [
"//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:gcloud_producer",
"//benchmarks/harness/machine_producers:machine_producer",
],
)
@@ -38,3 +39,42 @@ py_library(
name = "gcloud_mock_recorder",
srcs = ["gcloud_mock_recorder.py"],
)
+
+py_library(
+ name = "gcloud_producer",
+ srcs = ["gcloud_producer.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:gcloud_mock_recorder",
+ "//benchmarks/harness/machine_producers:machine_producer",
+ ],
+)
+
+filegroup(
+ name = "test_data",
+ srcs = [
+ "testdata/get_five.json",
+ "testdata/get_one.json",
+ ],
+)
+
+py_library(
+ name = "gcloud_producer_test_lib",
+ srcs = ["gcloud_producer_test.py"],
+ deps = [
+ "//benchmarks/harness/machine_producers:machine_producer",
+ "//benchmarks/harness/machine_producers:mock_producer",
+ ],
+)
+
+py_test(
+ name = "gcloud_producer_test",
+ srcs = [":gcloud_producer_test_lib"],
+ data = [
+ ":test_data",
+ ],
+ python_version = "PY3",
+ tags = [
+ "local",
+ ],
+)
diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py
new file mode 100644
index 000000000..e0b77d52b
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_producer.py
@@ -0,0 +1,268 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A machine producer which produces machine objects using `gcloud`.
+
+Machine producers produce valid harness.Machine objects which are backed by
+real machines. This producer produces those machines on the given user's GCP
+account using the `gcloud` tool.
+
+GCloudProducer creates instances on the given GCP account named like:
+`machine-XXXXXXX-XXXX-XXXX-XXXXXXXXXXXX` in a randomized fashion such that name
+collisions with user instances shouldn't happen.
+
+ Typical usage example:
+
+ producer = GCloudProducer(args)
+ machines = producer.get_machines(NUM_MACHINES)
+ # run stuff on machines with machines[i].run(CMD)
+ producer.release_machines(NUM_MACHINES)
+"""
+import datetime
+import json
+import subprocess
+import threading
+from typing import List, Dict, Any
+import uuid
+
+from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import gcloud_mock_recorder
+from benchmarks.harness.machine_producers import machine_producer
+
+
+class GCloudProducer(machine_producer.MachineProducer):
+ """Implementation of MachineProducer backed by GCP.
+
+ Produces Machine objects backed by GCP instances.
+
+ Attributes:
+ project: The GCP project name under which to create the machines.
+ ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys.
+ image: image name as a string.
+ image_project: image project as a string.
+ machine_type: type of GCP to create. e.g. n1-standard-4
+ zone: string to a valid GCP zone.
+ ssh_user: string of user name for ssh_key
+ ssh_password: string of password for ssh key
+ mock: a mock printer which will print mock data if required. Mock data is
+ recorded output from subprocess calls (returncode, stdout, args).
+ condition: mutex for this class around machine creation and deleteion.
+ """
+
+ def __init__(self,
+ project: str,
+ ssh_key_file: str,
+ image: str,
+ image_project: str,
+ machine_type: str,
+ zone: str,
+ ssh_user: str,
+ ssh_password: str,
+ mock: gcloud_mock_recorder.MockPrinter = None):
+ self.project = project
+ self.ssh_key_file = ssh_key_file
+ self.image = image
+ self.image_project = image_project
+ self.machine_type = machine_type
+ self.zone = zone
+ self.ssh_user = ssh_user
+ self.ssh_password = ssh_password
+ self.mock = mock
+ self.condition = threading.Condition()
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ """Returns requested number of machines backed by GCP instances."""
+ if num_machines <= 0:
+ raise ValueError(
+ "Cannot ask for {num} machines!".format(num=num_machines))
+ with self.condition:
+ names = self._get_unique_names(num_machines)
+ self._build_instances(names)
+ instances = self._start_command(names)
+ self._add_ssh_key_to_instances(names)
+ return self._machines_from_instances(instances)
+
+ def release_machines(self, machine_list: List[machine.Machine]):
+ """Releases the requested number of machines, deleting the instances."""
+ if not machine_list:
+ return
+ cmd = "gcloud compute instances delete --quiet".split(" ")
+ names = [str(m) for m in machine_list]
+ cmd.extend(names)
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ self._run_command(cmd, detach=True)
+
+ def _machines_from_instances(
+ self, instances: List[Dict[str, Any]]) -> List[machine.Machine]:
+ """Creates Machine Objects from json data describing created instances."""
+ machines = []
+ for instance in instances:
+ name = instance["name"]
+ kwargs = {
+ "hostname":
+ instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"],
+ "key_path":
+ self.ssh_key_file,
+ "username":
+ self.ssh_user,
+ "key_password":
+ self.ssh_password
+ }
+ machines.append(machine.RemoteMachine(name=name, **kwargs))
+ return machines
+
+ def _get_unique_names(self, num_names) -> List[str]:
+ """Returns num_names unique names based on data from the GCP project."""
+ curr_machines = self._list_machines()
+ curr_names = set([machine["name"] for machine in curr_machines])
+ ret = []
+ while len(ret) < num_names:
+ new_name = "machine-" + str(uuid.uuid4())
+ if new_name not in curr_names:
+ ret.append(new_name)
+ curr_names.update(new_name)
+ return ret
+
+ def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]:
+ """Creates instances using gcloud command.
+
+ Runs the command `gcloud compute instances create` and returns json data
+ on created instances on success. Creates len(names) instances, one for each
+ name.
+
+ Args:
+ names: list of names of instances to create.
+
+ Returns:
+ List of json data describing created machines.
+ """
+ if not names:
+ raise ValueError(
+ "_build_instances cannot create instances without names.")
+ cmd = "gcloud compute instances create".split(" ")
+ cmd.extend(names)
+ cmd.extend(
+ "--preemptible --image={image} --zone={zone} --machine-type={machine_type}"
+ .format(
+ image=self.image, zone=self.zone,
+ machine_type=self.machine_type).split(" "))
+ if self.image_project:
+ cmd.append("--image-project={project}".format(project=self.image_project))
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _start_command(self, names):
+ """Starts instances using gcloud command.
+
+ Runs the command `gcloud compute instances start` on list of instances by
+ name and returns json data on started instances on success.
+
+ Args:
+ names: list of names of instances to start.
+
+ Returns:
+ List of json data describing started machines.
+ """
+ if not names:
+ raise ValueError("_start_command cannot start empty instance list.")
+ cmd = "gcloud compute instances start".split(" ")
+ cmd.extend(names)
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ cmd.append("--project={project}".format(project=self.project))
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _add_ssh_key_to_instances(self, names: List[str]) -> None:
+ """Adds ssh key to instances by calling gcloud ssh command.
+
+ Runs the command `gcloud compute ssh instance_name` on list of images by
+ name. Tries to ssh into given instance
+
+ Args:
+ names: list of machine names to which to add the ssh-key
+ self.ssh_key_file.
+
+ Raises:
+ subprocess.CalledProcessError: when underlying subprocess call returns an
+ error other than 255 (Connection closed by remote host).
+ TimeoutError: when 3 unsuccessful tries to ssh into the host return 255.
+ """
+ for name in names:
+ cmd = "gcloud compute ssh {name}".format(name=name).split(" ")
+ cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file))
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ cmd.append("--command=uname")
+ timeout = datetime.timedelta(seconds=5 * 60)
+ start = datetime.datetime.now()
+ while datetime.datetime.now() <= timeout + start:
+ try:
+ self._run_command(cmd)
+ break
+ except subprocess.CalledProcessError as e:
+ if datetime.datetime.now() > timeout + start:
+ raise TimeoutError(
+ "Could not SSH into instance after 5 min: {name}".format(
+ name=name))
+ # 255 is the returncode for ssh connection refused.
+ elif e.returncode == 255:
+
+ continue
+ else:
+ raise e
+
+ def _list_machines(self) -> List[Dict[str, Any]]:
+ """Runs `list` gcloud command and returns list of Machine data."""
+ cmd = "gcloud compute instances list --project {project}".format(
+ project=self.project).split(" ")
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _run_command(self,
+ cmd: List[str],
+ detach: bool = False) -> [None, subprocess.CompletedProcess]:
+ """Runs command as a subprocess.
+
+ Runs command as subprocess and returns the result.
+ If this has a mock recorder, use the record method to record the subprocess
+ call.
+
+ Args:
+ cmd: command to be run as a list of strings.
+ detach: if True, run the child process and don't wait for it to return.
+
+ Returns:
+ Completed process object to be parsed by caller or None if detach=True.
+
+ Raises:
+ CalledProcessError: if subprocess.run returns an error.
+ """
+ cmd = cmd + ["--format=json"]
+ if detach:
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if self.mock:
+ out, _ = p.communicate()
+ self.mock.record(
+ subprocess.CompletedProcess(
+ returncode=p.returncode, stdout=out, args=p.args))
+ return
+
+ res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if self.mock:
+ self.mock.record(res)
+ if res.returncode != 0:
+ raise subprocess.CalledProcessError(
+ cmd=res.args,
+ output=res.stdout,
+ stderr=res.stderr,
+ returncode=res.returncode)
+ return res
diff --git a/benchmarks/harness/machine_producers/gcloud_producer_test.py b/benchmarks/harness/machine_producers/gcloud_producer_test.py
new file mode 100644
index 000000000..c8adb2bdc
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_producer_test.py
@@ -0,0 +1,48 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests GCloudProducer using mock data.
+
+GCloudProducer produces machines using 'get_machines' and 'release_machines'
+methods. The tests check recorded data (jsonified subprocess.CompletedProcess
+objects) of the producer producing one and five machines.
+"""
+import os
+import types
+
+from benchmarks.harness.machine_producers import machine_producer
+from benchmarks.harness.machine_producers import mock_producer
+
+TEST_DIR = os.path.dirname(__file__)
+
+
+def run_get_release(producer: machine_producer.MachineProducer,
+ num_machines: int,
+ validator: types.FunctionType = None):
+ machines = producer.get_machines(num_machines)
+ assert len(machines) == num_machines
+ if validator:
+ validator(machines=machines, cmd="uname -a", workload=None)
+ producer.release_machines(machines)
+
+
+def test_run_one():
+ mock = mock_producer.MockReader(TEST_DIR + "get_one.json")
+ producer = mock_producer.MockGCloudProducer(mock)
+ run_get_release(producer, 1)
+
+
+def test_run_five():
+ mock = mock_producer.MockReader(TEST_DIR + "get_five.json")
+ producer = mock_producer.MockGCloudProducer(mock)
+ run_get_release(producer, 5)
diff --git a/benchmarks/harness/machine_producers/machine_producer.py b/benchmarks/harness/machine_producers/machine_producer.py
index 124ee14cc..f5591c026 100644
--- a/benchmarks/harness/machine_producers/machine_producer.py
+++ b/benchmarks/harness/machine_producers/machine_producer.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Abstract types."""
+import threading
from typing import List
from benchmarks.harness import machine
@@ -28,3 +29,23 @@ class MachineProducer:
def release_machines(self, machine_list: List[machine.Machine]):
"""Releases the given set of machines."""
raise NotImplementedError
+
+
+class LocalMachineProducer(MachineProducer):
+ """Produces Local Machines."""
+
+ def __init__(self, limit: int):
+ self.limit_sem = threading.Semaphore(value=limit)
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ """Returns the request number of MockMachines."""
+
+ self.limit_sem.acquire()
+ return [machine.LocalMachine("local") for _ in range(num_machines)]
+
+ def release_machines(self, machine_list: List[machine.MockMachine]):
+ """No-op."""
+ if not machine_list:
+ raise ValueError("Cannot release an empty list!")
+ self.limit_sem.release()
+ machine_list.clear()
diff --git a/benchmarks/harness/machine_producers/mock_producer.py b/benchmarks/harness/machine_producers/mock_producer.py
index 4f29ad53f..37e9cb4b7 100644
--- a/benchmarks/harness/machine_producers/mock_producer.py
+++ b/benchmarks/harness/machine_producers/mock_producer.py
@@ -13,9 +13,11 @@
# limitations under the License.
"""Producers of mocks."""
-from typing import List
+from typing import List, Any
from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import gcloud_mock_recorder
+from benchmarks.harness.machine_producers import gcloud_producer
from benchmarks.harness.machine_producers import machine_producer
@@ -29,3 +31,22 @@ class MockMachineProducer(machine_producer.MachineProducer):
def release_machines(self, machine_list: List[machine.MockMachine]):
"""No-op."""
return
+
+
+class MockGCloudProducer(gcloud_producer.GCloudProducer):
+ """Mocks GCloudProducer for testing purposes."""
+
+ def __init__(self, mock: gcloud_mock_recorder.MockReader, **kwargs):
+ gcloud_producer.GCloudProducer.__init__(
+ self, project="mock", ssh_private_key_path="mock", **kwargs)
+ self.mock = mock
+
+ def _validate_ssh_file(self):
+ pass
+
+ def _run_command(self, cmd):
+ return self.mock.pop(cmd)
+
+ def _machines_from_instances(
+ self, instances: List[Any]) -> List[machine.MockMachine]:
+ return [machine.MockMachine() for _ in instances]
diff --git a/benchmarks/harness/machine_producers/testdata/get_five.json b/benchmarks/harness/machine_producers/testdata/get_five.json
new file mode 100644
index 000000000..32bad1b06
--- /dev/null
+++ b/benchmarks/harness/machine_producers/testdata/get_five.json
@@ -0,0 +1,211 @@
+[
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "list",
+ "--project",
+ "project",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "create",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--preemptible",
+ "--image=ubuntu-1910-eoan-v20191204",
+ "--zone=us-west1-b",
+ "--image-project=ubuntu-os-cloud",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "start",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--zone=us-west1-b",
+ "--project=project",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "delete",
+ "--quiet",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--zone=us-west1-b",
+ "--format=json"
+ ],
+ "stdout": "[]\n",
+ "returncode": "0"
+ }
+]
diff --git a/benchmarks/harness/machine_producers/testdata/get_one.json b/benchmarks/harness/machine_producers/testdata/get_one.json
new file mode 100644
index 000000000..c359c19c8
--- /dev/null
+++ b/benchmarks/harness/machine_producers/testdata/get_one.json
@@ -0,0 +1,145 @@
+[
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "list",
+ "--project",
+ "linux-testing-user",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "create",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--preemptible",
+ "--image=ubuntu-1910-eoan-v20191204",
+ "--zone=us-west1-b",
+ "--image-project=ubuntu-os-cloud",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "start",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--zone=us-west1-b",
+ "--project=linux-testing-user",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "delete",
+ "--quiet",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--zone=us-west1-b",
+ "--format=json"
+ ],
+ "stdout": "[]\n",
+ "returncode": "0"
+ }
+]
diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py
index fcbfbcdb2..e0bf258f1 100644
--- a/benchmarks/harness/ssh_connection.py
+++ b/benchmarks/harness/ssh_connection.py
@@ -94,7 +94,7 @@ class SSHConnection:
return stdout, stderr
def send_workload(self, name: str) -> str:
- """Sends a workload to the remote machine.
+ """Sends a workload tarball to the remote machine.
Args:
name: The workload name.
@@ -103,9 +103,6 @@ class SSHConnection:
The remote path.
"""
with self._client() as client:
- for dirpath, _, filenames in os.walk(
- harness.LOCAL_WORKLOADS_PATH.format(name)):
- for filename in filenames:
- send_one_file(client, os.path.join(dirpath, filename),
- harness.REMOTE_WORKLOADS_PATH.format(name))
+ send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name),
+ harness.REMOTE_WORKLOADS_PATH.format(name))
return harness.REMOTE_WORKLOADS_PATH.format(name)
diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD
index de24824cc..e1b2ea550 100644
--- a/benchmarks/runner/BUILD
+++ b/benchmarks/runner/BUILD
@@ -10,7 +10,9 @@ py_library(
],
visibility = ["//benchmarks:__pkg__"],
deps = [
+ ":commands",
"//benchmarks/harness:benchmark_driver",
+ "//benchmarks/harness/machine_producers:machine_producer",
"//benchmarks/harness/machine_producers:mock_producer",
"//benchmarks/harness/machine_producers:yaml_producer",
"//benchmarks/suites",
@@ -30,6 +32,14 @@ py_library(
],
)
+py_library(
+ name = "commands",
+ srcs = ["commands.py"],
+ deps = [
+ requirement("click", True),
+ ],
+)
+
py_test(
name = "runner_test",
srcs = ["runner_test.py"],
diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py
index 9bf9cfd65..ba80d83d7 100644
--- a/benchmarks/runner/__init__.py
+++ b/benchmarks/runner/__init__.py
@@ -15,10 +15,13 @@
import copy
import csv
+import json
import logging
+import os
import pkgutil
import pydoc
import re
+import subprocess
import sys
import types
from typing import List
@@ -26,10 +29,14 @@ from typing import Tuple
import click
+from benchmarks import harness
from benchmarks import suites
from benchmarks.harness import benchmark_driver
+from benchmarks.harness.machine_producers import gcloud_producer
+from benchmarks.harness.machine_producers import machine_producer
from benchmarks.harness.machine_producers import mock_producer
from benchmarks.harness.machine_producers import yaml_producer
+from benchmarks.runner import commands
@click.group()
@@ -100,30 +107,77 @@ def list_all(method):
print("\n")
-# pylint: disable=too-many-arguments
-# pylint: disable=too-many-branches
-# pylint: disable=too-many-locals
-@runner.command(
- context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
+@runner.command("run-local", commands.LocalCommand)
@click.pass_context
-@click.argument("method")
-@click.option("--mock/--no-mock", default=False, help="Mock the machines.")
-@click.option("--env", default=None, help="Specify a yaml file with machines.")
-@click.option(
- "--runtime", default=["runc"], help="The runtime to use.", multiple=True)
-@click.option("--metric", help="The metric to extract.", multiple=True)
-@click.option(
- "--runs", default=1, help="The number of times to run each benchmark.")
-@click.option(
- "--stat",
- default="median",
- help="How to aggregate the data from all runs."
- "\nmedian - returns the median of all runs (default)"
- "\nall - returns all results comma separated"
- "\nmeanstd - returns result as mean,std")
-# pylint: disable=too-many-statements
-def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str],
- metric: List[str], stat: str, **kwargs):
+def run_local(ctx, limit: float, **kwargs):
+ """Runs benchmarks locally."""
+ run(ctx, machine_producer.LocalMachineProducer(limit=limit), **kwargs)
+
+
+@runner.command("run-mock", commands.RunCommand)
+@click.pass_context
+def run_mock(ctx, **kwargs):
+ """Runs benchmarks on Mock machines. Used for testing."""
+ run(ctx, mock_producer.MockMachineProducer(), **kwargs)
+
+
+@runner.command("run-gcp", commands.GCPCommand)
+@click.pass_context
+def run_gcp(ctx, project: str, ssh_key_file: str, image: str,
+ image_project: str, machine_type: str, zone: str, ssh_user: str,
+ ssh_password: str, **kwargs):
+ """Runs all benchmarks on GCP instances."""
+
+ if not ssh_user:
+ ssh_user = harness.DEFAULT_USER
+
+ # Get the default project if one was not provided.
+ if not project:
+ sub = subprocess.run(
+ "gcloud config get-value project".split(" "), stdout=subprocess.PIPE)
+ if sub.returncode:
+ raise ValueError(
+ "Cannot get default project from gcloud. Is it configured>")
+ project = sub.stdout.decode("utf-8").strip("\n")
+
+ if not image_project:
+ image_project = project
+
+ # Check that the ssh-key exists and is readable.
+ if not os.access(ssh_key_file, os.R_OK):
+ raise ValueError(
+ "ssh key given `{ssh_key}` is does not exist or is not readable."
+ .format(ssh_key=ssh_key_file))
+
+ # Check that the image exists.
+ sub = subprocess.run(
+ "gcloud compute images describe {image} --project {image_project} --format=json"
+ .format(image=image, image_project=image_project).split(" "),
+ stdout=subprocess.PIPE)
+ if sub.returncode or "READY" not in json.loads(sub.stdout)["status"]:
+ raise ValueError(
+ "given image was not found or is not ready: {image} {image_project}."
+ .format(image=image, image_project=image_project))
+
+ # Check and set zone to default.
+ if not zone:
+ sub = subprocess.run(
+ "gcloud config get-value compute/zone".split(" "),
+ stdout=subprocess.PIPE)
+ if sub.returncode:
+ raise ValueError(
+ "Default zone is not set in gcloud. Set one or pass a zone with the --zone flag."
+ )
+ zone = sub.stdout.decode("utf-8").strip("\n")
+
+ producer = gcloud_producer.GCloudProducer(project, ssh_key_file, image,
+ image_project, machine_type, zone,
+ ssh_user, ssh_password)
+ run(ctx, producer, **kwargs)
+
+
+def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int,
+ runtime: List[str], metric: List[str], stat: str, **kwargs):
"""Runs arbitrary benchmarks.
All unknown command line flags are passed through to the underlying benchmark
@@ -139,16 +193,13 @@ def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str],
All benchmarks are run in parallel where possible, but have exclusive
ownership over the individual machines.
- Exactly one of the --mock and --env flag must be specified.
-
Every benchmark method will be run the times indicated by --runs.
Args:
ctx: Click context.
+ producer: A Machine Producer from which to get Machines.
method: A regular expression for methods to be run.
runs: Number of runs.
- env: Environment to use.
- mock: If true, use mocked environment (supercedes env).
runtime: A list of runtimes to test.
metric: A list of metrics to extract.
stat: The class of statistics to extract.
@@ -218,20 +269,6 @@ def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str],
sys.exit(1)
fold("method", list(methods.keys()), allow_flatten=True)
- # Construct the environment.
- if mock and env:
- # You can't provide both.
- logging.error("both --mock and --env are set: which one is it?")
- sys.exit(1)
- elif mock:
- producer = mock_producer.MockMachineProducer()
- elif env:
- producer = yaml_producer.YamlMachineProducer(env)
- else:
- # You must provide one of mock or env.
- logging.error("no enviroment provided: use --mock or --env.")
- sys.exit(1)
-
# Spin up the drivers.
#
# We ensure that metric is the last entry, because we have special behavior.
diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py
new file mode 100644
index 000000000..7ab12fac6
--- /dev/null
+++ b/benchmarks/runner/commands.py
@@ -0,0 +1,135 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module with the guts of `click` commands.
+
+Overrides of the click.core.Command. This is done so flags are inherited between
+similar commands (the run command). The classes below are meant to be used in
+click templates like so.
+
+@runner.command("run-mock", RunCommand)
+def run_mock(**kwargs):
+ # mock implementation
+
+"""
+import click
+
+from benchmarks import harness
+
+
+class RunCommand(click.core.Command):
+ """Base Run Command with flags.
+
+ Attributes:
+ method: regex of which suite to choose (e.g. sysbench would run
+ sysbench.cpu, sysbench.memory, and sysbench.mutex) See list command for
+ details.
+ metric: metric(s) to extract. See list command for details.
+ runtime: the runtime(s) on which to run.
+ runs: the number of runs to do of each method.
+ stat: how to compile results in the case of multiple run (e.g. median).
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ method = click.core.Argument(("method",))
+
+ metric = click.core.Option(("--metric",),
+ help="The metric to extract.",
+ multiple=True)
+
+ runtime = click.core.Option(("--runtime",),
+ default=["runc"],
+ help="The runtime to use.",
+ multiple=True)
+ runs = click.core.Option(("--runs",),
+ default=1,
+ help="The number of times to run each benchmark.")
+ stat = click.core.Option(
+ ("--stat",),
+ default="median",
+ help="How to aggregate the data from all runs."
+ "\nmedian - returns the median of all runs (default)"
+ "\nall - returns all results comma separated"
+ "\nmeanstd - returns result as mean,std")
+ self.params.extend([method, runtime, runs, stat, metric])
+ self.ignore_unknown_options = True
+ self.allow_extra_args = True
+
+
+class LocalCommand(RunCommand):
+ """LocalCommand inherits all flags from RunCommand.
+
+ Attributes:
+ limit: limits the number of machines on which to run benchmarks. This limits
+ for local how many benchmarks may run at a time. e.g. "startup" requires
+ one machine -- passing two machines would limit two startup jobs at a
+ time. Default is infinity.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.params.append(
+ click.core.Option(
+ ("--limit",),
+ default=1,
+ help="Limit of number of benchmarks that can run at a given time."))
+
+
+class GCPCommand(RunCommand):
+ """GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method.
+
+ Attributes:
+ project: GCP project
+ ssh_key_path: path to the ssh-key to use for the run
+ image: name of the image to build machines from
+ image_project: GCP project under which to find image
+ zone: a GCP zone (e.g. us-west1-b)
+ ssh_user: username to use for the ssh-key
+ ssh_password: password to use for the ssh-key
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ project = click.core.Option(
+ ("--project",),
+ help="Project to run on if not default value given by 'gcloud config get-value project'."
+ )
+ ssh_key_path = click.core.Option(
+ ("--ssh-key-file",),
+ help="Path to a valid ssh private key to use. See README on generating a valid ssh key. Set to ~/.ssh/benchmark-tools by default.",
+ default=harness.DEFAULT_USER_HOME + "/.ssh/benchmark-tools")
+ image = click.core.Option(("--image",),
+ help="The image on which to build VMs.",
+ default="bm-tools-testing")
+ image_project = click.core.Option(
+ ("--image_project",),
+ help="The project under which the image to be used is listed.",
+ default="")
+ machine_type = click.core.Option(("--machine_type",),
+ help="Type to make all machines.",
+ default="n1-standard-4")
+ zone = click.core.Option(("--zone",),
+ help="The GCP zone to run on.",
+ default="")
+ ssh_user = click.core.Option(("--ssh-user",),
+ help="User for the ssh key.",
+ default=harness.DEFAULT_USER)
+ ssh_password = click.core.Option(("--ssh-password",),
+ help="Password for the ssh key.",
+ default="")
+ self.params.extend([
+ project, ssh_key_path, image, image_project, machine_type, zone,
+ ssh_user, ssh_password
+ ])
diff --git a/benchmarks/runner/runner_test.py b/benchmarks/runner/runner_test.py
index 5719c2838..7818d631a 100644
--- a/benchmarks/runner/runner_test.py
+++ b/benchmarks/runner/runner_test.py
@@ -49,7 +49,7 @@ def test_list():
def test_run():
cli_runner = testing.CliRunner()
- result = cli_runner.invoke(runner.runner, ["run", "--mock", "."])
+ result = cli_runner.invoke(runner.runner, ["run-mock", "."])
print(result.output)
assert result.exit_code == 0
diff --git a/benchmarks/suites/http.py b/benchmarks/suites/http.py
index ea9024e43..6efea938c 100644
--- a/benchmarks/suites/http.py
+++ b/benchmarks/suites/http.py
@@ -92,7 +92,7 @@ def http_app(server: machine.Machine,
redis = server.pull("redis")
image = server.pull(workload)
redis_port = 6379
- redis_name = "redis_server"
+ redis_name = "{workload}_redis_server".format(workload=workload)
with server.container(redis, name=redis_name).detach():
server.container(server_netcat, links={redis_name: redis_name})\
diff --git a/benchmarks/tcp/tcp_benchmark.sh b/benchmarks/tcp/tcp_benchmark.sh
index 69344c9c3..e65801a7b 100755
--- a/benchmarks/tcp/tcp_benchmark.sh
+++ b/benchmarks/tcp/tcp_benchmark.sh
@@ -41,6 +41,8 @@ duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses.
duration=30 # 30s is enough time to consistent results (experimentally).
helper_dir=$(dirname $0)
netstack_opts=
+disable_linux_gso=
+num_client_threads=1
# Check for netem support.
lsmod_output=$(lsmod | grep sch_netem)
@@ -125,6 +127,13 @@ while [ $# -gt 0 ]; do
shift
netstack_opts="${netstack_opts} -memprofile=$1"
;;
+ --disable-linux-gso)
+ disable_linux_gso=1
+ ;;
+ --num-client-threads)
+ shift
+ num_client_threads=$1
+ ;;
--helpers)
shift
[ "$#" -le 0 ] && echo "no helper dir provided" && exit 1
@@ -147,6 +156,8 @@ while [ $# -gt 0 ]; do
echo " --loss set the loss probability (%)"
echo " --duplicate set the duplicate probability (%)"
echo " --helpers set the helper directory"
+ echo " --num-client-threads number of parallel client threads to run"
+ echo " --disable-linux-gso disable segmentation offload in the Linux network stack"
echo ""
echo "The output will of the script will be:"
echo " <throughput> <client-cpu-usage> <server-cpu-usage>"
@@ -301,6 +312,14 @@ fi
# Add client and server addresses, and bring everything up.
${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0
${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0
+if [ "${disable_linux_gso}" == "1" ]; then
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 tso off
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gro off
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 tso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gro off
+fi
${nsjoin_binary} /tmp/client.netns ip link set client.0 up
${nsjoin_binary} /tmp/client.netns ip link set lo up
${nsjoin_binary} /tmp/server.netns ip link set server.0 up
@@ -338,7 +357,7 @@ trap cleanup EXIT
# Run the benchmark, recording the results file.
while ${nsjoin_binary} /tmp/client.netns iperf \\
- -p ${proxy_port} -c ${client_addr} -t ${duration} -f m 2>&1 \\
+ -p ${proxy_port} -c ${client_addr} -t ${duration} -f m -P ${num_client_threads} 2>&1 \\
| tee \$results_file \\
| grep "connect failed" >/dev/null; do
sleep 0.1 # Wait for all services.
diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go
index 361a56755..72ada5700 100644
--- a/benchmarks/tcp/tcp_proxy.go
+++ b/benchmarks/tcp/tcp_proxy.go
@@ -84,8 +84,8 @@ func (netImpl) printStats() {
}
const (
- nicID = 1 // Fixed.
- rcvBufSize = 1 << 20 // 1MB.
+ nicID = 1 // Fixed.
+ bufSize = 4 << 20 // 4MB.
)
type netstackImpl struct {
@@ -94,11 +94,11 @@ type netstackImpl struct {
mode string
}
-func setupNetwork(ifaceName string) (fd int, err error) {
+func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) {
// Get all interfaces in the namespace.
ifaces, err := net.Interfaces()
if err != nil {
- return -1, fmt.Errorf("querying interfaces: %v", err)
+ return nil, fmt.Errorf("querying interfaces: %v", err)
}
for _, iface := range ifaces {
@@ -107,39 +107,47 @@ func setupNetwork(ifaceName string) (fd int, err error) {
}
// Create the socket.
const protocol = 0x0300 // htons(ETH_P_ALL)
- fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
- if err != nil {
- return -1, fmt.Errorf("unable to create raw socket: %v", err)
- }
+ fds := make([]int, numChannels)
+ for i := range fds {
+ fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create raw socket: %v", err)
+ }
- // Bind to the appropriate device.
- ll := syscall.SockaddrLinklayer{
- Protocol: protocol,
- Ifindex: iface.Index,
- Pkttype: syscall.PACKET_HOST,
- }
- if err := syscall.Bind(fd, &ll); err != nil {
- return -1, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
- }
+ // Bind to the appropriate device.
+ ll := syscall.SockaddrLinklayer{
+ Protocol: protocol,
+ Ifindex: iface.Index,
+ Pkttype: syscall.PACKET_HOST,
+ }
+ if err := syscall.Bind(fd, &ll); err != nil {
+ return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
+ }
- // RAW Sockets by default have a very small SO_RCVBUF of 256KB,
- // up it to at least 1MB to reduce packet drops.
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, rcvBufSize); err != nil {
- return -1, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err)
- }
+ // RAW Sockets by default have a very small SO_RCVBUF of 256KB,
+ // up it to at least 4MB to reduce packet drops.
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil {
+ return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err)
+ }
- if !*swgso && *gso != 0 {
- if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
- return -1, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil {
+ return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err)
+ }
+
+ if !*swgso && *gso != 0 {
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
+ return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
+ }
}
+ fds[i] = fd
}
- return fd, nil
+ return fds, nil
}
- return -1, fmt.Errorf("failed to find interface: %v", ifaceName)
+ return nil, fmt.Errorf("failed to find interface: %v", ifaceName)
}
func newNetstackImpl(mode string) (impl, error) {
- fd, err := setupNetwork(*iface)
+ fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1))
if err != nil {
return nil, err
}
@@ -177,7 +185,7 @@ func newNetstackImpl(mode string) (impl, error) {
mac[0] &^= 0x1 // Clear multicast bit.
mac[0] |= 0x2 // Set local assignment bit (IEEE802).
ep, err := fdbased.New(&fdbased.Options{
- FDs: []int{fd},
+ FDs: fds,
MTU: uint32(*mtu),
EthernetHeader: true,
Address: tcpip.LinkAddress(mac),
diff --git a/benchmarks/workloads/BUILD b/benchmarks/workloads/BUILD
index 643806105..ccb86af5b 100644
--- a/benchmarks/workloads/BUILD
+++ b/benchmarks/workloads/BUILD
@@ -11,25 +11,25 @@ py_library(
filegroup(
name = "files",
srcs = [
- "//benchmarks/workloads/ab:files",
- "//benchmarks/workloads/absl:files",
- "//benchmarks/workloads/curl:files",
- "//benchmarks/workloads/ffmpeg:files",
- "//benchmarks/workloads/fio:files",
- "//benchmarks/workloads/httpd:files",
- "//benchmarks/workloads/iperf:files",
- "//benchmarks/workloads/netcat:files",
- "//benchmarks/workloads/nginx:files",
- "//benchmarks/workloads/node:files",
- "//benchmarks/workloads/node_template:files",
- "//benchmarks/workloads/redis:files",
- "//benchmarks/workloads/redisbenchmark:files",
- "//benchmarks/workloads/ruby:files",
- "//benchmarks/workloads/ruby_template:files",
- "//benchmarks/workloads/sleep:files",
- "//benchmarks/workloads/sysbench:files",
- "//benchmarks/workloads/syscall:files",
- "//benchmarks/workloads/tensorflow:files",
- "//benchmarks/workloads/true:files",
+ "//benchmarks/workloads/ab:tar",
+ "//benchmarks/workloads/absl:tar",
+ "//benchmarks/workloads/curl:tar",
+ "//benchmarks/workloads/ffmpeg:tar",
+ "//benchmarks/workloads/fio:tar",
+ "//benchmarks/workloads/httpd:tar",
+ "//benchmarks/workloads/iperf:tar",
+ "//benchmarks/workloads/netcat:tar",
+ "//benchmarks/workloads/nginx:tar",
+ "//benchmarks/workloads/node:tar",
+ "//benchmarks/workloads/node_template:tar",
+ "//benchmarks/workloads/redis:tar",
+ "//benchmarks/workloads/redisbenchmark:tar",
+ "//benchmarks/workloads/ruby:tar",
+ "//benchmarks/workloads/ruby_template:tar",
+ "//benchmarks/workloads/sleep:tar",
+ "//benchmarks/workloads/sysbench:tar",
+ "//benchmarks/workloads/syscall:tar",
+ "//benchmarks/workloads/tensorflow:tar",
+ "//benchmarks/workloads/true:tar",
],
)
diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD
index e99a8d674..4fc0ab735 100644
--- a/benchmarks/workloads/ab/BUILD
+++ b/benchmarks/workloads/ab/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD
index bb499620e..61e010096 100644
--- a/benchmarks/workloads/absl/BUILD
+++ b/benchmarks/workloads/absl/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/curl/BUILD
+++ b/benchmarks/workloads/curl/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD
index c1f2afc40..be472dfb2 100644
--- a/benchmarks/workloads/ffmpeg/BUILD
+++ b/benchmarks/workloads/ffmpeg/BUILD
@@ -1,3 +1,5 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
@@ -8,8 +10,8 @@ py_library(
srcs = ["__init__.py"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD
index 7fc96cfa5..de257adad 100644
--- a/benchmarks/workloads/fio/BUILD
+++ b/benchmarks/workloads/fio/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/httpd/BUILD
+++ b/benchmarks/workloads/httpd/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD
index fe0acbfce..8832a996c 100644
--- a/benchmarks/workloads/iperf/BUILD
+++ b/benchmarks/workloads/iperf/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/netcat/BUILD
+++ b/benchmarks/workloads/netcat/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/nginx/BUILD
+++ b/benchmarks/workloads/nginx/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD
index 59460d02f..71cd9f519 100644
--- a/benchmarks/workloads/node/BUILD
+++ b/benchmarks/workloads/node/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
"index.js",
diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD
index ae7f121d3..ca996f068 100644
--- a/benchmarks/workloads/node_template/BUILD
+++ b/benchmarks/workloads/node_template/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
"index.hbs",
diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/redis/BUILD
+++ b/benchmarks/workloads/redis/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD
index d40e75a3a..f5994a815 100644
--- a/benchmarks/workloads/redisbenchmark/BUILD
+++ b/benchmarks/workloads/redisbenchmark/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD
index 9846c7e70..e37d77804 100644
--- a/benchmarks/workloads/ruby/BUILD
+++ b/benchmarks/workloads/ruby/BUILD
@@ -1,3 +1,5 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
@@ -13,3 +15,14 @@ filegroup(
"index.rb",
],
)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ "Gemfile",
+ "Gemfile.lock",
+ "config.ru",
+ "index.rb",
+ ],
+)
diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD
index 2b99892af..27f7c0c46 100644
--- a/benchmarks/workloads/ruby_template/BUILD
+++ b/benchmarks/workloads/ruby_template/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
"Gemfile",
@@ -13,4 +15,5 @@ filegroup(
"index.erb",
"main.rb",
],
+ strip_prefix = "third_party/gvisor/benchmarks/workloads/ruby_template",
)
diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/sleep/BUILD
+++ b/benchmarks/workloads/sleep/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD
index 35f4d460b..fd2f8f03d 100644
--- a/benchmarks/workloads/sysbench/BUILD
+++ b/benchmarks/workloads/sysbench/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD
index e1ff3059b..5100cbb21 100644
--- a/benchmarks/workloads/syscall/BUILD
+++ b/benchmarks/workloads/syscall/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
"syscall.c",
diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD
index 17f1f8ebb..026c3b316 100644
--- a/benchmarks/workloads/tensorflow/BUILD
+++ b/benchmarks/workloads/tensorflow/BUILD
@@ -1,3 +1,5 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
@@ -8,8 +10,8 @@ py_library(
srcs = ["__init__.py"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD
index 83f3c71a0..221c4b9a7 100644
--- a/benchmarks/workloads/true/BUILD
+++ b/benchmarks/workloads/true/BUILD
@@ -1,11 +1,14 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
+ extension = "tar",
)
diff --git a/go.mod b/go.mod
index 304b8bf13..c4687ed02 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
github.com/golang/protobuf v1.3.1
github.com/google/btree v1.0.0
github.com/google/go-cmp v0.2.0
+ github.com/google/go-github/v28 v28.1.1
github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8
github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d
github.com/kr/pty v1.1.1
@@ -17,5 +18,6 @@ require (
github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e
github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936
golang.org/x/net v0.0.0-20190311183353-d8887717615a
+ golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
)
diff --git a/go.sum b/go.sum
index 7a0bc175a..434770beb 100644
--- a/go.sum
+++ b/go.sum
@@ -4,6 +4,7 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
+github.com/google/go-github/v28 v28.1.1/go.mod h1:bsqJWQX05omyWVmc00nEUql9mhQyv38lDZ8kPZcQVoM=
github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
@@ -13,6 +14,7 @@ github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+S
github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
diff --git a/kokoro/issue_reviver.cfg b/kokoro/issue_reviver.cfg
new file mode 100644
index 000000000..2370d9250
--- /dev/null
+++ b/kokoro/issue_reviver.cfg
@@ -0,0 +1,15 @@
+build_file: "repo/scripts/issue_reviver.sh"
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73898
+ keyname: "kokoro-github-access-token"
+ }
+ }
+}
+
+env_vars {
+ key: "KOKORO_GITHUB_ACCESS_TOKEN"
+ value: "73898_kokoro-github-access-token"
+}
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 9553f164d..716ff22d2 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -41,6 +41,7 @@ go_library(
"poll.go",
"prctl.go",
"ptrace.go",
+ "rseq.go",
"rusage.go",
"sched.go",
"seccomp.go",
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 269ba5567..33fcc6c95 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -42,6 +42,15 @@ const (
NF_RETURN = -NF_REPEAT - 1
)
+// VerdictStrings maps int verdicts to the strings they represent. It is used
+// for debugging.
+var VerdictStrings = map[int32]string{
+ -NF_DROP - 1: "DROP",
+ -NF_ACCEPT - 1: "ACCEPT",
+ -NF_QUEUE - 1: "QUEUE",
+ NF_RETURN: "RETURN",
+}
+
// Socket options. These correspond to values in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
const (
@@ -179,7 +188,7 @@ const SizeOfXTCounters = 16
// the user data.
type XTEntryMatch struct {
MatchSize uint16
- Name [XT_EXTENSION_MAXNAMELEN]byte
+ Name ExtensionName
Revision uint8
// Data is omitted here because it would cause XTEntryMatch to be an
// extra byte larger (see http://www.catb.org/esr/structure-packing/).
@@ -199,7 +208,7 @@ const SizeOfXTEntryMatch = 32
// the user data.
type XTEntryTarget struct {
TargetSize uint16
- Name [XT_EXTENSION_MAXNAMELEN]byte
+ Name ExtensionName
Revision uint8
// Data is omitted here because it would cause XTEntryTarget to be an
// extra byte larger (see http://www.catb.org/esr/structure-packing/).
@@ -226,9 +235,9 @@ const SizeOfXTStandardTarget = 40
// ErrorName. It corresponds to struct xt_error_target in
// include/uapi/linux/netfilter/x_tables.h.
type XTErrorTarget struct {
- Target XTEntryTarget
- ErrorName [XT_FUNCTION_MAXNAMELEN]byte
- _ [2]byte
+ Target XTEntryTarget
+ Name ErrorName
+ _ [2]byte
}
// SizeOfXTErrorTarget is the size of an XTErrorTarget.
@@ -237,7 +246,7 @@ const SizeOfXTErrorTarget = 64
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTGetinfo struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
ValidHooks uint32
HookEntry [NF_INET_NUMHOOKS]uint32
Underflow [NF_INET_NUMHOOKS]uint32
@@ -248,16 +257,11 @@ type IPTGetinfo struct {
// SizeOfIPTGetinfo is the size of an IPTGetinfo.
const SizeOfIPTGetinfo = 84
-// TableName returns the table name.
-func (info *IPTGetinfo) TableName() string {
- return tableName(info.Name[:])
-}
-
// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
// corresponds to struct ipt_get_entries in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTGetEntries struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
Size uint32
_ [4]byte
// Entrytable is omitted here because it would cause IPTGetEntries to
@@ -266,34 +270,22 @@ type IPTGetEntries struct {
// Entrytable [0]IPTEntry
}
-// TableName returns the entries' table name.
-func (entries *IPTGetEntries) TableName() string {
- return tableName(entries.Name[:])
-}
-
// SizeOfIPTGetEntries is the size of an IPTGetEntries.
const SizeOfIPTGetEntries = 40
-// KernelIPTGetEntries is identical to IPTEntry, but includes the Elems field.
-// This struct marshaled via the binary package to write an KernelIPTGetEntries
-// to userspace.
+// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
+// Entrytable field. This struct marshaled via the binary package to write an
+// KernelIPTGetEntries to userspace.
type KernelIPTGetEntries struct {
- Name [XT_TABLE_MAXNAMELEN]byte
- Size uint32
- _ [4]byte
+ IPTGetEntries
Entrytable []KernelIPTEntry
}
-// TableName returns the entries' table name.
-func (entries *KernelIPTGetEntries) TableName() string {
- return tableName(entries.Name[:])
-}
-
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTReplace struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
ValidHooks uint32
NumEntries uint32
Size uint32
@@ -306,14 +298,45 @@ type IPTReplace struct {
// Entries [0]IPTEntry
}
+// KernelIPTReplace is identical to IPTReplace, but includes the Entries field.
+type KernelIPTReplace struct {
+ IPTReplace
+ Entries [0]IPTEntry
+}
+
// SizeOfIPTReplace is the size of an IPTReplace.
const SizeOfIPTReplace = 96
-func tableName(name []byte) string {
- for i, c := range name {
+// ExtensionName holds the name of a netfilter extension.
+type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (en ExtensionName) String() string {
+ return goString(en[:])
+}
+
+// TableName holds the name of a netfilter table.
+type TableName [XT_TABLE_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (tn TableName) String() string {
+ return goString(tn[:])
+}
+
+// ErrorName holds the name of a netfilter error. These can also hold
+// user-defined chains.
+type ErrorName [XT_FUNCTION_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (en ErrorName) String() string {
+ return goString(en[:])
+}
+
+func goString(cstring []byte) string {
+ for i, c := range cstring {
if c == 0 {
- return string(name[:i])
+ return string(cstring[:i])
}
}
- return string(name)
+ return string(cstring)
}
diff --git a/pkg/abi/linux/rseq.go b/pkg/abi/linux/rseq.go
new file mode 100644
index 000000000..76253ba30
--- /dev/null
+++ b/pkg/abi/linux/rseq.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Flags passed to rseq(2).
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_FLAG_UNREGISTER unregisters the current thread.
+ RSEQ_FLAG_UNREGISTER = 1 << 0
+)
+
+// Critical section flags used in RSeqCriticalSection.Flags and RSeq.Flags.
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT inhibits restart on preemption.
+ RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT = 1 << 0
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL inhibits restart on signal
+ // delivery.
+ RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL = 1 << 1
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE inhibits restart on CPU
+ // migration.
+ RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE = 1 << 2
+)
+
+// RSeqCriticalSection describes a restartable sequences critical section. It
+// is equivalent to struct rseq_cs, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+//
+// +marshal
+type RSeqCriticalSection struct {
+ // Version is the version of this structure. Version 0 is defined here.
+ Version uint32
+
+ // Flags are the critical section flags, defined above.
+ Flags uint32
+
+ // Start is the start address of the critical section.
+ Start uint64
+
+ // PostCommitOffset is the offset from Start of the first instruction
+ // outside of the critical section.
+ PostCommitOffset uint64
+
+ // Abort is the abort address. It must be outside the critical section,
+ // and the 4 bytes prior must match the abort signature.
+ Abort uint64
+}
+
+const (
+ // SizeOfRSeqCriticalSection is the size of RSeqCriticalSection.
+ SizeOfRSeqCriticalSection = 32
+
+ // SizeOfRSeqSignature is the size of the signature immediately
+ // preceding RSeqCriticalSection.Abort.
+ SizeOfRSeqSignature = 4
+)
+
+// Special values for RSeq.CPUID, defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CPU_ID_UNINITIALIZED indicates that this thread has not
+ // performed rseq initialization.
+ RSEQ_CPU_ID_UNINITIALIZED = ^uint32(0) // -1
+
+ // RSEQ_CPU_ID_REGISTRATION_FAILED indicates that rseq initialization
+ // failed.
+ RSEQ_CPU_ID_REGISTRATION_FAILED = ^uint32(1) // -2
+)
+
+// RSeq is the thread-local restartable sequences config/status. It
+// is equivalent to struct rseq, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+type RSeq struct {
+ // CPUIDStart contains the current CPU ID if rseq is initialized.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUIDStart uint32
+
+ // CPUID contains the current CPU ID or one of the CPU ID special
+ // values defined above.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUID uint32
+
+ // RSeqCriticalSection is a pointer to the current RSeqCriticalSection
+ // block, or NULL. It is reset to NULL by the kernel on restart or
+ // non-restarting preempt/signal.
+ //
+ // This field should only be written by the thread which registered
+ // this structure, and must be written atomically.
+ RSeqCriticalSection uint64
+
+ // Flags are the critical section flags that apply to all critical
+ // sections on this thread, defined above.
+ Flags uint32
+}
+
+const (
+ // SizeOfRSeq is the size of RSeq.
+ //
+ // Note that RSeq is naively 24 bytes. However, it has 32-byte
+ // alignment, which in C increases sizeof to 32. That is the size that
+ // the Linux kernel uses.
+ SizeOfRSeq = 32
+
+ // AlignOfRSeq is the standard alignment of RSeq.
+ AlignOfRSeq = 32
+
+ // OffsetOfRSeqCriticalSection is the offset of RSeqCriticalSection in RSeq.
+ OffsetOfRSeqCriticalSection = 8
+)
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index 546668bca..5c5a58cd4 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -234,6 +234,19 @@ type StatxTimestamp struct {
_ int32
}
+// ToNsec returns the nanosecond representation.
+func (sxts StatxTimestamp) ToNsec() int64 {
+ return int64(sxts.Sec)*1e9 + int64(sxts.Nsec)
+}
+
+// ToNsecCapped returns the safe nanosecond representation.
+func (sxts StatxTimestamp) ToNsecCapped() int64 {
+ if sxts.Sec > maxSecInDuration {
+ return math.MaxInt64
+ }
+ return sxts.ToNsec()
+}
+
// NsecToStatxTimestamp translates nanoseconds to StatxTimestamp.
func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
return StatxTimestamp{
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index 6bc486b62..d99e37b40 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -15,4 +15,5 @@ go_test(
size = "small",
srcs = ["amutex_test.go"],
embed = [":amutex"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go
index 1d7f45641..8a3952f2a 100644
--- a/pkg/amutex/amutex_test.go
+++ b/pkg/amutex/amutex_test.go
@@ -15,9 +15,10 @@
package amutex
import (
- "sync"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
type sleeper struct {
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 36beaade9..6403c60c2 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -20,4 +20,5 @@ go_test(
size = "small",
srcs = ["atomic_bitops_test.go"],
embed = [":atomicbitops"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomic_bitops_test.go
index 965e9be79..9466d3e23 100644
--- a/pkg/atomicbitops/atomic_bitops_test.go
+++ b/pkg/atomicbitops/atomic_bitops_test.go
@@ -16,8 +16,9 @@ package atomicbitops
import (
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
const iterations = 100
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
index a0b21d4bd..2bb581b18 100644
--- a/pkg/compressio/BUILD
+++ b/pkg/compressio/BUILD
@@ -8,7 +8,10 @@ go_library(
srcs = ["compressio.go"],
importpath = "gvisor.dev/gvisor/pkg/compressio",
visibility = ["//:sandbox"],
- deps = ["//pkg/binary"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/sync",
+ ],
)
go_test(
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index 3b0bb086e..5f52cbe74 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -52,9 +52,9 @@ import (
"hash"
"io"
"runtime"
- "sync"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sync"
)
var bufPool = sync.Pool{
diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD
index 21adf3adf..adbd1e3f8 100644
--- a/pkg/control/server/BUILD
+++ b/pkg/control/server/BUILD
@@ -9,6 +9,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
],
diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go
index a56152d10..41abe1f2d 100644
--- a/pkg/control/server/server.go
+++ b/pkg/control/server/server.go
@@ -22,9 +22,9 @@ package server
import (
"os"
- "sync"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/urpc"
)
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index d37047368..cf50ee53f 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -657,30 +657,28 @@ func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
return strings.Join(s, " ")
}
-// CPUInfo is to generate a section of one cpu in /proc/cpuinfo. This is a
-// minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
+// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
+// a minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
// not always printed in Linux. The bogomips field is simply made up.
-func (fs FeatureSet) CPUInfo(cpu uint) string {
- var b bytes.Buffer
- fmt.Fprintf(&b, "processor\t: %d\n", cpu)
- fmt.Fprintf(&b, "vendor_id\t: %s\n", fs.VendorID)
- fmt.Fprintf(&b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
- fmt.Fprintf(&b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
- fmt.Fprintf(&b, "model name\t: %s\n", "unknown") // Unknown for now.
- fmt.Fprintf(&b, "stepping\t: %s\n", "unknown") // Unknown for now.
- fmt.Fprintf(&b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
- fmt.Fprintln(&b, "fpu\t\t: yes")
- fmt.Fprintln(&b, "fpu_exception\t: yes")
- fmt.Fprintf(&b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
- fmt.Fprintln(&b, "wp\t\t: yes")
- fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true))
- fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
- fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine)
- fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine)
- fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
- fmt.Fprintln(&b, "power management:") // This is always here, but can be blank.
- fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline.
- return b.String()
+func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
+ fmt.Fprintf(b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(b, "vendor_id\t: %s\n", fs.VendorID)
+ fmt.Fprintf(b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
+ fmt.Fprintf(b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
+ fmt.Fprintf(b, "model name\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "stepping\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
+ fmt.Fprintln(b, "fpu\t\t: yes")
+ fmt.Fprintln(b, "fpu_exception\t: yes")
+ fmt.Fprintf(b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
+ fmt.Fprintln(b, "wp\t\t: yes")
+ fmt.Fprintf(b, "flags\t\t: %s\n", fs.FlagsString(true))
+ fmt.Fprintf(b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "cache_alignment\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
+ fmt.Fprintln(b, "power management:") // This is always here, but can be blank.
+ fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
}
const (
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 0b4b7cc44..9d68682c7 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -15,6 +15,7 @@ go_library(
deps = [
":eventchannel_go_proto",
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_golang_protobuf//ptypes:go_default_library_gen",
@@ -40,6 +41,7 @@ go_test(
srcs = ["event_test.go"],
embed = [":eventchannel"],
deps = [
+ "//pkg/sync",
"@com_github_golang_protobuf//proto:go_default_library",
],
)
diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go
index d37ad0428..9a29c58bd 100644
--- a/pkg/eventchannel/event.go
+++ b/pkg/eventchannel/event.go
@@ -22,13 +22,13 @@ package eventchannel
import (
"encoding/binary"
"fmt"
- "sync"
"syscall"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go
index 3649097d6..7f41b4a27 100644
--- a/pkg/eventchannel/event_test.go
+++ b/pkg/eventchannel/event_test.go
@@ -16,11 +16,11 @@ package eventchannel
import (
"fmt"
- "sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
// testEmitter is an emitter that can be used in tests. It records all events
diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD
index 56495cbd9..b0478c672 100644
--- a/pkg/fdchannel/BUILD
+++ b/pkg/fdchannel/BUILD
@@ -15,4 +15,5 @@ go_test(
size = "small",
srcs = ["fdchannel_test.go"],
embed = [":fdchannel"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go
index 5d01dc636..7a8a63a59 100644
--- a/pkg/fdchannel/fdchannel_test.go
+++ b/pkg/fdchannel/fdchannel_test.go
@@ -17,10 +17,11 @@ package fdchannel
import (
"io/ioutil"
"os"
- "sync"
"syscall"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSendRecvFD(t *testing.T) {
diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD
index aca2d8a82..91a202a30 100644
--- a/pkg/fdnotifier/BUILD
+++ b/pkg/fdnotifier/BUILD
@@ -11,6 +11,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/fdnotifier",
visibility = ["//:sandbox"],
deps = [
+ "//pkg/sync",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go
index f4aae1953..a6b63c982 100644
--- a/pkg/fdnotifier/fdnotifier.go
+++ b/pkg/fdnotifier/fdnotifier.go
@@ -22,10 +22,10 @@ package fdnotifier
import (
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index e590a71ba..85bd83af1 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -19,7 +19,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/log",
"//pkg/memutil",
- "//pkg/syncutil",
+ "//pkg/sync",
],
)
@@ -31,4 +31,5 @@ go_test(
"flipcall_test.go",
],
embed = [":flipcall"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go
index 8d88b845d..2e28a149a 100644
--- a/pkg/flipcall/flipcall_example_test.go
+++ b/pkg/flipcall/flipcall_example_test.go
@@ -17,7 +17,8 @@ package flipcall
import (
"bytes"
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func Example() {
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
index 168a487ec..33fd55a44 100644
--- a/pkg/flipcall/flipcall_test.go
+++ b/pkg/flipcall/flipcall_test.go
@@ -16,9 +16,10 @@ package flipcall
import (
"runtime"
- "sync"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
var testPacketWindowSize = pageSize
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
index 27b8939fc..ac974b232 100644
--- a/pkg/flipcall/flipcall_unsafe.go
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -18,7 +18,7 @@ import (
"reflect"
"unsafe"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Packets consist of a 16-byte header followed by an arbitrarily-sized
@@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte {
var ioSync int64
func raceBecomeActive() {
- if syncutil.RaceEnabled {
- syncutil.RaceAcquire((unsafe.Pointer)(&ioSync))
+ if sync.RaceEnabled {
+ sync.RaceAcquire((unsafe.Pointer)(&ioSync))
}
}
func raceBecomeInactive() {
- if syncutil.RaceEnabled {
- syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ if sync.RaceEnabled {
+ sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
}
}
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
index 4b9321711..f22bd070d 100644
--- a/pkg/gate/BUILD
+++ b/pkg/gate/BUILD
@@ -19,5 +19,6 @@ go_test(
],
deps = [
":gate",
+ "//pkg/sync",
],
)
diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go
index 5dbd8d712..850693df8 100644
--- a/pkg/gate/gate_test.go
+++ b/pkg/gate/gate_test.go
@@ -15,11 +15,11 @@
package gate_test
import (
- "sync"
"testing"
"time"
"gvisor.dev/gvisor/pkg/gate"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestBasicEnter(t *testing.T) {
diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD
new file mode 100644
index 000000000..5d31e5366
--- /dev/null
+++ b/pkg/goid/BUILD
@@ -0,0 +1,26 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "goid",
+ srcs = [
+ "goid.go",
+ "goid_amd64.s",
+ "goid_race.go",
+ "goid_unsafe.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/goid",
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "goid_test",
+ size = "small",
+ srcs = [
+ "empty_test.go",
+ "goid_test.go",
+ ],
+ embed = [":goid"],
+)
diff --git a/pkg/sentry/fsimpl/proc/proc.go b/pkg/goid/empty_test.go
index 31dec36de..c0a4b17ab 100644
--- a/pkg/sentry/fsimpl/proc/proc.go
+++ b/pkg/goid/empty_test.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,5 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package proc implements a partial in-memory file system for procfs.
-package proc
+// +build !race
+
+package goid
+
+import "testing"
+
+// TestNothing exists to make the build system happy.
+func TestNothing(t *testing.T) {}
diff --git a/pkg/sentry/socket/rpcinet/rpcinet.go b/pkg/goid/goid.go
index 5d4fd4dac..39df30031 100644
--- a/pkg/sentry/socket/rpcinet/rpcinet.go
+++ b/pkg/goid/goid.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,5 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package rpcinet implements sockets using an RPC for each syscall.
-package rpcinet
+// +build !race
+
+// Package goid provides access to the ID of the current goroutine in
+// race/gotsan builds.
+package goid
+
+// Get returns the ID of the current goroutine.
+func Get() int64 {
+ panic("unimplemented for non-race builds")
+}
diff --git a/pkg/sentry/socket/rpcinet/device.go b/pkg/goid/goid_amd64.s
index 8cfd5f6e5..d9f5cd2a3 100644
--- a/pkg/sentry/socket/rpcinet/device.go
+++ b/pkg/goid/goid_amd64.s
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rpcinet
+#include "textflag.h"
-import "gvisor.dev/gvisor/pkg/sentry/device"
-
-var socketDevice = device.NewAnonDevice()
+// func getg() *g
+TEXT ·getg(SB),NOSPLIT,$0-8
+ MOVQ (TLS), R14
+ MOVQ R14, ret+0(FP)
+ RET
diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/goid/goid_race.go
index 0e016bca5..1766beaee 100644
--- a/pkg/sentry/fsimpl/proc/filesystems.go
+++ b/pkg/goid/goid_race.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package proc
+// Only available in race/gotsan builds.
+// +build race
-// filesystemsData implements vfs.DynamicBytesSource for /proc/filesystems.
-//
-// +stateify savable
-type filesystemsData struct{}
+// Package goid provides access to the ID of the current goroutine in
+// race/gotsan builds.
+package goid
-// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for
-// filesystemsData. We would need to retrive filesystem names from
-// vfs.VirtualFilesystem. Also needs vfs replacement for
-// fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev.
+// Get returns the ID of the current goroutine.
+func Get() int64 {
+ return goid()
+}
diff --git a/pkg/goid/goid_test.go b/pkg/goid/goid_test.go
new file mode 100644
index 000000000..31970ce79
--- /dev/null
+++ b/pkg/goid/goid_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package goid
+
+import (
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestInitialGoID(t *testing.T) {
+ const max = 10000
+ if id := goid(); id < 0 || id > max {
+ t.Errorf("got goid = %d, want 0 < goid <= %d", id, max)
+ }
+}
+
+// TestGoIDSquence verifies that goid returns values which could plausibly be
+// goroutine IDs. If this test breaks or becomes flaky, the structs in
+// goid_unsafe.go may need to be updated.
+func TestGoIDSquence(t *testing.T) {
+ // Goroutine IDs are cached by each P.
+ runtime.GOMAXPROCS(1)
+
+ // Fill any holes in lower range.
+ for i := 0; i < 50; i++ {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ wg.Done()
+
+ // Leak the goroutine to prevent the ID from being
+ // reused.
+ select {}
+ }()
+ wg.Wait()
+ }
+
+ id := goid()
+ for i := 0; i < 100; i++ {
+ var (
+ newID int64
+ wg sync.WaitGroup
+ )
+ wg.Add(1)
+ go func() {
+ newID = goid()
+ wg.Done()
+
+ // Leak the goroutine to prevent the ID from being
+ // reused.
+ select {}
+ }()
+ wg.Wait()
+ if max := id + 100; newID <= id || newID > max {
+ t.Errorf("unexpected goroutine ID pattern, got goid = %d, want %d < goid <= %d (previous = %d)", newID, id, max, id)
+ }
+ id = newID
+ }
+}
diff --git a/pkg/goid/goid_unsafe.go b/pkg/goid/goid_unsafe.go
new file mode 100644
index 000000000..ded8004dd
--- /dev/null
+++ b/pkg/goid/goid_unsafe.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package goid
+
+// Structs from Go runtime. These may change in the future and require
+// updating. These structs are currently the same on both AMD64 and ARM64,
+// but may diverge in the future.
+
+type stack struct {
+ lo uintptr
+ hi uintptr
+}
+
+type gobuf struct {
+ sp uintptr
+ pc uintptr
+ g uintptr
+ ctxt uintptr
+ ret uint64
+ lr uintptr
+ bp uintptr
+}
+
+type g struct {
+ stack stack
+ stackguard0 uintptr
+ stackguard1 uintptr
+
+ _panic uintptr
+ _defer uintptr
+ m uintptr
+ sched gobuf
+ syscallsp uintptr
+ syscallpc uintptr
+ stktopsp uintptr
+ param uintptr
+ atomicstatus uint32
+ stackLock uint32
+ goid int64
+
+ // More fields...
+ //
+ // We only use goid and the fields before it are only listed to
+ // calculate the correct offset.
+}
+
+func getg() *g
+
+// goid returns the ID of the current goroutine.
+func goid() int64 {
+ return getg().goid
+}
diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD
index a5d980d14..bcde6d308 100644
--- a/pkg/linewriter/BUILD
+++ b/pkg/linewriter/BUILD
@@ -8,6 +8,7 @@ go_library(
srcs = ["linewriter.go"],
importpath = "gvisor.dev/gvisor/pkg/linewriter",
visibility = ["//visibility:public"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go
index cd6e4e2ce..a1b1285d4 100644
--- a/pkg/linewriter/linewriter.go
+++ b/pkg/linewriter/linewriter.go
@@ -17,7 +17,8 @@ package linewriter
import (
"bytes"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// Writer is an io.Writer which buffers input, flushing
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
index fc5f5779b..0df0f2849 100644
--- a/pkg/log/BUILD
+++ b/pkg/log/BUILD
@@ -16,7 +16,10 @@ go_library(
visibility = [
"//visibility:public",
],
- deps = ["//pkg/linewriter"],
+ deps = [
+ "//pkg/linewriter",
+ "//pkg/sync",
+ ],
)
go_test(
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 9387586e6..91a81b288 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -25,12 +25,12 @@ import (
stdlog "log"
"os"
"runtime"
- "sync"
"sync/atomic"
"syscall"
"time"
"gvisor.dev/gvisor/pkg/linewriter"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Level is the log level.
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index dd6ca6d39..9145f3233 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -14,6 +14,7 @@ go_library(
":metric_go_proto",
"//pkg/eventchannel",
"//pkg/log",
+ "//pkg/sync",
],
)
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index eadde06e4..93d4f2b8c 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -18,12 +18,12 @@ package metric
import (
"errors"
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/eventchannel"
"gvisor.dev/gvisor/pkg/log"
pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
var (
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index f32244c69..a3e05c96d 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -29,6 +29,7 @@ go_library(
"//pkg/fdchannel",
"//pkg/flipcall",
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 221516c6c..4045e41fa 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -17,12 +17,12 @@ package p9
import (
"errors"
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index de9357389..0254e4ccc 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -165,6 +165,35 @@ func (c *clientFile) SetAttr(valid SetAttrMask, attr SetAttr) error {
return c.client.sendRecv(&Tsetattr{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattr{})
}
+// GetXattr implements File.GetXattr.
+func (c *clientFile) GetXattr(name string, size uint64) (string, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return "", syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return "", syscall.EOPNOTSUPP
+ }
+
+ rgetxattr := Rgetxattr{}
+ if err := c.client.sendRecv(&Tgetxattr{FID: c.fid, Name: name, Size: size}, &rgetxattr); err != nil {
+ return "", err
+ }
+
+ return rgetxattr.Value, nil
+}
+
+// SetXattr implements File.SetXattr.
+func (c *clientFile) SetXattr(name, value string, flags uint32) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tsetxattr{FID: c.fid, Name: name, Value: value, Flags: flags}, &Rsetxattr{})
+}
+
// Allocate implements File.Allocate.
func (c *clientFile) Allocate(mode AllocateMode, offset, length uint64) error {
if atomic.LoadUint32(&c.closed) != 0 {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index 96d1f2a8e..4607cfcdf 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -89,6 +89,22 @@ type File interface {
// On the server, SetAttr has a write concurrency guarantee.
SetAttr(valid SetAttrMask, attr SetAttr) error
+ // GetXattr returns extended attributes of this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // TODO(b/127675828): Determine concurrency guarantees once implemented.
+ GetXattr(name string, size uint64) (string, error)
+
+ // SetXattr sets extended attributes on this node.
+ //
+ // TODO(b/127675828): Determine concurrency guarantees once implemented.
+ SetXattr(name, value string, flags uint32) error
+
// Allocate allows the caller to directly manipulate the allocated disk space
// for the file. See fallocate(2) for more details.
Allocate(mode AllocateMode, offset, length uint64) error
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index b9582c07f..7d6653a07 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -913,6 +913,35 @@ func (t *Txattrcreate) handle(cs *connState) message {
}
// handle implements handler.handle.
+func (t *Tgetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ val, err := ref.file.GetXattr(t.Name, t.Size)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rgetxattr{Value: val}
+}
+
+// handle implements handler.handle.
+func (t *Tsetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.file.SetXattr(t.Name, t.Value, t.Flags); err != nil {
+ return newErr(err)
+ }
+ return &Rsetxattr{}
+}
+
+// handle implements handler.handle.
func (t *Treaddir) handle(cs *connState) message {
ref, ok := cs.LookupFID(t.Directory)
if !ok {
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index ffdd7e8c6..ceb723d86 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -1611,6 +1611,131 @@ func (r *Rxattrcreate) String() string {
return fmt.Sprintf("Rxattrcreate{}")
}
+// Tgetxattr is a getxattr request.
+type Tgetxattr struct {
+ // FID refers to the file for which to get xattrs.
+ FID FID
+
+ // Name is the xattr to get.
+ Name string
+
+ // Size is the buffer size for the xattr to get.
+ Size uint64
+}
+
+// Decode implements encoder.Decode.
+func (t *Tgetxattr) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Size = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tgetxattr) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tgetxattr) Type() MsgType {
+ return MsgTgetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tgetxattr) String() string {
+ return fmt.Sprintf("Tgetxattr{FID: %d, Name: %s, Size: %d}", t.FID, t.Name, t.Size)
+}
+
+// Rgetxattr is a getxattr response.
+type Rgetxattr struct {
+ // Value is the extended attribute value.
+ Value string
+}
+
+// Decode implements encoder.Decode.
+func (r *Rgetxattr) Decode(b *buffer) {
+ r.Value = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rgetxattr) Encode(b *buffer) {
+ b.WriteString(r.Value)
+}
+
+// Type implements message.Type.
+func (*Rgetxattr) Type() MsgType {
+ return MsgRgetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rgetxattr) String() string {
+ return fmt.Sprintf("Rgetxattr{Value: %s}", r.Value)
+}
+
+// Tsetxattr sets extended attributes.
+type Tsetxattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+
+ // Value is the attribute value.
+ Value string
+
+ // Linux setxattr(2) flags.
+ Flags uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tsetxattr) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Value = b.ReadString()
+ t.Flags = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tsetxattr) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.WriteString(t.Value)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tsetxattr) Type() MsgType {
+ return MsgTsetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetxattr) String() string {
+ return fmt.Sprintf("Tsetxattr{FID: %d, Name: %s, Value: %s, Flags: %d}", t.FID, t.Name, t.Value, t.Flags)
+}
+
+// Rsetxattr is a setxattr response.
+type Rsetxattr struct {
+}
+
+// Decode implements encoder.Decode.
+func (r *Rsetxattr) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (r *Rsetxattr) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetxattr) Type() MsgType {
+ return MsgRsetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetxattr) String() string {
+ return fmt.Sprintf("Rsetxattr{}")
+}
+
// Treaddir is a readdir request.
type Treaddir struct {
// Directory is the directory FID to read.
@@ -2363,6 +2488,10 @@ func init() {
msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
+ msgRegistry.register(MsgTgetxattr, func() message { return &Tgetxattr{} })
+ msgRegistry.register(MsgRgetxattr, func() message { return &Rgetxattr{} })
+ msgRegistry.register(MsgTsetxattr, func() message { return &Tsetxattr{} })
+ msgRegistry.register(MsgRsetxattr, func() message { return &Rsetxattr{} })
msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index 6ba6a1654..825c939da 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -194,6 +194,21 @@ func TestEncodeDecode(t *testing.T) {
Flags: 3,
},
&Rxattrcreate{},
+ &Tgetxattr{
+ FID: 1,
+ Name: "abc",
+ Size: 2,
+ },
+ &Rgetxattr{
+ Value: "xyz",
+ },
+ &Tsetxattr{
+ FID: 1,
+ Name: "abc",
+ Value: "xyz",
+ Flags: 2,
+ },
+ &Rsetxattr{},
&Treaddir{
Directory: 1,
Offset: 2,
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index d3090535a..5ab00d625 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -339,6 +339,10 @@ const (
MsgRxattrwalk = 31
MsgTxattrcreate = 32
MsgRxattrcreate = 33
+ MsgTgetxattr = 34
+ MsgRgetxattr = 35
+ MsgTsetxattr = 36
+ MsgRsetxattr = 37
MsgTreaddir = 40
MsgRreaddir = 41
MsgTfsync = 50
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 28707c0ca..f4edd68b2 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -70,6 +70,7 @@ go_library(
"//pkg/fd",
"//pkg/log",
"//pkg/p9",
+ "//pkg/sync",
"//pkg/unet",
"@com_github_golang_mock//gomock:go_default_library",
],
@@ -83,6 +84,7 @@ go_test(
deps = [
"//pkg/fd",
"//pkg/p9",
+ "//pkg/sync",
"@com_github_golang_mock//gomock:go_default_library",
],
)
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index 6e758148d..6e7bb3db2 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -22,7 +22,6 @@ import (
"os"
"reflect"
"strings"
- "sync"
"syscall"
"testing"
"time"
@@ -30,6 +29,7 @@ import (
"github.com/golang/mock/gomock"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestPanic(t *testing.T) {
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 4d3271b37..dd8b01b6d 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -17,13 +17,13 @@ package p9test
import (
"fmt"
- "sync"
"sync/atomic"
"syscall"
"testing"
"github.com/golang/mock/gomock"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go
index 865459411..72ef53313 100644
--- a/pkg/p9/path_tree.go
+++ b/pkg/p9/path_tree.go
@@ -16,7 +16,8 @@ package p9
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// pathNode is a single node in a path traversal.
diff --git a/pkg/p9/pool.go b/pkg/p9/pool.go
index 52de889e1..2b14a5ce3 100644
--- a/pkg/p9/pool.go
+++ b/pkg/p9/pool.go
@@ -15,7 +15,7 @@
package p9
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// pool is a simple allocator.
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index 40b8fa023..fdfa83648 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -17,7 +17,6 @@ package p9
import (
"io"
"runtime/debug"
- "sync"
"sync/atomic"
"syscall"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/fdchannel"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 6e8b4bbcd..9c11e28ce 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -19,11 +19,11 @@ import (
"fmt"
"io"
"io/ioutil"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index 36a694c58..34a15eb55 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 9
+ highestSupportedVersion uint32 = 10
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -161,3 +161,9 @@ func versionSupportsFlipcall(v uint32) bool {
func VersionSupportsOpenTruncateFlag(v uint32) bool {
return v >= 9
}
+
+// versionSupportsGetSetXattr returns true if version v supports
+// the Tgetxattr and Tsetxattr messages.
+func versionSupportsGetSetXattr(v uint32) bool {
+ return v >= 10
+}
diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD
index 078f084b2..b506813f0 100644
--- a/pkg/procid/BUILD
+++ b/pkg/procid/BUILD
@@ -21,6 +21,7 @@ go_test(
"procid_test.go",
],
embed = [":procid"],
+ deps = ["//pkg/sync"],
)
go_test(
@@ -31,4 +32,5 @@ go_test(
"procid_test.go",
],
embed = [":procid"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go
index 88dd0b3ae..9ec08c3d6 100644
--- a/pkg/procid/procid_test.go
+++ b/pkg/procid/procid_test.go
@@ -17,9 +17,10 @@ package procid
import (
"os"
"runtime"
- "sync"
"syscall"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// runOnMain is used to send functions to run on the main (initial) thread.
diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD
index f4f2001f3..9d5b4859b 100644
--- a/pkg/rand/BUILD
+++ b/pkg/rand/BUILD
@@ -10,5 +10,8 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/rand",
visibility = ["//:sandbox"],
- deps = ["@org_golang_x_sys//unix:go_default_library"],
+ deps = [
+ "//pkg/sync",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
)
diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go
index 2b92db3e6..0bdad5fad 100644
--- a/pkg/rand/rand_linux.go
+++ b/pkg/rand/rand_linux.go
@@ -19,9 +19,9 @@ package rand
import (
"crypto/rand"
"io"
- "sync"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
)
// reader implements an io.Reader that returns pseudorandom bytes.
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 7ad59dfd7..974d9af9b 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -27,6 +27,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/sync",
],
)
@@ -35,4 +36,5 @@ go_test(
size = "small",
srcs = ["refcounter_test.go"],
embed = [":refs"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index ad69e0757..c45ba8200 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -21,10 +21,10 @@ import (
"fmt"
"reflect"
"runtime"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
)
// RefCounter is the interface to be implemented by objects that are reference
diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go
index ffd3d3f07..1ab4a4440 100644
--- a/pkg/refs/refcounter_test.go
+++ b/pkg/refs/refcounter_test.go
@@ -16,8 +16,9 @@ package refs
import (
"reflect"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
type testCounter struct {
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 18c73cc24..65f22af2b 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -9,17 +9,23 @@ go_library(
srcs = [
"aligned.go",
"arch.go",
+ "arch_aarch64.go",
"arch_amd64.go",
"arch_amd64.s",
+ "arch_arm64.go",
+ "arch_state_aarch64.go",
"arch_state_x86.go",
"arch_x86.go",
"auxv.go",
+ "signal.go",
"signal_act.go",
"signal_amd64.go",
+ "signal_arm64.go",
"signal_info.go",
"signal_stack.go",
"stack.go",
"syscalls_amd64.go",
+ "syscalls_arm64.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/arch",
visibility = ["//:sandbox"],
@@ -32,6 +38,7 @@ go_library(
"//pkg/sentry/context",
"//pkg/sentry/limits",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 498ca4669..81ec98a77 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -125,9 +125,9 @@ type Context interface {
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
SetTLS(value uintptr) bool
- // SetRSEQInterruptedIP sets the register that contains the old IP when a
- // restartable sequence is interrupted.
- SetRSEQInterruptedIP(value uintptr)
+ // SetOldRSeqInterruptedIP sets the register that contains the old IP
+ // when an "old rseq" restartable sequence is interrupted.
+ SetOldRSeqInterruptedIP(value uintptr)
// StateData returns a pointer to underlying architecture state.
StateData() *State
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
new file mode 100644
index 000000000..ea4dedbdf
--- /dev/null
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -0,0 +1,293 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+import (
+ "fmt"
+ "io"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/log"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // SyscallWidth is the width of insturctions.
+ SyscallWidth = 4
+)
+
+// aarch64FPState is aarch64 floating point state.
+type aarch64FPState []byte
+
+// initAarch64FPState (defined in asm files) sets up initial state.
+func initAarch64FPState(data *FloatingPointData) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+}
+
+func newAarch64FPStateSlice() []byte {
+ return alignedBytes(4096, 32)[:4096]
+}
+
+// newAarch64FPState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet. Since they may still be able to see state not advertised by
+// CPUID we must ensure it does not contain any sentry state.
+func newAarch64FPState() aarch64FPState {
+ f := aarch64FPState(newAarch64FPStateSlice())
+ initAarch64FPState(f.FloatingPointData())
+ return f
+}
+
+// fork creates and returns an identical copy of the aarch64 floating point state.
+func (f aarch64FPState) fork() aarch64FPState {
+ n := aarch64FPState(newAarch64FPStateSlice())
+ copy(n, f)
+ return n
+}
+
+// FloatingPointData returns the raw data pointer.
+func (f aarch64FPState) FloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&f[0])
+}
+
+// NewFloatingPointData returns a new floating point data blob.
+//
+// This is primarily for use in tests.
+func NewFloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&(newAarch64FPState()[0]))
+}
+
+// State contains the common architecture bits for aarch64 (the build tag of this
+// file ensures it's only built on aarch64).
+type State struct {
+ // The system registers.
+ Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"`
+
+ // Our floating point state.
+ aarch64FPState `state:"wait"`
+
+ // FeatureSet is a pointer to the currently active feature set.
+ FeatureSet *cpuid.FeatureSet
+}
+
+// Proto returns a protobuf representation of the system registers in State.
+func (s State) Proto() *rpb.Registers {
+ regs := &rpb.ARM64Registers{
+ R0: s.Regs.Regs[0],
+ R1: s.Regs.Regs[1],
+ R2: s.Regs.Regs[2],
+ R3: s.Regs.Regs[3],
+ R4: s.Regs.Regs[4],
+ R5: s.Regs.Regs[5],
+ R6: s.Regs.Regs[6],
+ R7: s.Regs.Regs[7],
+ R8: s.Regs.Regs[8],
+ R9: s.Regs.Regs[9],
+ R10: s.Regs.Regs[10],
+ R11: s.Regs.Regs[11],
+ R12: s.Regs.Regs[12],
+ R13: s.Regs.Regs[13],
+ R14: s.Regs.Regs[14],
+ R15: s.Regs.Regs[15],
+ R16: s.Regs.Regs[16],
+ R17: s.Regs.Regs[17],
+ R18: s.Regs.Regs[18],
+ R19: s.Regs.Regs[19],
+ R20: s.Regs.Regs[20],
+ R21: s.Regs.Regs[21],
+ R22: s.Regs.Regs[22],
+ R23: s.Regs.Regs[23],
+ R24: s.Regs.Regs[24],
+ R25: s.Regs.Regs[25],
+ R26: s.Regs.Regs[26],
+ R27: s.Regs.Regs[27],
+ R28: s.Regs.Regs[28],
+ R29: s.Regs.Regs[29],
+ R30: s.Regs.Regs[30],
+ Sp: s.Regs.Sp,
+ Pc: s.Regs.Pc,
+ Pstate: s.Regs.Pstate,
+ }
+ return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}}
+}
+
+// Fork creates and returns an identical copy of the state.
+func (s *State) Fork() State {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return State{
+ Regs: s.Regs,
+ FeatureSet: s.FeatureSet,
+ }
+}
+
+// StateData implements Context.StateData.
+func (s *State) StateData() *State {
+ return s
+}
+
+// CPUIDEmulate emulates a cpuid instruction.
+func (s *State) CPUIDEmulate(l log.Logger) {
+ // TODO(gvisor.dev/issue/1255): cpuid is not supported.
+}
+
+// SingleStep implements Context.SingleStep.
+func (s *State) SingleStep() bool {
+ return false
+}
+
+// SetSingleStep enables single stepping.
+func (s *State) SetSingleStep() {
+ // Set the trap flag.
+ // TODO(gvisor.dev/issue/1239): ptrace single-step is not supported.
+}
+
+// ClearSingleStep enables single stepping.
+func (s *State) ClearSingleStep() {
+ // Clear the trap flag.
+ // TODO(gvisor.dev/issue/1239): ptrace single-step is not supported.
+}
+
+// RegisterMap returns a map of all registers.
+func (s *State) RegisterMap() (map[string]uintptr, error) {
+ return map[string]uintptr{
+ "R0": uintptr(s.Regs.Regs[0]),
+ "R1": uintptr(s.Regs.Regs[1]),
+ "R2": uintptr(s.Regs.Regs[2]),
+ "R3": uintptr(s.Regs.Regs[3]),
+ "R4": uintptr(s.Regs.Regs[4]),
+ "R5": uintptr(s.Regs.Regs[5]),
+ "R6": uintptr(s.Regs.Regs[6]),
+ "R7": uintptr(s.Regs.Regs[7]),
+ "R8": uintptr(s.Regs.Regs[8]),
+ "R9": uintptr(s.Regs.Regs[9]),
+ "R10": uintptr(s.Regs.Regs[10]),
+ "R11": uintptr(s.Regs.Regs[11]),
+ "R12": uintptr(s.Regs.Regs[12]),
+ "R13": uintptr(s.Regs.Regs[13]),
+ "R14": uintptr(s.Regs.Regs[14]),
+ "R15": uintptr(s.Regs.Regs[15]),
+ "R16": uintptr(s.Regs.Regs[16]),
+ "R17": uintptr(s.Regs.Regs[17]),
+ "R18": uintptr(s.Regs.Regs[18]),
+ "R19": uintptr(s.Regs.Regs[19]),
+ "R20": uintptr(s.Regs.Regs[20]),
+ "R21": uintptr(s.Regs.Regs[21]),
+ "R22": uintptr(s.Regs.Regs[22]),
+ "R23": uintptr(s.Regs.Regs[23]),
+ "R24": uintptr(s.Regs.Regs[24]),
+ "R25": uintptr(s.Regs.Regs[25]),
+ "R26": uintptr(s.Regs.Regs[26]),
+ "R27": uintptr(s.Regs.Regs[27]),
+ "R28": uintptr(s.Regs.Regs[28]),
+ "R29": uintptr(s.Regs.Regs[29]),
+ "R30": uintptr(s.Regs.Regs[30]),
+ "Sp": uintptr(s.Regs.Sp),
+ "Pc": uintptr(s.Regs.Pc),
+ "Pstate": uintptr(s.Regs.Pstate),
+ }, nil
+}
+
+// PtraceGetRegs implements Context.PtraceGetRegs.
+func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
+ return dst.Write(binary.Marshal(nil, usermem.ByteOrder, s.ptraceGetRegs()))
+}
+
+func (s *State) ptraceGetRegs() syscall.PtraceRegs {
+ return s.Regs
+}
+
+var ptraceRegsSize = int(binary.Size(syscall.PtraceRegs{}))
+
+// PtraceSetRegs implements Context.PtraceSetRegs.
+func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
+ var regs syscall.PtraceRegs
+ buf := make([]byte, ptraceRegsSize)
+ if _, err := io.ReadFull(src, buf); err != nil {
+ return 0, err
+ }
+ binary.Unmarshal(buf, usermem.ByteOrder, &regs)
+ s.Regs = regs
+ return ptraceRegsSize, nil
+}
+
+// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
+func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return 0, nil
+}
+
+// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
+func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return 0, nil
+}
+
+// Register sets defined in include/uapi/linux/elf.h.
+const (
+ _NT_PRSTATUS = 1
+ _NT_PRFPREG = 2
+)
+
+// PtraceGetRegSet implements Context.PtraceGetRegSet.
+func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < ptraceRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceGetRegs(dst)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// PtraceSetRegSet implements Context.PtraceSetRegSet.
+func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < ptraceRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceSetRegs(src)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// FullRestore indicates whether a full restore is required.
+func (s *State) FullRestore() bool {
+ return false
+}
+
+// New returns a new architecture context.
+func New(arch Arch, fs *cpuid.FeatureSet) Context {
+ switch arch {
+ case ARM64:
+ return &context64{
+ State{
+ FeatureSet: fs,
+ },
+ }
+ }
+ panic(fmt.Sprintf("unknown architecture %v", arch))
+}
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 67daa6c24..2aa08b1a9 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -174,8 +174,8 @@ func (c *context64) SetTLS(value uintptr) bool {
return true
}
-// SetRSEQInterruptedIP implements Context.SetRSEQInterruptedIP.
-func (c *context64) SetRSEQInterruptedIP(value uintptr) {
+// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
+func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
c.Regs.R10 = uint64(value)
}
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
new file mode 100644
index 000000000..0d5b7d317
--- /dev/null
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -0,0 +1,266 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "fmt"
+ "math/rand"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// Host specifies the host architecture.
+const Host = ARM64
+
+// These constants come directly from Linux.
+const (
+ // maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux
+ // for a 64-bit process.
+ maxAddr64 usermem.Addr = (1 << 48)
+
+ // maxStackRand64 is the maximum randomization to apply to the stack.
+ // It is defined by arch/arm64/mm/mmap.c:(STACK_RND_MASK << PAGE_SHIFT) in Linux.
+ maxStackRand64 = 0x3ffff << 12 // 16 GB
+
+ // maxMmapRand64 is the maximum randomization to apply to the mmap
+ // layout. It is defined by arch/arm64/mm/mmap.c:arch_mmap_rnd in Linux.
+ maxMmapRand64 = (1 << 33) * usermem.PageSize
+
+ // minGap64 is the minimum gap to leave at the top of the address space
+ // for the stack. It is defined by arch/arm64/mm/mmap.c:MIN_GAP in Linux.
+ minGap64 = (128 << 20) + maxStackRand64
+
+ // preferredPIELoadAddr is the standard Linux position-independent
+ // executable base load address. It is ELF_ET_DYN_BASE in Linux.
+ //
+ // The Platform {Min,Max}UserAddress() may preclude loading at this
+ // address. See other preferredFoo comments below.
+ preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5
+)
+
+// These constants are selected as heuristics to help make the Platform's
+// potentially limited address space conform as closely to Linux as possible.
+const (
+ preferredTopDownAllocMin usermem.Addr = 0x7e8000000000
+ preferredAllocationGap = 128 << 30 // 128 GB
+ preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
+
+ // minMmapRand64 is the smallest we are willing to make the
+ // randomization to stay above preferredTopDownBaseMin.
+ minMmapRand64 = (1 << 18) * usermem.PageSize
+)
+
+// context64 represents an ARM64 context.
+type context64 struct {
+ State
+}
+
+// Arch implements Context.Arch.
+func (c *context64) Arch() Arch {
+ return ARM64
+}
+
+// Fork returns an exact copy of this context.
+func (c *context64) Fork() Context {
+ return &context64{
+ State: c.State.Fork(),
+ }
+}
+
+// General purpose registers usage on Arm64:
+// R0...R7: parameter/result registers.
+// R8: indirect result location register.
+// R9...R15: temporary rgisters.
+// R16: the first intra-procedure-call scratch register.
+// R17: the second intra-procedure-call scratch register.
+// R18: the platform register.
+// R19...R28: callee-saved registers.
+// R29: the frame pointer.
+// R30: the link register.
+
+// Return returns the current syscall return value.
+func (c *context64) Return() uintptr {
+ return uintptr(c.Regs.Regs[0])
+}
+
+// SetReturn sets the syscall return value.
+func (c *context64) SetReturn(value uintptr) {
+ c.Regs.Regs[0] = uint64(value)
+}
+
+// IP returns the current instruction pointer.
+func (c *context64) IP() uintptr {
+ return uintptr(c.Regs.Pc)
+}
+
+// SetIP sets the current instruction pointer.
+func (c *context64) SetIP(value uintptr) {
+ c.Regs.Pc = uint64(value)
+}
+
+// Stack returns the current stack pointer.
+func (c *context64) Stack() uintptr {
+ return uintptr(c.Regs.Sp)
+}
+
+// SetStack sets the current stack pointer.
+func (c *context64) SetStack(value uintptr) {
+ c.Regs.Sp = uint64(value)
+}
+
+// TLS returns the current TLS pointer.
+func (c *context64) TLS() uintptr {
+ // TODO(gvisor.dev/issue/1238): TLS is not supported.
+ // MRS_TPIDR_EL0
+ return 0
+}
+
+// SetTLS sets the current TLS pointer. Returns false if value is invalid.
+func (c *context64) SetTLS(value uintptr) bool {
+ // TODO(gvisor.dev/issue/1238): TLS is not supported.
+ // MSR_TPIDR_EL0
+ return false
+}
+
+// SetRSEQInterruptedIP implements Context.SetRSEQInterruptedIP.
+func (c *context64) SetRSEQInterruptedIP(value uintptr) {
+ c.Regs.Regs[3] = uint64(value)
+}
+
+// Native returns the native type for the given val.
+func (c *context64) Native(val uintptr) interface{} {
+ v := uint64(val)
+ return &v
+}
+
+// Value returns the generic val for the given native type.
+func (c *context64) Value(val interface{}) uintptr {
+ return uintptr(*val.(*uint64))
+}
+
+// Width returns the byte width of this architecture.
+func (c *context64) Width() uint {
+ return 8
+}
+
+// FeatureSet returns the FeatureSet in use.
+func (c *context64) FeatureSet() *cpuid.FeatureSet {
+ return c.State.FeatureSet
+}
+
+// mmapRand returns a random adjustment for randomizing an mmap layout.
+func mmapRand(max uint64) usermem.Addr {
+ return usermem.Addr(rand.Int63n(int64(max))).RoundDown()
+}
+
+// NewMmapLayout implements Context.NewMmapLayout consistently with Linux.
+func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) {
+ min, ok := min.RoundUp()
+ if !ok {
+ return MmapLayout{}, syscall.EINVAL
+ }
+ if max > maxAddr64 {
+ max = maxAddr64
+ }
+ max = max.RoundDown()
+
+ if min > max {
+ return MmapLayout{}, syscall.EINVAL
+ }
+
+ stackSize := r.Get(limits.Stack)
+
+ // MAX_GAP in Linux.
+ maxGap := (max / 6) * 5
+ gap := usermem.Addr(stackSize.Cur)
+ if gap < minGap64 {
+ gap = minGap64
+ }
+ if gap > maxGap {
+ gap = maxGap
+ }
+ defaultDir := MmapTopDown
+ if stackSize.Cur == limits.Infinity {
+ defaultDir = MmapBottomUp
+ }
+
+ topDownMin := max - gap - maxMmapRand64
+ maxRand := usermem.Addr(maxMmapRand64)
+ if topDownMin < preferredTopDownBaseMin {
+ // Try to keep TopDownBase above preferredTopDownBaseMin by
+ // shrinking maxRand.
+ maxAdjust := maxRand - minMmapRand64
+ needAdjust := preferredTopDownBaseMin - topDownMin
+ if needAdjust <= maxAdjust {
+ maxRand -= needAdjust
+ }
+ }
+
+ rnd := mmapRand(uint64(maxRand))
+ l := MmapLayout{
+ MinAddr: min,
+ MaxAddr: max,
+ // TASK_UNMAPPED_BASE in Linux.
+ BottomUpBase: (max/3 + rnd).RoundDown(),
+ TopDownBase: (max - gap - rnd).RoundDown(),
+ DefaultDirection: defaultDir,
+ // We may have reduced the maximum randomization to keep
+ // TopDownBase above preferredTopDownBaseMin while maintaining
+ // our stack gap. Stack allocations must use that max
+ // randomization to avoiding eating into the gap.
+ MaxStackRand: uint64(maxRand),
+ }
+
+ // Final sanity check on the layout.
+ if !l.Valid() {
+ panic(fmt.Sprintf("Invalid MmapLayout: %+v", l))
+ }
+
+ return l, nil
+}
+
+// PIELoadAddress implements Context.PIELoadAddress.
+func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
+ base := preferredPIELoadAddr
+ max, ok := base.AddLength(maxMmapRand64)
+ if !ok {
+ panic(fmt.Sprintf("preferredPIELoadAddr %#x too large", base))
+ }
+
+ if max > l.MaxAddr {
+ // preferredPIELoadAddr won't fit; fall back to the standard
+ // Linux behavior of 2/3 of TopDownBase. TSAN won't like this.
+ //
+ // Don't bother trying to shrink the randomization for now.
+ base = l.TopDownBase / 3 * 2
+ }
+
+ return base + mmapRand(maxMmapRand64)
+}
+
+// PtracePeekUser implements Context.PtracePeekUser.
+func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+ // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
+ return c.Native(0), nil
+}
+
+// PtracePokeUser implements Context.PtracePokeUser.
+func (c *context64) PtracePokeUser(addr, data uintptr) error {
+ // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
+ return nil
+}
diff --git a/pkg/sentry/arch/arch_state_aarch64.go b/pkg/sentry/arch/arch_state_aarch64.go
new file mode 100644
index 000000000..0136a85ad
--- /dev/null
+++ b/pkg/sentry/arch/arch_state_aarch64.go
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+import (
+ "syscall"
+)
+
+type syscallPtraceRegs struct {
+ Regs [31]uint64
+ Sp uint64
+ Pc uint64
+ Pstate uint64
+}
+
+// saveRegs is invoked by stateify.
+func (s *State) saveRegs() syscallPtraceRegs {
+ return syscallPtraceRegs(s.Regs)
+}
+
+// loadRegs is invoked by stateify.
+func (s *State) loadRegs(r syscallPtraceRegs) {
+ s.Regs = syscall.PtraceRegs(r)
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 9061fcc86..84f11b0d1 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64 i386
+
package arch
import (
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index 9294ac773..9f41e566f 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -19,7 +19,6 @@ package arch
import (
"fmt"
"io"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/binary"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto
index 9dc83e241..60c027aab 100644
--- a/pkg/sentry/arch/registers.proto
+++ b/pkg/sentry/arch/registers.proto
@@ -48,8 +48,45 @@ message AMD64Registers {
uint64 gs_base = 27;
}
+message ARM64Registers {
+ uint64 r0 = 1;
+ uint64 r1 = 2;
+ uint64 r2 = 3;
+ uint64 r3 = 4;
+ uint64 r4 = 5;
+ uint64 r5 = 6;
+ uint64 r6 = 7;
+ uint64 r7 = 8;
+ uint64 r8 = 9;
+ uint64 r9 = 10;
+ uint64 r10 = 11;
+ uint64 r11 = 12;
+ uint64 r12 = 13;
+ uint64 r13 = 14;
+ uint64 r14 = 15;
+ uint64 r15 = 16;
+ uint64 r16 = 17;
+ uint64 r17 = 18;
+ uint64 r18 = 19;
+ uint64 r19 = 20;
+ uint64 r20 = 21;
+ uint64 r21 = 22;
+ uint64 r22 = 23;
+ uint64 r23 = 24;
+ uint64 r24 = 25;
+ uint64 r25 = 26;
+ uint64 r26 = 27;
+ uint64 r27 = 28;
+ uint64 r28 = 29;
+ uint64 r29 = 30;
+ uint64 r30 = 31;
+ uint64 sp = 32;
+ uint64 pc = 33;
+ uint64 pstate = 34;
+}
message Registers {
oneof arch {
AMD64Registers amd64 = 1;
+ ARM64Registers arm64 = 2;
}
}
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
new file mode 100644
index 000000000..402e46025
--- /dev/null
+++ b/pkg/sentry/arch/signal.go
@@ -0,0 +1,250 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// SignalAct represents the action that should be taken when a signal is
+// delivered, and is equivalent to struct sigaction.
+//
+// +stateify savable
+type SignalAct struct {
+ Handler uint64
+ Flags uint64
+ Restorer uint64 // Only used on amd64.
+ Mask linux.SignalSet
+}
+
+// SerializeFrom implements NativeSignalAct.SerializeFrom.
+func (s *SignalAct) SerializeFrom(other *SignalAct) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalAct.DeserializeTo.
+func (s *SignalAct) DeserializeTo(other *SignalAct) {
+ *other = *s
+}
+
+// SignalStack represents information about a user stack, and is equivalent to
+// stack_t.
+//
+// +stateify savable
+type SignalStack struct {
+ Addr uint64
+ Flags uint32
+ _ uint32
+ Size uint64
+}
+
+// SerializeFrom implements NativeSignalStack.SerializeFrom.
+func (s *SignalStack) SerializeFrom(other *SignalStack) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalStack.DeserializeTo.
+func (s *SignalStack) DeserializeTo(other *SignalStack) {
+ *other = *s
+}
+
+// SignalInfo represents information about a signal being delivered, and is
+// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
+//
+// +stateify savable
+type SignalInfo struct {
+ Signo int32 // Signal number
+ Errno int32 // Errno value
+ Code int32 // Signal code
+ _ uint32
+
+ // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
+ // are accessed through methods.
+ //
+ // For reference, here is the definition of _sifields: (_sigfault._trapno,
+ // which does not exist on x86, omitted for clarity)
+ //
+ // union {
+ // int _pad[SI_PAD_SIZE];
+ //
+ // /* kill() */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // } _kill;
+ //
+ // /* POSIX.1b timers */
+ // struct {
+ // __kernel_timer_t _tid; /* timer id */
+ // int _overrun; /* overrun count */
+ // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
+ // sigval_t _sigval; /* same as below */
+ // int _sys_private; /* not to be passed to user */
+ // } _timer;
+ //
+ // /* POSIX.1b signals */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // sigval_t _sigval;
+ // } _rt;
+ //
+ // /* SIGCHLD */
+ // struct {
+ // __kernel_pid_t _pid; /* which child */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // int _status; /* exit code */
+ // __ARCH_SI_CLOCK_T _utime;
+ // __ARCH_SI_CLOCK_T _stime;
+ // } _sigchld;
+ //
+ // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
+ // struct {
+ // void *_addr; /* faulting insn/memory ref. */
+ // short _addr_lsb; /* LSB of the reported address */
+ // } _sigfault;
+ //
+ // /* SIGPOLL */
+ // struct {
+ // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
+ // int _fd;
+ // } _sigpoll;
+ //
+ // /* SIGSYS */
+ // struct {
+ // void *_call_addr; /* calling user insn */
+ // int _syscall; /* triggering system call number */
+ // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
+ // } _sigsys;
+ // } _sifields;
+ //
+ // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
+ // bytes.
+ Fields [128 - 16]byte
+}
+
+// FixSignalCodeForUser fixes up si_code.
+//
+// The si_code we get from Linux may contain the kernel-specific code in the
+// top 16 bits if it's positive (e.g., from ptrace). Linux's
+// copy_siginfo_to_user does
+// err |= __put_user((short)from->si_code, &to->si_code);
+// to mask out those bits and we need to do the same.
+func (s *SignalInfo) FixSignalCodeForUser() {
+ if s.Code > 0 {
+ s.Code &= 0x0000ffff
+ }
+}
+
+// Pid returns the si_pid field.
+func (s *SignalInfo) Pid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetPid mutates the si_pid field.
+func (s *SignalInfo) SetPid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Uid returns the si_uid field.
+func (s *SignalInfo) Uid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetUid mutates the si_uid field.
+func (s *SignalInfo) SetUid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
+func (s *SignalInfo) Sigval() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[8:16])
+}
+
+// SetSigval mutates the sigval field.
+func (s *SignalInfo) SetSigval(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
+}
+
+// TimerID returns the si_timerid field.
+func (s *SignalInfo) TimerID() linux.TimerID {
+ return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetTimerID sets the si_timerid field.
+func (s *SignalInfo) SetTimerID(val linux.TimerID) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Overrun returns the si_overrun field.
+func (s *SignalInfo) Overrun() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetOverrun sets the si_overrun field.
+func (s *SignalInfo) SetOverrun(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Addr returns the si_addr field.
+func (s *SignalInfo) Addr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetAddr sets the si_addr field.
+func (s *SignalInfo) SetAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Status returns the si_status field.
+func (s *SignalInfo) Status() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetStatus mutates the si_status field.
+func (s *SignalInfo) SetStatus(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// CallAddr returns the si_call_addr field.
+func (s *SignalInfo) CallAddr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetCallAddr mutates the si_call_addr field.
+func (s *SignalInfo) SetCallAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Syscall returns the si_syscall field.
+func (s *SignalInfo) Syscall() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetSyscall mutates the si_syscall field.
+func (s *SignalInfo) SetSyscall(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// Arch returns the si_arch field.
+func (s *SignalInfo) Arch() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[12:16])
+}
+
+// SetArch mutates the si_arch field.
+func (s *SignalInfo) SetArch(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
+}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index febd6f9b9..1e4f9c3c2 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -26,236 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
-// SignalAct represents the action that should be taken when a signal is
-// delivered, and is equivalent to struct sigaction on 64-bit x86.
-//
-// +stateify savable
-type SignalAct struct {
- Handler uint64
- Flags uint64
- Restorer uint64
- Mask linux.SignalSet
-}
-
-// SerializeFrom implements NativeSignalAct.SerializeFrom.
-func (s *SignalAct) SerializeFrom(other *SignalAct) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalAct.DeserializeTo.
-func (s *SignalAct) DeserializeTo(other *SignalAct) {
- *other = *s
-}
-
-// SignalStack represents information about a user stack, and is equivalent to
-// stack_t on 64-bit x86.
-//
-// +stateify savable
-type SignalStack struct {
- Addr uint64
- Flags uint32
- _ uint32
- Size uint64
-}
-
-// SerializeFrom implements NativeSignalStack.SerializeFrom.
-func (s *SignalStack) SerializeFrom(other *SignalStack) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalStack.DeserializeTo.
-func (s *SignalStack) DeserializeTo(other *SignalStack) {
- *other = *s
-}
-
-// SignalInfo represents information about a signal being delivered, and is
-// equivalent to struct siginfo on 64-bit x86.
-//
-// +stateify savable
-type SignalInfo struct {
- Signo int32 // Signal number
- Errno int32 // Errno value
- Code int32 // Signal code
- _ uint32
-
- // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
- // are accessed through methods.
- //
- // For reference, here is the definition of _sifields: (_sigfault._trapno,
- // which does not exist on x86, omitted for clarity)
- //
- // union {
- // int _pad[SI_PAD_SIZE];
- //
- // /* kill() */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // } _kill;
- //
- // /* POSIX.1b timers */
- // struct {
- // __kernel_timer_t _tid; /* timer id */
- // int _overrun; /* overrun count */
- // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
- // sigval_t _sigval; /* same as below */
- // int _sys_private; /* not to be passed to user */
- // } _timer;
- //
- // /* POSIX.1b signals */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // sigval_t _sigval;
- // } _rt;
- //
- // /* SIGCHLD */
- // struct {
- // __kernel_pid_t _pid; /* which child */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // int _status; /* exit code */
- // __ARCH_SI_CLOCK_T _utime;
- // __ARCH_SI_CLOCK_T _stime;
- // } _sigchld;
- //
- // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
- // struct {
- // void *_addr; /* faulting insn/memory ref. */
- // short _addr_lsb; /* LSB of the reported address */
- // } _sigfault;
- //
- // /* SIGPOLL */
- // struct {
- // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
- // int _fd;
- // } _sigpoll;
- //
- // /* SIGSYS */
- // struct {
- // void *_call_addr; /* calling user insn */
- // int _syscall; /* triggering system call number */
- // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
- // } _sigsys;
- // } _sifields;
- //
- // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
- // bytes.
- Fields [128 - 16]byte
-}
-
-// FixSignalCodeForUser fixes up si_code.
-//
-// The si_code we get from Linux may contain the kernel-specific code in the
-// top 16 bits if it's positive (e.g., from ptrace). Linux's
-// copy_siginfo_to_user does
-// err |= __put_user((short)from->si_code, &to->si_code);
-// to mask out those bits and we need to do the same.
-func (s *SignalInfo) FixSignalCodeForUser() {
- if s.Code > 0 {
- s.Code &= 0x0000ffff
- }
-}
-
-// Pid returns the si_pid field.
-func (s *SignalInfo) Pid() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetPid mutates the si_pid field.
-func (s *SignalInfo) SetPid(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// Uid returns the si_uid field.
-func (s *SignalInfo) Uid() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetUid mutates the si_uid field.
-func (s *SignalInfo) SetUid(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
-func (s *SignalInfo) Sigval() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[8:16])
-}
-
-// SetSigval mutates the sigval field.
-func (s *SignalInfo) SetSigval(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
-}
-
-// TimerID returns the si_timerid field.
-func (s *SignalInfo) TimerID() linux.TimerID {
- return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetTimerID sets the si_timerid field.
-func (s *SignalInfo) SetTimerID(val linux.TimerID) {
- usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// Overrun returns the si_overrun field.
-func (s *SignalInfo) Overrun() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetOverrun sets the si_overrun field.
-func (s *SignalInfo) SetOverrun(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Addr returns the si_addr field.
-func (s *SignalInfo) Addr() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetAddr sets the si_addr field.
-func (s *SignalInfo) SetAddr(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Status returns the si_status field.
-func (s *SignalInfo) Status() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetStatus mutates the si_status field.
-func (s *SignalInfo) SetStatus(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// CallAddr returns the si_call_addr field.
-func (s *SignalInfo) CallAddr() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetCallAddr mutates the si_call_addr field.
-func (s *SignalInfo) SetCallAddr(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Syscall returns the si_syscall field.
-func (s *SignalInfo) Syscall() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetSyscall mutates the si_syscall field.
-func (s *SignalInfo) SetSyscall(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// Arch returns the si_arch field.
-func (s *SignalInfo) Arch() uint32 {
- return usermem.ByteOrder.Uint32(s.Fields[12:16])
-}
-
-// SetArch mutates the si_arch field.
-func (s *SignalInfo) SetArch(val uint32) {
- usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
-}
-
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
// second argument to signal handlers set by signal(2).
type SignalContext64 struct {
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
new file mode 100644
index 000000000..7d0e98935
--- /dev/null
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -0,0 +1,126 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "encoding/binary"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// SignalContext64 is equivalent to struct sigcontext, the type passed as the
+// second argument to signal handlers set by signal(2).
+type SignalContext64 struct {
+ FaultAddr uint64
+ Regs [31]uint64
+ Sp uint64
+ Pc uint64
+ Pstate uint64
+ _pad [8]byte // __attribute__((__aligned__(16)))
+ Reserved [4096]uint8
+}
+
+// UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h).
+type UContext64 struct {
+ Flags uint64
+ Link *UContext64
+ Stack SignalStack
+ Sigset linux.SignalSet
+ // glibc uses a 1024-bit sigset_t
+ _pad [(1024 - 64) / 8]byte
+ // sigcontext must be aligned to 16-byte
+ _pad2 [8]byte
+ // last for future expansion
+ MContext SignalContext64
+}
+
+// NewSignalAct implements Context.NewSignalAct.
+func (c *context64) NewSignalAct() NativeSignalAct {
+ return &SignalAct{}
+}
+
+// NewSignalStack implements Context.NewSignalStack.
+func (c *context64) NewSignalStack() NativeSignalStack {
+ return &SignalStack{}
+}
+
+// SignalSetup implements Context.SignalSetup.
+func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+ sp := st.Bottom
+
+ if !(alt.IsEnabled() && sp == alt.Top()) {
+ sp -= 128
+ }
+
+ // Construct the UContext64 now since we need its size.
+ uc := &UContext64{
+ Flags: 0,
+ Stack: *alt,
+ MContext: SignalContext64{
+ Regs: c.Regs.Regs,
+ Sp: c.Regs.Sp,
+ Pc: c.Regs.Pc,
+ Pstate: c.Regs.Pstate,
+ },
+ Sigset: sigset,
+ }
+
+ ucSize := binary.Size(uc)
+ if ucSize < 0 {
+ panic("can't get size of UContext64")
+ }
+ // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128.
+ frameSize := int(st.Arch.Width()) + ucSize + 128
+ frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8
+ sp = frameBottom + usermem.Addr(frameSize)
+ st.Bottom = sp
+
+ // Prior to proceeding, figure out if the frame will exhaust the range
+ // for the signal stack. This is not allowed, and should immediately
+ // force signal delivery (reverting to the default handler).
+ if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ return syscall.EFAULT
+ }
+
+ // Adjust the code.
+ info.FixSignalCodeForUser()
+
+ // Set up the stack frame.
+ infoAddr, err := st.Push(info)
+ if err != nil {
+ return err
+ }
+ ucAddr, err := st.Push(uc)
+ if err != nil {
+ return err
+ }
+
+ // Set up registers.
+ c.Regs.Sp = uint64(st.Bottom)
+ c.Regs.Pc = act.Handler
+ c.Regs.Regs[0] = uint64(info.Signo)
+ c.Regs.Regs[1] = uint64(infoAddr)
+ c.Regs.Regs[2] = uint64(ucAddr)
+
+ return nil
+}
+
+// SignalRestore implements Context.SignalRestore.
+// Only used on intel.
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+ return 0, SignalStack{}, nil
+}
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index 5a3228113..d324da705 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build i386 amd64 arm64
package arch
diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go
new file mode 100644
index 000000000..00d5ef461
--- /dev/null
+++ b/pkg/sentry/arch/syscalls_arm64.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+const restartSyscallNr = uintptr(128)
+
+// SyscallNo returns the syscall number according to the 64-bit convention.
+func (c *context64) SyscallNo() uintptr {
+ return uintptr(c.Regs.Regs[8])
+}
+
+// SyscallArgs provides syscall arguments according to the 64-bit convention.
+//
+// Due to the way addresses are mapped for the sentry this binary *must* be
+// built in 64-bit mode. So we can just assume the syscall numbers that come
+// back match the expected host system call numbers.
+// General purpose registers usage on Arm64:
+// R0...R7: parameter/result registers.
+// R8: indirect result location register.
+// R9...R15: temporary registers.
+// R16: the first intra-procedure-call scratch register.
+// R17: the second intra-procedure-call scratch register.
+// R18: the platform register.
+// R19...R28: callee-saved registers.
+// R29: the frame pointer.
+// R30: the link register.
+func (c *context64) SyscallArgs() SyscallArguments {
+ return SyscallArguments{
+ SyscallArgument{Value: uintptr(c.Regs.Regs[0])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[1])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[2])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[3])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[4])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[5])},
+ }
+}
+
+// RestartSyscall implements Context.RestartSyscall.
+func (c *context64) RestartSyscall() {
+ c.Regs.Pc -= SyscallWidth
+ c.Regs.Regs[8] = uint64(restartSyscallNr)
+}
+
+// RestartSyscallWithRestartBlock implements Context.RestartSyscallWithRestartBlock.
+func (c *context64) RestartSyscallWithRestartBlock() {
+ c.Regs.Pc -= SyscallWidth
+ c.Regs.Regs[8] = uint64(restartSyscallNr)
+}
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index 5522cecd0..2561a6109 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/sentry/strace",
"//pkg/sentry/usage",
"//pkg/sentry/watchdog",
+ "//pkg/sync",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
],
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index e1f2fea60..151808911 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -19,10 +19,10 @@ import (
"runtime"
"runtime/pprof"
"runtime/trace"
- "sync"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/urpc"
)
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 1098ed777..97fa1512c 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -8,7 +8,10 @@ go_library(
srcs = ["device.go"],
importpath = "gvisor.dev/gvisor/pkg/sentry/device",
visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/abi/linux"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sync",
+ ],
)
go_test(
diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go
index 47945d1a7..69e71e322 100644
--- a/pkg/sentry/device/device.go
+++ b/pkg/sentry/device/device.go
@@ -19,10 +19,10 @@ package device
import (
"bytes"
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Registry tracks all simple devices and related state on the system for
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index c035ffff7..7d5d72d5a 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -68,7 +68,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/state",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
@@ -115,6 +115,7 @@ go_test(
"//pkg/sentry/fs/tmpfs",
"//pkg/sentry/kernel/contexttest",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index 9ac62c84d..e03e3e417 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -17,12 +17,13 @@ package fs
import (
"fmt"
"io"
- "sync"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -395,12 +396,12 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
// Size and permissions are set on upper when the file content is copied
// and when the file is created respectively.
func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error {
- // Extract attributes fro the lower filesystem.
+ // Extract attributes from the lower filesystem.
lowerAttr, err := lower.UnstableAttr(ctx)
if err != nil {
return err
}
- lowerXattr, err := lower.Listxattr()
+ lowerXattr, err := lower.ListXattr(ctx)
if err != nil && err != syserror.EOPNOTSUPP {
return err
}
@@ -421,11 +422,11 @@ func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error
if isXattrOverlay(name) {
continue
}
- value, err := lower.Getxattr(name)
+ value, err := lower.GetXattr(ctx, name, linux.XATTR_SIZE_MAX)
if err != nil {
return err
}
- if err := upper.InodeOperations.Setxattr(upper, name, value); err != nil {
+ if err := upper.InodeOperations.SetXattr(ctx, upper, name, value, 0 /* flags */); err != nil {
return err
}
}
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
index 1d80bf15a..738580c5f 100644
--- a/pkg/sentry/fs/copy_up_test.go
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -19,13 +19,13 @@ import (
"crypto/rand"
"fmt"
"io"
- "sync"
"testing"
"gvisor.dev/gvisor/pkg/sentry/fs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index 3cb73bd78..31fc4d87b 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -18,7 +18,6 @@ import (
"fmt"
"path"
"sort"
- "sync"
"sync/atomic"
"syscall"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go
index 60a15a275..25514ace4 100644
--- a/pkg/sentry/fs/dirent_cache.go
+++ b/pkg/sentry/fs/dirent_cache.go
@@ -16,7 +16,8 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// DirentCache is an LRU cache of Dirents. The Dirent's refCount is
diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go
index ebb80bd50..525ee25f9 100644
--- a/pkg/sentry/fs/dirent_cache_limiter.go
+++ b/pkg/sentry/fs/dirent_cache_limiter.go
@@ -16,7 +16,8 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// DirentCacheLimiter acts as a global limit for all dirent caches in the
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index 277ee4c31..cc43de69d 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -23,6 +23,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/safemem",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index 669ffcb75..5b6cfeb0a 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -17,7 +17,6 @@ package fdpipe
import (
"os"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/fd"
@@ -29,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go
index 29175fb3d..cee87f726 100644
--- a/pkg/sentry/fs/fdpipe/pipe_state.go
+++ b/pkg/sentry/fs/fdpipe/pipe_state.go
@@ -17,10 +17,10 @@ package fdpipe
import (
"fmt"
"io/ioutil"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
)
// beforeSave is invoked by stateify.
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index c0a6e884b..7c4586296 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -16,7 +16,6 @@ package fs
import (
"math"
- "sync"
"sync/atomic"
"time"
@@ -29,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -555,6 +555,10 @@ type lockedWriter struct {
//
// This applies only to Write, not WriteAt.
Offset int64
+
+ // Err contains the first error encountered while copying. This is
+ // useful to determine whether Writer or Reader failed during io.Copy.
+ Err error
}
// Write implements io.Writer.Write.
@@ -590,5 +594,8 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
break
}
}
+ if w.Err == nil {
+ w.Err = err
+ }
return written, err
}
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 225e40186..8991207b4 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -16,13 +16,13 @@ package fs
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -475,7 +475,7 @@ func readdirEntries(ctx context.Context, o *overlayEntry) (*SortedDentryMap, err
// Skip this name if it is a negative entry in the
// upper or there exists a whiteout for it.
if o.upper != nil {
- if overlayHasWhiteout(o.upper, name) {
+ if overlayHasWhiteout(ctx, o.upper, name) {
continue
}
}
diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go
index b157fd228..c5b51620a 100644
--- a/pkg/sentry/fs/filesystems.go
+++ b/pkg/sentry/fs/filesystems.go
@@ -18,9 +18,9 @@ import (
"fmt"
"sort"
"strings"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags.
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index 8b2a5e6b2..26abf49e2 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -54,10 +54,9 @@
package fs
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
var (
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index b2e8d9c77..945b6270d 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -53,7 +53,7 @@ go_template_instance(
"Key": "uint64",
"Range": "memmap.MappableRange",
"Value": "uint64",
- "Functions": "fileRangeSetFunctions",
+ "Functions": "FileRangeSetFunctions",
},
)
@@ -93,6 +93,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/state",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 0a5466b0a..f52d712e3 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -34,25 +34,25 @@ import (
//
// type FileRangeSet <generated by go_generics>
-// fileRangeSetFunctions implements segment.Functions for FileRangeSet.
-type fileRangeSetFunctions struct{}
+// FileRangeSetFunctions implements segment.Functions for FileRangeSet.
+type FileRangeSetFunctions struct{}
// MinKey implements segment.Functions.MinKey.
-func (fileRangeSetFunctions) MinKey() uint64 {
+func (FileRangeSetFunctions) MinKey() uint64 {
return 0
}
// MaxKey implements segment.Functions.MaxKey.
-func (fileRangeSetFunctions) MaxKey() uint64 {
+func (FileRangeSetFunctions) MaxKey() uint64 {
return math.MaxUint64
}
// ClearValue implements segment.Functions.ClearValue.
-func (fileRangeSetFunctions) ClearValue(_ *uint64) {
+func (FileRangeSetFunctions) ClearValue(_ *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
+func (FileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
if frstart1+mr1.Length() != frstart2 {
return 0, false
}
@@ -60,7 +60,7 @@ func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _
}
// Split implements segment.Functions.Split.
-func (fileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
+func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
return frstart, frstart + (split - mr.Start)
}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index b06a71cc2..837fc70b5 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -16,7 +16,6 @@ package fsutil
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/log"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// HostFileMapper caches mappings of an arbitrary host file descriptor. It is
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index 30475f340..a625f0e26 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -16,7 +16,6 @@ package fsutil
import (
"math"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// HostMappable implements memmap.Mappable and platform.File over a
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index 4e100a402..df7b74855 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -15,13 +15,13 @@
package fsutil
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -203,7 +203,7 @@ func (i *InodeSimpleAttributes) NotifyModificationAndStatusChange(ctx context.Co
}
// InodeSimpleExtendedAttributes implements
-// fs.InodeOperations.{Get,Set,List}xattr.
+// fs.InodeOperations.{Get,Set,List}Xattr.
//
// +stateify savable
type InodeSimpleExtendedAttributes struct {
@@ -212,8 +212,8 @@ type InodeSimpleExtendedAttributes struct {
xattrs map[string]string
}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (i *InodeSimpleExtendedAttributes) Getxattr(_ *fs.Inode, name string) (string, error) {
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (i *InodeSimpleExtendedAttributes) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) {
i.mu.RLock()
value, ok := i.xattrs[name]
i.mu.RUnlock()
@@ -223,19 +223,31 @@ func (i *InodeSimpleExtendedAttributes) Getxattr(_ *fs.Inode, name string) (stri
return value, nil
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (i *InodeSimpleExtendedAttributes) Setxattr(_ *fs.Inode, name, value string) error {
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (i *InodeSimpleExtendedAttributes) SetXattr(_ context.Context, _ *fs.Inode, name, value string, flags uint32) error {
i.mu.Lock()
+ defer i.mu.Unlock()
if i.xattrs == nil {
+ if flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
i.xattrs = make(map[string]string)
}
+
+ _, ok := i.xattrs[name]
+ if ok && flags&linux.XATTR_CREATE != 0 {
+ return syserror.EEXIST
+ }
+ if !ok && flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+
i.xattrs[name] = value
- i.mu.Unlock()
return nil
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (i *InodeSimpleExtendedAttributes) Listxattr(_ *fs.Inode) (map[string]struct{}, error) {
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
i.mu.RLock()
names := make(map[string]struct{}, len(i.xattrs))
for name := range i.xattrs {
@@ -437,18 +449,18 @@ func (InodeNotSymlink) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
// extended attributes.
type InodeNoExtendedAttributes struct{}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (InodeNoExtendedAttributes) Getxattr(*fs.Inode, string) (string, error) {
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (InodeNoExtendedAttributes) GetXattr(context.Context, *fs.Inode, string, uint64) (string, error) {
return "", syserror.EOPNOTSUPP
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (InodeNoExtendedAttributes) Setxattr(*fs.Inode, string, string) error {
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (InodeNoExtendedAttributes) SetXattr(context.Context, *fs.Inode, string, string, uint32) error {
return syserror.EOPNOTSUPP
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (InodeNoExtendedAttributes) Listxattr(*fs.Inode) (map[string]struct{}, error) {
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (InodeNoExtendedAttributes) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
return nil, syserror.EOPNOTSUPP
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 798920d18..20a014402 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -17,7 +17,6 @@ package fsutil
import (
"fmt"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -30,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Lock order (compare the lock order model in mm/mm.go):
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 4a005c605..fd870e8e1 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -44,6 +44,7 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/unet",
diff --git a/pkg/sentry/fs/gofer/context_file.go b/pkg/sentry/fs/gofer/context_file.go
index 44b72582a..2125dafef 100644
--- a/pkg/sentry/fs/gofer/context_file.go
+++ b/pkg/sentry/fs/gofer/context_file.go
@@ -59,6 +59,20 @@ func (c *contextFile) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9
return err
}
+func (c *contextFile) getXattr(ctx context.Context, name string, size uint64) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ val, err := c.file.GetXattr(name, size)
+ ctx.UninterruptibleSleepFinish(false)
+ return val, err
+}
+
+func (c *contextFile) setXattr(ctx context.Context, name, value string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.SetXattr(name, value, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
func (c *contextFile) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
ctx.UninterruptibleSleepStart(false)
err := c.file.Allocate(mode, offset, length)
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 91263ebdc..98d1a8a48 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -16,7 +16,6 @@ package gofer
import (
"errors"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/host"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -38,8 +38,7 @@ import (
//
// +stateify savable
type inodeOperations struct {
- fsutil.InodeNotVirtual `state:"nosave"`
- fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNotVirtual `state:"nosave"`
// fileState implements fs.CachedFileObject. It exists
// to break a circular load dependency between inodeOperations
@@ -604,6 +603,21 @@ func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length
return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)})
}
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (i *inodeOperations) GetXattr(ctx context.Context, inode *fs.Inode, name string, size uint64) (string, error) {
+ return i.fileState.file.getXattr(ctx, name, size)
+}
+
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (i *inodeOperations) SetXattr(ctx context.Context, inode *fs.Inode, name string, value string, flags uint32) error {
+ return i.fileState.file.setXattr(ctx, name, value, flags)
+}
+
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (i *inodeOperations) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
+ return nil, syscall.EOPNOTSUPP
+}
+
// Allocate implements fs.InodeOperations.Allocate.
func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, length int64) error {
// This can only be called for files anyway.
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 4e358a46a..edc796ce0 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -16,7 +16,6 @@ package gofer
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refs"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 23daeb528..2b581aa69 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -50,6 +50,7 @@ go_library(
"//pkg/sentry/unimpl",
"//pkg/sentry/uniqueid",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index a6e4a09e3..873a1c52d 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -15,7 +15,6 @@
package host
import (
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 107336a3e..c076d5bdd 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -16,7 +16,6 @@ package host
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -30,6 +29,7 @@ import (
unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 90331e3b2..753ef8cd6 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -15,8 +15,6 @@
package host
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -24,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index 91e2fde2f..e4cf5a570 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -15,8 +15,6 @@
package fs
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
@@ -26,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -262,28 +261,28 @@ func (i *Inode) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
return i.InodeOperations.UnstableAttr(ctx, i)
}
-// Getxattr calls i.InodeOperations.Getxattr with i as the Inode.
-func (i *Inode) Getxattr(name string) (string, error) {
+// GetXattr calls i.InodeOperations.GetXattr with i as the Inode.
+func (i *Inode) GetXattr(ctx context.Context, name string, size uint64) (string, error) {
if i.overlay != nil {
- return overlayGetxattr(i.overlay, name)
+ return overlayGetXattr(ctx, i.overlay, name, size)
}
- return i.InodeOperations.Getxattr(i, name)
+ return i.InodeOperations.GetXattr(ctx, i, name, size)
}
-// Setxattr calls i.InodeOperations.Setxattr with i as the Inode.
-func (i *Inode) Setxattr(name, value string) error {
+// SetXattr calls i.InodeOperations.SetXattr with i as the Inode.
+func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, flags uint32) error {
if i.overlay != nil {
- return overlaySetxattr(i.overlay, name, value)
+ return overlaySetxattr(ctx, i.overlay, d, name, value, flags)
}
- return i.InodeOperations.Setxattr(i, name, value)
+ return i.InodeOperations.SetXattr(ctx, i, name, value, flags)
}
-// Listxattr calls i.InodeOperations.Listxattr with i as the Inode.
-func (i *Inode) Listxattr() (map[string]struct{}, error) {
+// ListXattr calls i.InodeOperations.ListXattr with i as the Inode.
+func (i *Inode) ListXattr(ctx context.Context) (map[string]struct{}, error) {
if i.overlay != nil {
- return overlayListxattr(i.overlay)
+ return overlayListXattr(ctx, i.overlay)
}
- return i.InodeOperations.Listxattr(i)
+ return i.InodeOperations.ListXattr(ctx, i)
}
// CheckPermission will check if the caller may access this file in the
diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go
index 0f2a66a79..efd3c962b 100644
--- a/pkg/sentry/fs/inode_inotify.go
+++ b/pkg/sentry/fs/inode_inotify.go
@@ -16,7 +16,8 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// Watches is the collection of inotify watches on an inode.
diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go
index 5cde9d215..13261cb81 100644
--- a/pkg/sentry/fs/inode_operations.go
+++ b/pkg/sentry/fs/inode_operations.go
@@ -170,20 +170,27 @@ type InodeOperations interface {
// file system events.
UnstableAttr(ctx context.Context, inode *Inode) (UnstableAttr, error)
- // Getxattr retrieves the value of extended attribute name. Inodes that
- // do not support extended attributes return EOPNOTSUPP. Inodes that
- // support extended attributes but don't have a value at name return
+ // GetXattr retrieves the value of extended attribute specified by name.
+ // Inodes that do not support extended attributes return EOPNOTSUPP. Inodes
+ // that support extended attributes but don't have a value at name return
// ENODATA.
- Getxattr(inode *Inode, name string) (string, error)
+ //
+ // If this is called through the getxattr(2) syscall, size indicates the
+ // size of the buffer that the application has allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ GetXattr(ctx context.Context, inode *Inode, name string, size uint64) (string, error)
- // Setxattr sets the value of extended attribute name. Inodes that
- // do not support extended attributes return EOPNOTSUPP.
- Setxattr(inode *Inode, name, value string) error
+ // SetXattr sets the value of extended attribute specified by name. Inodes
+ // that do not support extended attributes return EOPNOTSUPP.
+ SetXattr(ctx context.Context, inode *Inode, name, value string, flags uint32) error
- // Listxattr returns the set of all extended attributes names that
+ // ListXattr returns the set of all extended attributes names that
// have values. Inodes that do not support extended attributes return
// EOPNOTSUPP.
- Listxattr(inode *Inode) (map[string]struct{}, error)
+ ListXattr(ctx context.Context, inode *Inode) (map[string]struct{}, error)
// Check determines whether an Inode can be accessed with the
// requested permission mask using the context (which gives access
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index 13d11e001..c477de837 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -25,13 +25,13 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-func overlayHasWhiteout(parent *Inode, name string) bool {
- s, err := parent.Getxattr(XattrOverlayWhiteout(name))
+func overlayHasWhiteout(ctx context.Context, parent *Inode, name string) bool {
+ s, err := parent.GetXattr(ctx, XattrOverlayWhiteout(name), 1)
return err == nil && s == "y"
}
-func overlayCreateWhiteout(parent *Inode, name string) error {
- return parent.InodeOperations.Setxattr(parent, XattrOverlayWhiteout(name), "y")
+func overlayCreateWhiteout(ctx context.Context, parent *Inode, name string) error {
+ return parent.InodeOperations.SetXattr(ctx, parent, XattrOverlayWhiteout(name), "y", 0 /* flags */)
}
func overlayWriteOut(ctx context.Context, o *overlayEntry) error {
@@ -89,7 +89,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
}
// Are we done?
- if overlayHasWhiteout(parent.upper, name) {
+ if overlayHasWhiteout(ctx, parent.upper, name) {
if upperInode == nil {
parent.copyMu.RUnlock()
if negativeUpperChild {
@@ -345,7 +345,7 @@ func overlayRemove(ctx context.Context, o *overlayEntry, parent *Dirent, child *
}
}
if child.Inode.overlay.lowerExists {
- if err := overlayCreateWhiteout(o.upper, child.name); err != nil {
+ if err := overlayCreateWhiteout(ctx, o.upper, child.name); err != nil {
return err
}
}
@@ -426,7 +426,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
return err
}
if renamed.Inode.overlay.lowerExists {
- if err := overlayCreateWhiteout(oldParent.Inode.overlay.upper, oldName); err != nil {
+ if err := overlayCreateWhiteout(ctx, oldParent.Inode.overlay.upper, oldName); err != nil {
return err
}
}
@@ -528,7 +528,7 @@ func overlayUnstableAttr(ctx context.Context, o *overlayEntry) (UnstableAttr, er
return attr, err
}
-func overlayGetxattr(o *overlayEntry, name string) (string, error) {
+func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uint64) (string, error) {
// Hot path. This is how the overlay checks for whiteout files.
// Avoid defers.
var (
@@ -544,31 +544,38 @@ func overlayGetxattr(o *overlayEntry, name string) (string, error) {
o.copyMu.RLock()
if o.upper != nil {
- s, err = o.upper.Getxattr(name)
+ s, err = o.upper.GetXattr(ctx, name, size)
} else {
- s, err = o.lower.Getxattr(name)
+ s, err = o.lower.GetXattr(ctx, name, size)
}
o.copyMu.RUnlock()
return s, err
}
-// TODO(b/146028302): Support setxattr for overlayfs.
-func overlaySetxattr(o *overlayEntry, name, value string) error {
- return syserror.EOPNOTSUPP
+func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error {
+ // Don't allow changes to overlay xattrs through a setxattr syscall.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return syserror.EPERM
+ }
+
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.SetXattr(ctx, d, name, value, flags)
}
-func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) {
+func overlayListXattr(ctx context.Context, o *overlayEntry) (map[string]struct{}, error) {
o.copyMu.RLock()
defer o.copyMu.RUnlock()
var names map[string]struct{}
var err error
if o.upper != nil {
- names, err = o.upper.Listxattr()
+ names, err = o.upper.ListXattr(ctx)
} else {
- names, err = o.lower.Listxattr()
+ names, err = o.lower.ListXattr(ctx)
}
for name := range names {
- // Same as overlayGetxattr, we shouldn't forward along
+ // Same as overlayGetXattr, we shouldn't forward along
// overlay attributes.
if strings.HasPrefix(XattrOverlayPrefix, name) {
delete(names, name)
diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go
index 8935aad65..493d98c36 100644
--- a/pkg/sentry/fs/inode_overlay_test.go
+++ b/pkg/sentry/fs/inode_overlay_test.go
@@ -382,8 +382,8 @@ type dir struct {
ReaddirCalled bool
}
-// Getxattr implements InodeOperations.Getxattr.
-func (d *dir) Getxattr(inode *fs.Inode, name string) (string, error) {
+// GetXattr implements InodeOperations.GetXattr.
+func (d *dir) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) {
for _, n := range d.negative {
if name == fs.XattrOverlayWhiteout(n) {
return "y", nil
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index ba3e0233d..cc7dd1c92 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -16,7 +16,6 @@ package fs
import (
"io"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go
index 0aa0a5e9b..900cba3ca 100644
--- a/pkg/sentry/fs/inotify_watch.go
+++ b/pkg/sentry/fs/inotify_watch.go
@@ -15,10 +15,10 @@
package fs
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Watch represent a particular inotify watch created by inotify_add_watch.
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 8d62642e7..2c332a82a 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -44,6 +44,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 636484424..926538d90 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -52,9 +52,9 @@ package lock
import (
"fmt"
"math"
- "sync"
"syscall"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -78,6 +78,9 @@ const (
)
// LockEOF is the maximal possible end of a regional file lock.
+//
+// A BSD-style full file lock can be represented as a regional file lock from
+// offset 0 to LockEOF.
const LockEOF = math.MaxUint64
// Lock is a regional file lock. It consists of either a single writer
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index ac0398bd9..db3dfd096 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -19,7 +19,6 @@ import (
"math"
"path"
"strings"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
index 25573e986..4cad55327 100644
--- a/pkg/sentry/fs/overlay.go
+++ b/pkg/sentry/fs/overlay.go
@@ -17,13 +17,12 @@ package fs
import (
"fmt"
"strings"
- "sync"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -199,7 +198,7 @@ type overlayEntry struct {
upper *Inode
// dirCacheMu protects dirCache.
- dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"`
+ dirCacheMu sync.DowngradableRWMutex `state:"nosave"`
// dirCache is cache of DentAttrs from upper and lower Inodes.
dirCache *SortedDentryMap
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 75cbb0622..cb37c6c6b 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -18,7 +18,6 @@ go_library(
"mounts.go",
"net.go",
"proc.go",
- "rpcinet_proc.go",
"stat.go",
"sys.go",
"sys_net.go",
@@ -46,11 +45,11 @@ go_library(
"//pkg/sentry/limits",
"//pkg/sentry/mm",
"//pkg/sentry/socket",
- "//pkg/sentry/socket/rpcinet",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip/header",
"//pkg/waiter",
diff --git a/pkg/sentry/fs/proc/cgroup.go b/pkg/sentry/fs/proc/cgroup.go
index 05e31c55d..c4abe319d 100644
--- a/pkg/sentry/fs/proc/cgroup.go
+++ b/pkg/sentry/fs/proc/cgroup.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
)
+// LINT.IfChange
+
func newCGroupInode(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string]string) *fs.Inode {
// From man 7 cgroups: "For each cgroup hierarchy of which the process
// is a member, there is one entry containing three colon-separated
@@ -39,3 +41,5 @@ func newCGroupInode(ctx context.Context, msrc *fs.MountSource, cgroupControllers
return newStaticProcInode(ctx, msrc, []byte(data))
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/cpuinfo.go b/pkg/sentry/fs/proc/cpuinfo.go
index 3edf36780..df0c4e3a7 100644
--- a/pkg/sentry/fs/proc/cpuinfo.go
+++ b/pkg/sentry/fs/proc/cpuinfo.go
@@ -15,11 +15,15 @@
package proc
import (
+ "bytes"
+
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
func newCPUInfo(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
k := kernel.KernelFromContext(ctx)
features := k.FeatureSet()
@@ -27,9 +31,11 @@ func newCPUInfo(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// Kernel is always initialized with a FeatureSet.
panic("cpuinfo read with nil FeatureSet")
}
- contents := make([]byte, 0, 1024)
+ var buf bytes.Buffer
for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
- contents = append(contents, []byte(features.CPUInfo(i))...)
+ features.WriteCPUInfoTo(i, &buf)
}
- return newStaticProcInode(ctx, msrc, contents)
+ return newStaticProcInode(ctx, msrc, buf.Bytes())
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
index 1d3a2d426..9aaeb780b 100644
--- a/pkg/sentry/fs/proc/exec_args.go
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -29,6 +29,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// execArgType enumerates the types of exec arguments that are exposed through
// proc.
type execArgType int
@@ -201,3 +203,5 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
}
return int64(n), err
}
+
+// LINT.ThenChange(../../fsimpl/proc/task.go)
diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go
index bee421d76..2fa3cfa7d 100644
--- a/pkg/sentry/fs/proc/fds.go
+++ b/pkg/sentry/fs/proc/fds.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// walkDescriptors finds the descriptor (file-flag pair) for the fd identified
// by p, and calls the toInodeOperations callback with that descriptor. This is a helper
// method for implementing fs.InodeOperations.Lookup.
@@ -277,3 +279,5 @@ func (fdid *fdInfoDir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.
}
return fs.NewFile(ctx, dirent, flags, fops), nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/filesystems.go b/pkg/sentry/fs/proc/filesystems.go
index e9250c51c..7b3b974ab 100644
--- a/pkg/sentry/fs/proc/filesystems.go
+++ b/pkg/sentry/fs/proc/filesystems.go
@@ -23,6 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
)
+// LINT.IfChange
+
// filesystemsData backs /proc/filesystems.
//
// +stateify savable
@@ -59,3 +61,5 @@ func (*filesystemsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle
// Return the SeqData and advance the generation counter.
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*filesystemsData)(nil)}}, 1
}
+
+// LINT.ThenChange(../../fsimpl/proc/filesystem.go)
diff --git a/pkg/sentry/fs/proc/fs.go b/pkg/sentry/fs/proc/fs.go
index f14833805..761d24462 100644
--- a/pkg/sentry/fs/proc/fs.go
+++ b/pkg/sentry/fs/proc/fs.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
)
+// LINT.IfChange
+
// filesystem is a procfs.
//
// +stateify savable
@@ -79,3 +81,5 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
// never want them cached.
return New(ctx, fs.NewNonCachingMountSource(ctx, f, flags), cgroups)
}
+
+// LINT.ThenChange(../../fsimpl/proc/filesystem.go)
diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go
index 0c04f81fa..723f6b661 100644
--- a/pkg/sentry/fs/proc/inode.go
+++ b/pkg/sentry/fs/proc/inode.go
@@ -26,6 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
+// LINT.IfChange
+
// taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr
// method to return either the task or root as the owner, depending on the
// task's dumpability.
@@ -131,3 +133,5 @@ func newProcInode(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
}
return fs.NewInode(ctx, iops, msrc, sattr)
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks.go)
diff --git a/pkg/sentry/fs/proc/loadavg.go b/pkg/sentry/fs/proc/loadavg.go
index 8602b7426..d7d2afcb7 100644
--- a/pkg/sentry/fs/proc/loadavg.go
+++ b/pkg/sentry/fs/proc/loadavg.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
)
+// LINT.IfChange
+
// loadavgData backs /proc/loadavg.
//
// +stateify savable
@@ -53,3 +55,5 @@ func (d *loadavgData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/meminfo.go b/pkg/sentry/fs/proc/meminfo.go
index 495f3e3ba..313c6a32b 100644
--- a/pkg/sentry/fs/proc/meminfo.go
+++ b/pkg/sentry/fs/proc/meminfo.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
+// LINT.IfChange
+
// meminfoData backs /proc/meminfo.
//
// +stateify savable
@@ -83,3 +85,5 @@ func (d *meminfoData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
fmt.Fprintf(&buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*meminfoData)(nil)}}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index e33c4a460..5aedae799 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// forEachMountSource runs f for the process root mount and each mount that is a
// descendant of the root.
func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
@@ -195,3 +197,5 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*mountsFile)(nil)}}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 402919924..3f17e98ea 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -38,6 +38,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+// LINT.IfChange
+
// newNet creates a new proc net entry.
func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
@@ -831,3 +833,5 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
}
return data, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_net.go)
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index 56e92721e..29867dc3a 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -27,10 +27,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// proc is a root proc node.
//
// +stateify savable
@@ -85,15 +86,9 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string
}
// Add more contents that need proc to be initialized.
+ p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc))
p.AddChild(ctx, "sys", p.newSysDir(ctx, msrc))
- // If we're using rpcinet we will let it manage /proc/net.
- if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok {
- p.AddChild(ctx, "net", newRPCInetProcNet(ctx, msrc))
- } else {
- p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc))
- }
-
return newProcInode(ctx, p, msrc, fs.SpecialDirectory, nil), nil
}
@@ -249,3 +244,5 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent
}
return offset, nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks.go)
diff --git a/pkg/sentry/fs/proc/rpcinet_proc.go b/pkg/sentry/fs/proc/rpcinet_proc.go
deleted file mode 100644
index 01ac97530..000000000
--- a/pkg/sentry/fs/proc/rpcinet_proc.go
+++ /dev/null
@@ -1,217 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "io"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// rpcInetInode implements fs.InodeOperations.
-type rpcInetInode struct {
- fsutil.SimpleFileInode
-
- // filepath is the full path of this rpcInetInode.
- filepath string
-
- k *kernel.Kernel
-}
-
-func newRPCInetInode(ctx context.Context, msrc *fs.MountSource, filepath string, mode linux.FileMode) *fs.Inode {
- f := &rpcInetInode{
- SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(mode), linux.PROC_SUPER_MAGIC),
- filepath: filepath,
- k: kernel.KernelFromContext(ctx),
- }
- return newProcInode(ctx, f, msrc, fs.SpecialFile, nil)
-}
-
-// GetFile implements fs.InodeOperations.GetFile.
-func (i *rpcInetInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
- flags.Pread = true
- flags.Pwrite = true
- fops := &rpcInetFile{
- inode: i,
- }
- return fs.NewFile(ctx, dirent, flags, fops), nil
-}
-
-// rpcInetFile implements fs.FileOperations as RPCs.
-type rpcInetFile struct {
- fsutil.FileGenericSeek `state:"nosave"`
- fsutil.FileNoIoctl `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileNoopFsync `state:"nosave"`
- fsutil.FileNoopRelease `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- waiter.AlwaysReady `state:"nosave"`
-
- inode *rpcInetInode
-}
-
-// Read implements fs.FileOperations.Read.
-//
-// This method can panic if an rpcInetInode was created without an rpcinet
-// stack.
-func (f *rpcInetFile) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- s, ok := f.inode.k.NetworkStack().(*rpcinet.Stack)
- if !ok {
- panic("Network stack is not a rpcinet.")
- }
-
- contents, se := s.RPCReadFile(f.inode.filepath)
- if se != nil || offset >= int64(len(contents)) {
- return 0, io.EOF
- }
-
- n, err := dst.CopyOut(ctx, contents[offset:])
- return int64(n), err
-}
-
-// Write implements fs.FileOperations.Write.
-//
-// This method can panic if an rpcInetInode was created without an rpcInet
-// stack.
-func (f *rpcInetFile) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
- s, ok := f.inode.k.NetworkStack().(*rpcinet.Stack)
- if !ok {
- panic("Network stack is not a rpcinet.")
- }
-
- if src.NumBytes() == 0 {
- return 0, nil
- }
-
- b := make([]byte, src.NumBytes(), src.NumBytes())
- n, err := src.CopyIn(ctx, b)
- if err != nil {
- return int64(n), err
- }
-
- written, se := s.RPCWriteFile(f.inode.filepath, b)
- return int64(written), se.ToError()
-}
-
-// newRPCInetProcNet will build an inode for /proc/net.
-func newRPCInetProcNet(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "arp": newRPCInetInode(ctx, msrc, "/proc/net/arp", 0444),
- "dev": newRPCInetInode(ctx, msrc, "/proc/net/dev", 0444),
- "if_inet6": newRPCInetInode(ctx, msrc, "/proc/net/if_inet6", 0444),
- "ipv6_route": newRPCInetInode(ctx, msrc, "/proc/net/ipv6_route", 0444),
- "netlink": newRPCInetInode(ctx, msrc, "/proc/net/netlink", 0444),
- "netstat": newRPCInetInode(ctx, msrc, "/proc/net/netstat", 0444),
- "packet": newRPCInetInode(ctx, msrc, "/proc/net/packet", 0444),
- "protocols": newRPCInetInode(ctx, msrc, "/proc/net/protocols", 0444),
- "psched": newRPCInetInode(ctx, msrc, "/proc/net/psched", 0444),
- "ptype": newRPCInetInode(ctx, msrc, "/proc/net/ptype", 0444),
- "route": newRPCInetInode(ctx, msrc, "/proc/net/route", 0444),
- "tcp": newRPCInetInode(ctx, msrc, "/proc/net/tcp", 0444),
- "tcp6": newRPCInetInode(ctx, msrc, "/proc/net/tcp6", 0444),
- "udp": newRPCInetInode(ctx, msrc, "/proc/net/udp", 0444),
- "udp6": newRPCInetInode(ctx, msrc, "/proc/net/udp6", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetProcSysNet will build an inode for /proc/sys/net.
-func newRPCInetProcSysNet(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "ipv4": newRPCInetSysNetIPv4Dir(ctx, msrc),
- "core": newRPCInetSysNetCore(ctx, msrc),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetSysNetCore builds the /proc/sys/net/core directory.
-func newRPCInetSysNetCore(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "default_qdisc": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/default_qdisc", 0444),
- "message_burst": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/message_burst", 0444),
- "message_cost": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/message_cost", 0444),
- "optmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/optmem_max", 0444),
- "rmem_default": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/rmem_default", 0444),
- "rmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/rmem_max", 0444),
- "somaxconn": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/somaxconn", 0444),
- "wmem_default": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/wmem_default", 0444),
- "wmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/wmem_max", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetSysNetIPv4Dir builds the /proc/sys/net/ipv4 directory.
-func newRPCInetSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "ip_local_port_range": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_local_port_range", 0444),
- "ip_local_reserved_ports": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_local_reserved_ports", 0444),
- "ipfrag_time": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ipfrag_time", 0444),
- "ip_nonlocal_bind": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_nonlocal_bind", 0444),
- "ip_no_pmtu_disc": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_no_pmtu_disc", 0444),
- "tcp_allowed_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_allowed_congestion_control", 0444),
- "tcp_available_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_available_congestion_control", 0444),
- "tcp_base_mss": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_base_mss", 0444),
- "tcp_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_congestion_control", 0644),
- "tcp_dsack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_dsack", 0644),
- "tcp_early_retrans": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_early_retrans", 0644),
- "tcp_fack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fack", 0644),
- "tcp_fastopen": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fastopen", 0644),
- "tcp_fastopen_key": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fastopen_key", 0444),
- "tcp_fin_timeout": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fin_timeout", 0644),
- "tcp_invalid_ratelimit": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_invalid_ratelimit", 0444),
- "tcp_keepalive_intvl": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_intvl", 0644),
- "tcp_keepalive_probes": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_probes", 0644),
- "tcp_keepalive_time": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_time", 0644),
- "tcp_mem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_mem", 0444),
- "tcp_mtu_probing": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_mtu_probing", 0644),
- "tcp_no_metrics_save": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_no_metrics_save", 0444),
- "tcp_probe_interval": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_probe_interval", 0444),
- "tcp_probe_threshold": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_probe_threshold", 0444),
- "tcp_retries1": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_retries1", 0644),
- "tcp_retries2": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_retries2", 0644),
- "tcp_rfc1337": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_rfc1337", 0444),
- "tcp_rmem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_rmem", 0444),
- "tcp_sack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_sack", 0644),
- "tcp_slow_start_after_idle": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_slow_start_after_idle", 0644),
- "tcp_synack_retries": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_synack_retries", 0644),
- "tcp_syn_retries": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_syn_retries", 0644),
- "tcp_timestamps": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_timestamps", 0644),
- "tcp_wmem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_wmem", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index fe7067be1..38b246dff 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/fs/proc/device",
"//pkg/sentry/kernel/time",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go
index 5fe823000..f9af191d5 100644
--- a/pkg/sentry/fs/proc/seqfile/seqfile.go
+++ b/pkg/sentry/fs/proc/seqfile/seqfile.go
@@ -17,7 +17,6 @@ package seqfile
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -26,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/proc/stat.go b/pkg/sentry/fs/proc/stat.go
index b641effbb..bc5b2bc7b 100644
--- a/pkg/sentry/fs/proc/stat.go
+++ b/pkg/sentry/fs/proc/stat.go
@@ -24,6 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// statData backs /proc/stat.
//
// +stateify savable
@@ -140,3 +142,5 @@ func (s *statData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index cd37776c8..2bdcf5f70 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -26,11 +26,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// mmapMinAddrData backs /proc/sys/vm/mmap_min_addr.
//
// +stateify savable
@@ -104,16 +105,10 @@ func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
func (p *proc) newSysDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
children := map[string]*fs.Inode{
"kernel": p.newKernelDir(ctx, msrc),
+ "net": p.newSysNetDir(ctx, msrc),
"vm": p.newVMDir(ctx, msrc),
}
- // If we're using rpcinet we will let it manage /proc/sys/net.
- if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok {
- children["net"] = newRPCInetProcSysNet(ctx, msrc)
- } else {
- children["net"] = p.newSysNetDir(ctx, msrc)
- }
-
d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
@@ -160,3 +155,5 @@ func (hf *hostnameFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ
}
var _ fs.FileOperations = (*hostnameFile)(nil)
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_sys.go)
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index bd93f83fa..b9e8ef35f 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -17,7 +17,6 @@ package proc
import (
"fmt"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -27,9 +26,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
type tcpMemDir int
const (
@@ -364,3 +366,5 @@ func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_sys.go)
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 9bf4b4527..7358d6ef9 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -37,6 +37,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's
// users count is incremented, and must be decremented by the caller when it is
// no longer in use.
@@ -800,3 +802,5 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
n, err := dst.CopyOut(ctx, buf[offset:])
return int64(n), err
}
+
+// LINT.ThenChange(../../fsimpl/proc/task.go|../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go
index eea37d15c..3eacc9265 100644
--- a/pkg/sentry/fs/proc/uid_gid_map.go
+++ b/pkg/sentry/fs/proc/uid_gid_map.go
@@ -30,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// idMapInodeOperations implements fs.InodeOperations for
// /proc/[pid]/{uid,gid}_map.
//
@@ -177,3 +179,5 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u
// count, even if fewer bytes were used.
return int64(srclen), nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/uptime.go b/pkg/sentry/fs/proc/uptime.go
index 4e903917a..adfe58adb 100644
--- a/pkg/sentry/fs/proc/uptime.go
+++ b/pkg/sentry/fs/proc/uptime.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// uptime is a file containing the system uptime.
//
// +stateify savable
@@ -85,3 +87,5 @@ func (f *uptimeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
n, err := dst.CopyOut(ctx, s[offset:])
return int64(n), err
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/version.go b/pkg/sentry/fs/proc/version.go
index a6d2c3cd3..27fd5b1cb 100644
--- a/pkg/sentry/fs/proc/version.go
+++ b/pkg/sentry/fs/proc/version.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// versionData backs /proc/version.
//
// +stateify savable
@@ -76,3 +78,5 @@ func (v *versionData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index 012cb3e44..3fb7b0633 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go
index 78e082b8e..dcbb8eb2e 100644
--- a/pkg/sentry/fs/ramfs/dir.go
+++ b/pkg/sentry/fs/ramfs/dir.go
@@ -17,7 +17,6 @@ package ramfs
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go
index f10168125..64c6a6ae9 100644
--- a/pkg/sentry/fs/restore.go
+++ b/pkg/sentry/fs/restore.go
@@ -15,7 +15,7 @@
package fs
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// RestoreEnvironment is the restore environment for file systems. It consists
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index 311798811..389c330a0 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -167,6 +167,11 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
if !srcPipe && !opts.SrcOffset {
atomic.StoreInt64(&src.offset, src.offset+n)
}
+
+ // Don't report any errors if we have some progress without data loss.
+ if w.Err == nil {
+ err = nil
+ }
}
// Drop locks.
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 59ce400c2..3400b940c 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index f86dfaa36..f1c87fe41 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -17,7 +17,6 @@ package tmpfs
import (
"fmt"
"io"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 69089c8a8..0f718e236 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -148,19 +148,19 @@ func (d *Dir) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perms
return d.ramfsDir.CreateFifo(ctx, dir, name, perms)
}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (d *Dir) Getxattr(i *fs.Inode, name string) (string, error) {
- return d.ramfsDir.Getxattr(i, name)
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (d *Dir) GetXattr(ctx context.Context, i *fs.Inode, name string, size uint64) (string, error) {
+ return d.ramfsDir.GetXattr(ctx, i, name, size)
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (d *Dir) Setxattr(i *fs.Inode, name, value string) error {
- return d.ramfsDir.Setxattr(i, name, value)
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (d *Dir) SetXattr(ctx context.Context, i *fs.Inode, name, value string, flags uint32) error {
+ return d.ramfsDir.SetXattr(ctx, i, name, value, flags)
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (d *Dir) Listxattr(i *fs.Inode) (map[string]struct{}, error) {
- return d.ramfsDir.Listxattr(i)
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (d *Dir) ListXattr(ctx context.Context, i *fs.Inode) (map[string]struct{}, error) {
+ return d.ramfsDir.ListXattr(ctx, i)
}
// Lookup implements fs.InodeOperations.Lookup.
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 95ad98cb0..f6f60d0cf 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 2f639c823..88aa66b24 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -19,7 +19,6 @@ import (
"fmt"
"math"
"strconv"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
index 7cc0eb409..894964260 100644
--- a/pkg/sentry/fs/tty/line_discipline.go
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -16,13 +16,13 @@ package tty
import (
"bytes"
- "sync"
"unicode/utf8"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
index 231e4e6eb..8b5d4699a 100644
--- a/pkg/sentry/fs/tty/queue.go
+++ b/pkg/sentry/fs/tty/queue.go
@@ -15,13 +15,12 @@
package tty
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index bc90330bc..903874141 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -50,6 +50,7 @@ go_library(
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index 2f46d2d13..a56b03711 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -50,7 +50,9 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{})
+ vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
if err != nil {
f.Close()
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 91802dc1e..8944171c8 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -15,8 +15,6 @@
package ext
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/log"
@@ -25,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 567523d32..4110649ab 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -29,8 +29,12 @@ package disklayout
// byte (i * sb.BlockSize()) to ((i+1) * sb.BlockSize()).
const (
- // ExtentStructsSize is the size of all the three extent on-disk structs.
- ExtentStructsSize = 12
+ // ExtentHeaderSize is the size of the header of an extent tree node.
+ ExtentHeaderSize = 12
+
+ // ExtentEntrySize is the size of an entry in an extent tree node.
+ // This size is the same for both leaf and internal nodes.
+ ExtentEntrySize = 12
// ExtentMagic is the magic number which must be present in the header.
ExtentMagic = 0xf30a
@@ -57,7 +61,7 @@ type ExtentNode struct {
Entries []ExtentEntryPair
}
-// ExtentEntry reprsents an extent tree node entry. The entry can either be
+// ExtentEntry represents an extent tree node entry. The entry can either be
// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
type ExtentEntry interface {
// FileBlock returns the first file block number covered by this entry.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index b0fad9b71..8762b90db 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -21,7 +21,7 @@ import (
// TestExtentSize tests that the extent structs are of the correct
// size.
func TestExtentSize(t *testing.T) {
- assertSize(t, ExtentHeader{}, ExtentStructsSize)
- assertSize(t, ExtentIdx{}, ExtentStructsSize)
- assertSize(t, Extent{}, ExtentStructsSize)
+ assertSize(t, ExtentHeader{}, ExtentHeaderSize)
+ assertSize(t, ExtentIdx{}, ExtentEntrySize)
+ assertSize(t, Extent{}, ExtentEntrySize)
}
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 5d6c999bd..6c14a1e2d 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -66,7 +66,9 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{})
+ vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
if err != nil {
f.Close()
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 3d3ebaca6..11dcc0346 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -57,7 +57,7 @@ func newExtentFile(regFile regularFile) (*extentFile, error) {
func (f *extentFile) buildExtTree() error {
rootNodeData := f.regFile.inode.diskInode.Data()
- binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &f.root.Header)
+ binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
// Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
if f.root.Header.NumEntries > 4 {
@@ -67,7 +67,7 @@ func (f *extentFile) buildExtTree() error {
}
f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries)
- for i, off := uint16(0), disklayout.ExtentStructsSize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ for i, off := uint16(0), disklayout.ExtentEntrySize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
var curEntry disklayout.ExtentEntry
if f.root.Header.Height == 0 {
// Leaf node.
@@ -76,7 +76,7 @@ func (f *extentFile) buildExtTree() error {
// Internal node.
curEntry = &disklayout.ExtentIdx{}
}
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry)
+ binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
f.root.Entries[i].Entry = curEntry
}
@@ -105,7 +105,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla
}
entries := make([]disklayout.ExtentEntryPair, header.NumEntries)
- for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ for i, off := uint16(0), off+disklayout.ExtentEntrySize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
var curEntry disklayout.ExtentEntry
if header.Height == 0 {
// Leaf node.
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 616fc002a..9afb1a84c 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -17,13 +17,13 @@ package ext
import (
"errors"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index b2cc826c7..8608805bf 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -157,8 +157,6 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
switch in.impl.(type) {
case *regularFile:
var fd regularFileFD
- mnt.IncRef()
- vfsd.IncRef()
fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
case *directory:
@@ -168,8 +166,6 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.EISDIR
}
var fd directoryFD
- mnt.IncRef()
- vfsd.IncRef()
fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
case *symlink:
@@ -178,8 +174,6 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.ELOOP
}
var fd symlinkFD
- mnt.IncRef()
- vfsd.IncRef()
fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
default:
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index aec33e00a..d11153c90 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -16,7 +16,6 @@ package ext
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 59f7f39e2..809178250 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -25,6 +25,7 @@ go_library(
"inode_impl_util.go",
"kernfs.go",
"slot_list.go",
+ "symlink.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs",
visibility = ["//pkg/sentry:internal"],
@@ -38,6 +39,7 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
],
)
@@ -55,6 +57,7 @@ go_test(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"@com_github_google_go-cmp//cmp:go_default_library",
],
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 51102ce48..606ca692d 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -15,6 +15,8 @@
package kernfs
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -26,7 +28,10 @@ import (
// DynamicBytesFile implements kernfs.Inode and represents a read-only
// file whose contents are backed by a vfs.DynamicBytesSource.
//
-// Must be initialized with Init before first use.
+// Must be instantiated with NewDynamicBytesFile or initialized with Init
+// before first use.
+//
+// +stateify savable
type DynamicBytesFile struct {
InodeAttrs
InodeNoopRefCount
@@ -36,9 +41,14 @@ type DynamicBytesFile struct {
data vfs.DynamicBytesSource
}
-// Init intializes a dynamic bytes file.
-func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource) {
- f.InodeAttrs.Init(creds, ino, linux.ModeRegular|0444)
+var _ Inode = (*DynamicBytesFile)(nil)
+
+// Init initializes a dynamic bytes file.
+func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ f.InodeAttrs.Init(creds, ino, linux.ModeRegular|perm)
f.data = data
}
@@ -59,6 +69,8 @@ func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error {
// DynamicBytesFile.
//
// Must be initialized with Init before first use.
+//
+// +stateify savable
type DynamicBytesFD struct {
vfs.FileDescriptionDefaultImpl
vfs.DynamicBytesFileDescriptionImpl
@@ -69,8 +81,6 @@ type DynamicBytesFD struct {
// Init initializes a DynamicBytesFD.
func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) {
- m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
- d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
fd.inode = d.Impl().(*Dentry).inode
fd.SetDataSource(data)
fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{})
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index bd402330f..bcf069b5f 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -44,8 +44,6 @@ type GenericDirectoryFD struct {
// Init initializes a GenericDirectoryFD.
func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) {
- m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
- d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
fd.children = children
fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{})
}
@@ -154,7 +152,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
fd.off++
}
- return nil
+ var err error
+ relOffset := fd.off - int64(len(fd.children.set)) - 2
+ fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset)
+ return err
}
// Seek implements vfs.FileDecriptionImpl.Seek.
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 79759e0fc..a4600ad47 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -40,7 +39,7 @@ func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingP
return nil, syserror.ENOTDIR
}
// Directory searchable?
- if err := d.inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
afterSymlink:
@@ -182,8 +181,8 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
//
// Preconditions: Filesystem.mu must be locked for at least reading. parentInode
// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true.
-func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) {
- if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) {
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return "", err
}
pc := rp.Component()
@@ -206,7 +205,7 @@ func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInod
// checkDeleteLocked checks that the file represented by vfsd may be deleted.
//
// Preconditions: Filesystem.mu must be locked for at least reading.
-func checkDeleteLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
+func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
parentVFSD := vfsd.Parent()
if parentVFSD == nil {
return syserror.EBUSY
@@ -214,36 +213,12 @@ func checkDeleteLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
if parentVFSD.IsDisowned() {
return syserror.ENOENT
}
- if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
return nil
}
-// checkRenameLocked checks that a rename operation may be performed on the
-// target dentry across the given set of parent directories. The target dentry
-// may be nil.
-//
-// Precondition: isDir(dstInode) == true.
-func checkRenameLocked(creds *auth.Credentials, src, dstDir *vfs.Dentry, dstInode Inode) error {
- srcDir := src.Parent()
- if srcDir == nil {
- return syserror.EBUSY
- }
- if srcDir.IsDisowned() {
- return syserror.ENOENT
- }
- if dstDir.IsDisowned() {
- return syserror.ENOENT
- }
- // Check for creation permissions on dst dir.
- if err := dstInode.CheckPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
- return err
- }
-
- return nil
-}
-
// Release implements vfs.FilesystemImpl.Release.
func (fs *Filesystem) Release() {
}
@@ -269,7 +244,7 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
@@ -302,7 +277,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
if err != nil {
return err
}
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -339,7 +314,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if err != nil {
return err
}
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -367,7 +342,7 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if err != nil {
return err
}
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -401,7 +376,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return nil, err
}
- if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
return inode.Open(rp, vfsd, opts.Flags)
@@ -420,7 +395,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
- if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
return inode.Open(rp, vfsd, opts.Flags)
@@ -432,7 +407,7 @@ afterTrailingSymlink:
return nil, err
}
// Check for search permission in the parent directory.
- if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Reject attempts to open directories with O_CREAT.
@@ -450,7 +425,7 @@ afterTrailingSymlink:
}
if childVFSD == nil {
// Already checked for searchability above; now check for writability.
- if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -485,7 +460,7 @@ afterTrailingSymlink:
goto afterTrailingSymlink
}
}
- if err := childInode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ if err := childInode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
return childInode.Open(rp, childVFSD, opts.Flags)
@@ -545,13 +520,13 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
srcVFSD := &src.vfsd
// Can we remove the src dentry?
- if err := checkDeleteLocked(rp, srcVFSD); err != nil {
+ if err := checkDeleteLocked(ctx, rp, srcVFSD); err != nil {
return err
}
// Can we create the dst dentry?
var dstVFSD *vfs.Dentry
- pc, err := checkCreateLocked(rp, dstDirVFSD, dstDirInode)
+ pc, err := checkCreateLocked(ctx, rp, dstDirVFSD, dstDirInode)
switch err {
case nil:
// Ok, continue with rename as replacement.
@@ -607,7 +582,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(rp, vfsd); err != nil {
+ if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
return err
}
if !vfsd.Impl().(*Dentry).isDir() {
@@ -683,7 +658,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
if err != nil {
return err
}
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -712,7 +687,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(rp, vfsd); err != nil {
+ if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
return err
}
if vfsd.Impl().(*Dentry).isDir() {
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 7b45b702a..1700fffd9 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -16,7 +16,6 @@ package kernfs
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -139,6 +139,11 @@ func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry,
panic("Lookup called on non-directory inode")
}
+// IterDirents implements Inode.IterDirents.
+func (*InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ panic("IterDirents called on non-directory inode")
+}
+
// Valid implements Inode.Valid.
func (*InodeNotDirectory) Valid(context.Context) bool {
return true
@@ -156,6 +161,11 @@ func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dent
return nil, syserror.ENOENT
}
+// IterDirents implements Inode.IterDirents.
+func (*InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return offset, nil
+}
+
// Valid implements Inode.Valid.
func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool {
return true
@@ -252,7 +262,7 @@ func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
}
// CheckPermissions implements Inode.CheckPermissions.
-func (a *InodeAttrs) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
mode := a.Mode()
return vfs.GenericCheckPermissions(
creds,
@@ -490,3 +500,57 @@ func (o *OrderedChildren) nthLocked(i int64) *slot {
}
return nil
}
+
+// InodeSymlink partially implements Inode interface for symlinks.
+type InodeSymlink struct {
+ InodeNotDirectory
+}
+
+// Open implements Inode.Open.
+func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ return nil, syserror.ELOOP
+}
+
+// StaticDirectory is a standard implementation of a directory with static
+// contents.
+//
+// +stateify savable
+type StaticDirectory struct {
+ InodeNotSymlink
+ InodeDirectoryNoNewChildren
+ InodeAttrs
+ InodeNoDynamicLookup
+ OrderedChildren
+}
+
+var _ Inode = (*StaticDirectory)(nil)
+
+// NewStaticDir creates a new static directory and returns its dentry.
+func NewStaticDir(creds *auth.Credentials, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry {
+ inode := &StaticDirectory{}
+ inode.Init(creds, ino, perm)
+
+ dentry := &Dentry{}
+ dentry.Init(inode)
+
+ inode.OrderedChildren.Init(OrderedChildrenOptions{})
+ links := inode.OrderedChildren.Populate(dentry, children)
+ inode.IncLinks(links)
+
+ return dentry
+}
+
+// Init initializes StaticDirectory.
+func (s *StaticDirectory) Init(creds *auth.Credentials, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ s.InodeAttrs.Init(creds, ino, linux.ModeDirectory|perm)
+}
+
+// Open implements kernfs.Inode.
+func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index ac802218d..85bcdcc57 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -53,7 +53,6 @@ package kernfs
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -61,6 +60,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FilesystemType implements vfs.FilesystemType.
@@ -320,7 +320,7 @@ type inodeMetadata interface {
// CheckPermissions checks that creds may access this inode for the
// requested access type, per the the rules of
// fs/namei.c:generic_permission().
- CheckPermissions(creds *auth.Credentials, atx vfs.AccessTypes) error
+ CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error
// Mode returns the (struct stat)::st_mode value for this inode. This is
// separated from Stat for performance.
@@ -404,6 +404,15 @@ type inodeDynamicLookup interface {
// Valid should return true if this inode is still valid, or needs to
// be resolved again by a call to Lookup.
Valid(ctx context.Context) bool
+
+ // IterDirents is used to iterate over dynamically created entries. It invokes
+ // cb on each entry in the directory represented by the FileDescription.
+ // 'offset' is the offset for the entire IterDirents call, which may include
+ // results from the caller. 'relOffset' is the offset inside the entries
+ // returned by this IterDirents invocation. In other words,
+ // 'offset+relOffset+1' is the value that should be set in vfs.Dirent.NextOff,
+ // while 'relOffset' is the place where iteration should start from.
+ IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
}
type inodeSymlink interface {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 73b6e43b5..5c9d580e1 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -19,7 +19,6 @@ import (
"fmt"
"io"
"runtime"
- "sync"
"testing"
"github.com/google/go-cmp/cmp"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -59,7 +59,9 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem {
ctx := contexttest.Context(t)
creds := auth.CredentialsFromContext(ctx)
v := vfs.New()
- v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn})
+ v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{})
if err != nil {
t.Fatalf("Failed to create testfs root mount: %v", err)
@@ -133,7 +135,7 @@ type file struct {
func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry {
f := &file{}
f.content = content
- f.DynamicBytesFile.Init(creds, fs.NextIno(), f)
+ f.DynamicBytesFile.Init(creds, fs.NextIno(), f, 0777)
d := &kernfs.Dentry{}
d.Init(f)
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
new file mode 100644
index 000000000..f19f12854
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -0,0 +1,54 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// StaticSymlink provides an Inode implementation for symlinks that point to
+// a immutable target.
+type StaticSymlink struct {
+ InodeAttrs
+ InodeNoopRefCount
+ InodeSymlink
+
+ target string
+}
+
+var _ Inode = (*StaticSymlink)(nil)
+
+// NewStaticSymlink creates a new symlink file pointing to 'target'.
+func NewStaticSymlink(creds *auth.Credentials, ino uint64, target string) *Dentry {
+ inode := &StaticSymlink{}
+ inode.Init(creds, ino, target)
+
+ d := &Dentry{}
+ d.Init(inode)
+ return d
+}
+
+// Init initializes the instance.
+func (s *StaticSymlink) Init(creds *auth.Credentials, ino uint64, target string) {
+ s.target = target
+ s.InodeAttrs.Init(creds, ino, linux.ModeSymlink|0777)
+}
+
+// Readlink implements Inode.
+func (s *StaticSymlink) Readlink(_ context.Context) (string, error) {
+ return s.target, nil
+}
diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go
deleted file mode 100644
index b7f4853b3..000000000
--- a/pkg/sentry/fsimpl/memfs/regular_file.go
+++ /dev/null
@@ -1,154 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package memfs
-
-import (
- "io"
- "sync"
- "sync/atomic"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-type regularFile struct {
- inode inode
-
- mu sync.RWMutex
- data []byte
- // dataLen is len(data), but accessed using atomic memory operations to
- // avoid locking in inode.stat().
- dataLen int64
-}
-
-func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
- file := &regularFile{}
- file.inode.init(file, fs, creds, mode)
- file.inode.nlink = 1 // from parent directory
- return &file.inode
-}
-
-type regularFileFD struct {
- fileDescription
-
- // These are immutable.
- readable bool
- writable bool
-
- // off is the file offset. off is accessed using atomic memory operations.
- // offMu serializes operations that may mutate off.
- off int64
- offMu sync.Mutex
-}
-
-// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *regularFileFD) Release() {
- if fd.writable {
- fd.vfsfd.VirtualDentry().Mount().EndWrite()
- }
-}
-
-// PRead implements vfs.FileDescriptionImpl.PRead.
-func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
- if !fd.readable {
- return 0, syserror.EINVAL
- }
- f := fd.inode().impl.(*regularFile)
- f.mu.RLock()
- if offset >= int64(len(f.data)) {
- f.mu.RUnlock()
- return 0, io.EOF
- }
- n, err := dst.CopyOut(ctx, f.data[offset:])
- f.mu.RUnlock()
- return int64(n), err
-}
-
-// Read implements vfs.FileDescriptionImpl.Read.
-func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- fd.offMu.Lock()
- n, err := fd.PRead(ctx, dst, fd.off, opts)
- fd.off += n
- fd.offMu.Unlock()
- return n, err
-}
-
-// PWrite implements vfs.FileDescriptionImpl.PWrite.
-func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- if !fd.writable {
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- srclen := src.NumBytes()
- if srclen == 0 {
- return 0, nil
- }
- f := fd.inode().impl.(*regularFile)
- f.mu.Lock()
- end := offset + srclen
- if end < offset {
- // Overflow.
- f.mu.Unlock()
- return 0, syserror.EFBIG
- }
- if end > f.dataLen {
- f.data = append(f.data, make([]byte, end-f.dataLen)...)
- atomic.StoreInt64(&f.dataLen, end)
- }
- n, err := src.CopyIn(ctx, f.data[offset:end])
- f.mu.Unlock()
- return int64(n), err
-}
-
-// Write implements vfs.FileDescriptionImpl.Write.
-func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
- fd.offMu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
- fd.offMu.Unlock()
- return n, err
-}
-
-// Seek implements vfs.FileDescriptionImpl.Seek.
-func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
- fd.offMu.Lock()
- defer fd.offMu.Unlock()
- switch whence {
- case linux.SEEK_SET:
- // use offset as specified
- case linux.SEEK_CUR:
- offset += fd.off
- case linux.SEEK_END:
- offset += atomic.LoadInt64(&fd.inode().impl.(*regularFile).dataLen)
- default:
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- fd.off = offset
- return offset, nil
-}
-
-// Sync implements vfs.FileDescriptionImpl.Sync.
-func (fd *regularFileFD) Sync(ctx context.Context) error {
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index ade6ac946..e92564b5d 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -6,16 +6,14 @@ package(licenses = ["notice"])
go_library(
name = "proc",
srcs = [
- "filesystems.go",
- "loadavg.go",
- "meminfo.go",
- "mounts.go",
- "net.go",
- "proc.go",
- "stat.go",
- "sys.go",
+ "filesystem.go",
+ "subtasks.go",
"task.go",
- "version.go",
+ "task_files.go",
+ "tasks.go",
+ "tasks_files.go",
+ "tasks_net.go",
+ "tasks_sys.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc",
deps = [
@@ -24,27 +22,54 @@ go_library(
"//pkg/log",
"//pkg/sentry/context",
"//pkg/sentry/fs",
+ "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/mm",
+ "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/syserror",
],
)
go_test(
name = "proc_test",
size = "small",
- srcs = ["net_test.go"],
+ srcs = [
+ "boot_test.go",
+ "tasks_sys_test.go",
+ "tasks_test.go",
+ ],
embed = [":proc"],
deps = [
"//pkg/abi/linux",
+ "//pkg/cpuid",
+ "//pkg/fspath",
+ "//pkg/memutil",
+ "//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
+ "//pkg/sentry/fs",
"//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/sched",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/platform/kvm",
+ "//pkg/sentry/platform/ptrace",
+ "//pkg/sentry/time",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
],
)
diff --git a/pkg/sentry/fsimpl/proc/boot_test.go b/pkg/sentry/fsimpl/proc/boot_test.go
new file mode 100644
index 000000000..84a93ee56
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/boot_test.go
@@ -0,0 +1,149 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/time"
+
+ // Platforms are plugable.
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/kvm"
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace"
+)
+
+var (
+ platformFlag = flag.String("platform", "ptrace", "specify which platform to use")
+)
+
+// boot initializes a new bare bones kernel for test.
+func boot() (*kernel.Kernel, error) {
+ platformCtr, err := platform.Lookup(*platformFlag)
+ if err != nil {
+ return nil, fmt.Errorf("platform not found: %v", err)
+ }
+ deviceFile, err := platformCtr.OpenDevice()
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+ plat, err := platformCtr.New(deviceFile)
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+
+ k := &kernel.Kernel{
+ Platform: plat,
+ }
+
+ mf, err := createMemoryFile()
+ if err != nil {
+ return nil, err
+ }
+ k.SetMemoryFile(mf)
+
+ // Pass k as the platform since it is savable, unlike the actual platform.
+ vdso, err := loader.PrepareVDSO(nil, k)
+ if err != nil {
+ return nil, fmt.Errorf("creating vdso: %v", err)
+ }
+
+ // Create timekeeper.
+ tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange())
+ if err != nil {
+ return nil, fmt.Errorf("creating timekeeper: %v", err)
+ }
+ tk.SetClocks(time.NewCalibratedClocks())
+
+ creds := auth.NewRootCredentials(auth.NewRootUserNamespace())
+
+ // Initiate the Kernel object, which is required by the Context passed
+ // to createVFS in order to mount (among other things) procfs.
+ if err = k.Init(kernel.InitKernelArgs{
+ ApplicationCores: uint(runtime.GOMAXPROCS(-1)),
+ FeatureSet: cpuid.HostFeatureSet(),
+ Timekeeper: tk,
+ RootUserNamespace: creds.UserNamespace,
+ Vdso: vdso,
+ RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace),
+ RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace),
+ RootAbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace),
+ }); err != nil {
+ return nil, fmt.Errorf("initializing kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+
+ // Create mount namespace without root as it's the minimum required to create
+ // the global thread group.
+ mntns, err := fs.NewMountNamespace(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ ls, err := limits.NewLinuxLimitSet()
+ if err != nil {
+ return nil, err
+ }
+ tg := k.NewThreadGroup(mntns, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls)
+ k.TestOnly_SetGlobalInit(tg)
+
+ return k, nil
+}
+
+// createTask creates a new bare bones task for tests.
+func createTask(ctx context.Context, name string, tc *kernel.ThreadGroup) (*kernel.Task, error) {
+ k := kernel.KernelFromContext(ctx)
+ config := &kernel.TaskConfig{
+ Kernel: k,
+ ThreadGroup: tc,
+ TaskContext: &kernel.TaskContext{Name: name},
+ Credentials: auth.CredentialsFromContext(ctx),
+ AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
+ UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
+ IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
+ AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ }
+ return k.TaskSet().NewTask(config)
+}
+
+func createMemoryFile() (*pgalloc.MemoryFile, error) {
+ const memfileName = "test-memory"
+ memfd, err := memutil.CreateMemFD(memfileName, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating memfd: %v", err)
+ }
+ memfile := os.NewFile(uintptr(memfd), memfileName)
+ mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{})
+ if err != nil {
+ memfile.Close()
+ return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err)
+ }
+ return mf, nil
+}
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
new file mode 100644
index 000000000..e9cb7895f
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -0,0 +1,80 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package proc implements a partial in-memory file system for procfs.
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// procFSType is the factory class for procfs.
+//
+// +stateify savable
+type procFSType struct{}
+
+var _ vfs.FilesystemType = (*procFSType)(nil)
+
+// GetFilesystem implements vfs.FilesystemType.
+func (ft *procFSType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, nil, fmt.Errorf("procfs requires a kernel")
+ }
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return nil, nil, fmt.Errorf("procfs requires a PID namespace")
+ }
+
+ procfs := &kernfs.Filesystem{}
+ procfs.VFSFilesystem().Init(vfsObj, procfs)
+
+ _, dentry := newTasksInode(procfs, k, pidns)
+ return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
+}
+
+// dynamicInode is an overfitted interface for common Inodes with
+// dynamicByteSource types used in procfs.
+type dynamicInode interface {
+ kernfs.Inode
+ vfs.DynamicBytesSource
+
+ Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
+}
+
+func newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+ inode.Init(creds, ino, inode, perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+type staticFile struct {
+ kernfs.DynamicBytesFile
+ vfs.StaticData
+}
+
+var _ dynamicInode = (*staticFile)(nil)
+
+func newStaticFile(data string) *staticFile {
+ return &staticFile{StaticData: vfs.StaticData{Data: data}}
+}
diff --git a/pkg/sentry/fsimpl/proc/loadavg.go b/pkg/sentry/fsimpl/proc/loadavg.go
deleted file mode 100644
index 9135afef1..000000000
--- a/pkg/sentry/fsimpl/proc/loadavg.go
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// loadavgData backs /proc/loadavg.
-//
-// +stateify savable
-type loadavgData struct{}
-
-var _ vfs.DynamicBytesSource = (*loadavgData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- // TODO(b/62345059): Include real data in fields.
- // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
- // Column 4-5: currently running processes and the total number of processes.
- // Column 6: the last process ID used.
- fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/meminfo.go b/pkg/sentry/fsimpl/proc/meminfo.go
deleted file mode 100644
index 9a827cd66..000000000
--- a/pkg/sentry/fsimpl/proc/meminfo.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
-//
-// +stateify savable
-type meminfoData struct {
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*meminfoData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- mf := d.k.MemoryFile()
- mf.UpdateUsage()
- snapshot, totalUsage := usage.MemoryAccounting.Copy()
- totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
- anon := snapshot.Anonymous + snapshot.Tmpfs
- file := snapshot.PageCache + snapshot.Mapped
- // We don't actually have active/inactive LRUs, so just make up numbers.
- activeFile := (file / 2) &^ (usermem.PageSize - 1)
- inactiveFile := file - activeFile
-
- fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
- memFree := (totalSize - totalUsage) / 1024
- // We use MemFree as MemAvailable because we don't swap.
- // TODO(rahat): When reclaim is implemented the value of MemAvailable
- // should change.
- fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree)
- fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree)
- fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices
- fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
- // Emulate a system with no swap, which disables inactivation of anon pages.
- fmt.Fprintf(buf, "SwapCache: 0 kB\n")
- fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024)
- fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024)
- fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024)
- fmt.Fprintf(buf, "Inactive(anon): 0 kB\n")
- fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024)
- fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
- fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
- fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
- fmt.Fprintf(buf, "SwapTotal: 0 kB\n")
- fmt.Fprintf(buf, "SwapFree: 0 kB\n")
- fmt.Fprintf(buf, "Dirty: 0 kB\n")
- fmt.Fprintf(buf, "Writeback: 0 kB\n")
- fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024)
- fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
- fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go
deleted file mode 100644
index 8683cf677..000000000
--- a/pkg/sentry/fsimpl/proc/mounts.go
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import "gvisor.dev/gvisor/pkg/sentry/kernel"
-
-// TODO(gvisor.dev/issue/1195): Implement mountInfoFile and mountsFile.
-
-// mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo.
-//
-// +stateify savable
-type mountInfoFile struct {
- t *kernel.Task
-}
-
-// mountsFile implements vfs.DynamicBytesSource for /proc/[pid]/mounts.
-//
-// +stateify savable
-type mountsFile struct {
- t *kernel.Task
-}
diff --git a/pkg/sentry/fsimpl/proc/stat.go b/pkg/sentry/fsimpl/proc/stat.go
deleted file mode 100644
index 720db3828..000000000
--- a/pkg/sentry/fsimpl/proc/stat.go
+++ /dev/null
@@ -1,127 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// cpuStats contains the breakdown of CPU time for /proc/stat.
-type cpuStats struct {
- // user is time spent in userspace tasks with non-positive niceness.
- user uint64
-
- // nice is time spent in userspace tasks with positive niceness.
- nice uint64
-
- // system is time spent in non-interrupt kernel context.
- system uint64
-
- // idle is time spent idle.
- idle uint64
-
- // ioWait is time spent waiting for IO.
- ioWait uint64
-
- // irq is time spent in interrupt context.
- irq uint64
-
- // softirq is time spent in software interrupt context.
- softirq uint64
-
- // steal is involuntary wait time.
- steal uint64
-
- // guest is time spent in guests with non-positive niceness.
- guest uint64
-
- // guestNice is time spent in guests with positive niceness.
- guestNice uint64
-}
-
-// String implements fmt.Stringer.
-func (c cpuStats) String() string {
- return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
-}
-
-// statData implements vfs.DynamicBytesSource for /proc/stat.
-//
-// +stateify savable
-type statData struct {
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*statData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- // TODO(b/37226836): We currently export only zero CPU stats. We could
- // at least provide some aggregate stats.
- var cpu cpuStats
- fmt.Fprintf(buf, "cpu %s\n", cpu)
-
- for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
- fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
- }
-
- // The total number of interrupts is dependent on the CPUs and PCI
- // devices on the system. See arch_probe_nr_irqs.
- //
- // Since we don't report real interrupt stats, just choose an arbitrary
- // value from a representative VM.
- const numInterrupts = 256
-
- // The Kernel doesn't handle real interrupts, so report all zeroes.
- // TODO(b/37226836): We could count page faults as #PF.
- fmt.Fprintf(buf, "intr 0") // total
- for i := 0; i < numInterrupts; i++ {
- fmt.Fprintf(buf, " 0")
- }
- fmt.Fprintf(buf, "\n")
-
- // Total number of context switches.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "ctxt 0\n")
-
- // CLOCK_REALTIME timestamp from boot, in seconds.
- fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
-
- // Total number of clones.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "processes 0\n")
-
- // Number of runnable tasks.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "procs_running 0\n")
-
- // Number of tasks waiting on IO.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "procs_blocked 0\n")
-
- // Number of each softirq handled.
- fmt.Fprintf(buf, "softirq 0") // total
- for i := 0; i < linux.NumSoftIRQ; i++ {
- fmt.Fprintf(buf, " 0")
- }
- fmt.Fprintf(buf, "\n")
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
new file mode 100644
index 000000000..8892c5a11
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -0,0 +1,126 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// subtasksInode represents the inode for /proc/[pid]/task/ directory.
+//
+// +stateify savable
+type subtasksInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+ inoGen InoGenerator
+}
+
+var _ kernfs.Inode = (*subtasksInode)(nil)
+
+func newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, inoGen InoGenerator) *kernfs.Dentry {
+ subInode := &subtasksInode{
+ task: task,
+ pidns: pidns,
+ inoGen: inoGen,
+ }
+ // Note: credentials are overridden by taskOwnedInode.
+ subInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555)
+ subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ inode := &taskOwnedInode{Inode: subInode, owner: task}
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ return dentry
+}
+
+// Valid implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) Valid(ctx context.Context) bool {
+ return true
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ tid, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+
+ subTask := i.pidns.TaskWithID(kernel.ThreadID(tid))
+ if subTask == nil {
+ return nil, syserror.ENOENT
+ }
+ if subTask.ThreadGroup() != i.task.ThreadGroup() {
+ return nil, syserror.ENOENT
+ }
+
+ subTaskDentry := newTaskInode(i.inoGen, subTask, i.pidns, false)
+ return subTaskDentry.VFSDentry(), nil
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
+ if len(tasks) == 0 {
+ return offset, syserror.ENOENT
+ }
+
+ tids := make([]int, 0, len(tasks))
+ for _, tid := range tasks {
+ tids = append(tids, int(tid))
+ }
+
+ sort.Ints(tids)
+ for _, tid := range tids[relOffset:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(tid), 10),
+ Type: linux.DT_DIR,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+// Open implements kernfs.Inode.
+func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
+
+// Stat implements kernfs.Inode.
+func (i *subtasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
+ stat := i.InodeAttrs.Stat(vsfs)
+ stat.Nlink += uint32(i.task.ThreadGroup().Count())
+ return stat
+}
diff --git a/pkg/sentry/fsimpl/proc/sys.go b/pkg/sentry/fsimpl/proc/sys.go
deleted file mode 100644
index b88256e12..000000000
--- a/pkg/sentry/fsimpl/proc/sys.go
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// mmapMinAddrData implements vfs.DynamicBytesSource for
-// /proc/sys/vm/mmap_min_addr.
-//
-// +stateify savable
-type mmapMinAddrData struct {
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*mmapMinAddrData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress())
- return nil
-}
-
-// +stateify savable
-type overcommitMemory struct{}
-
-var _ vfs.DynamicBytesSource = (*overcommitMemory)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *overcommitMemory) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "0\n")
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 0d87be52b..621c17cfe 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -15,251 +15,215 @@
package proc
import (
- "bytes"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
-// mapsCommon is embedded by mapsData and smapsData.
-type mapsCommon struct {
- t *kernel.Task
-}
-
-// mm gets the kernel task's MemoryManager. No additional reference is taken on
-// mm here. This is safe because MemoryManager.destroy is required to leave the
-// MemoryManager in a state where it's still usable as a DynamicBytesSource.
-func (md *mapsCommon) mm() *mm.MemoryManager {
- var tmm *mm.MemoryManager
- md.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- tmm = mm
- }
- })
- return tmm
-}
-
-// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
+// taskInode represents the inode for /proc/PID/ directory.
//
// +stateify savable
-type mapsData struct {
- mapsCommon
+type taskInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+
+ task *kernel.Task
}
-var _ vfs.DynamicBytesSource = (*mapsData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (md *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- if mm := md.mm(); mm != nil {
- mm.ReadMapsDataInto(ctx, buf)
+var _ kernfs.Inode = (*taskInode)(nil)
+
+func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool) *kernfs.Dentry {
+ contents := map[string]*kernfs.Dentry{
+ "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}),
+ "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
+ "comm": newComm(task, inoGen.NextIno(), 0444),
+ "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
+ //"exe": newExe(t, msrc),
+ //"fd": newFdDir(t, msrc),
+ //"fdinfo": newFdInfoDir(t, msrc),
+ "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}),
+ "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)),
+ "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}),
+ //"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
+ //"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ "ns": newTaskOwnedDir(task, inoGen.NextIno(), 0511, map[string]*kernfs.Dentry{
+ "net": newNamespaceSymlink(task, inoGen.NextIno(), "net"),
+ "pid": newNamespaceSymlink(task, inoGen.NextIno(), "pid"),
+ "user": newNamespaceSymlink(task, inoGen.NextIno(), "user"),
+ }),
+ "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}),
+ "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}),
+ "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
- return nil
+ if isThreadGroup {
+ contents["task"] = newSubtasks(task, pidns, inoGen)
+ }
+ //if len(p.cgroupControllers) > 0 {
+ // contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers)
+ //}
+
+ taskInode := &taskInode{task: task}
+ // Note: credentials are overridden by taskOwnedInode.
+ taskInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555)
+
+ inode := &taskOwnedInode{Inode: taskInode, owner: task}
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ taskInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := taskInode.OrderedChildren.Populate(dentry, contents)
+ taskInode.IncLinks(links)
+
+ return dentry
}
-// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps.
-//
-// +stateify savable
-type smapsData struct {
- mapsCommon
+// Valid implements kernfs.inodeDynamicLookup. This inode remains valid as long
+// as the task is still running. When it's dead, another tasks with the same
+// PID could replace it.
+func (i *taskInode) Valid(ctx context.Context) bool {
+ return i.task.ExitState() != kernel.TaskExitDead
}
-var _ vfs.DynamicBytesSource = (*smapsData)(nil)
+// Open implements kernfs.Inode.
+func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (sd *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- if mm := sd.mm(); mm != nil {
- mm.ReadSmapsDataInto(ctx, buf)
+// SetStat implements kernfs.Inode.
+func (i *taskInode) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
+ stat := opts.Stat
+ if stat.Mask&linux.STATX_MODE != 0 {
+ return syserror.EPERM
}
return nil
}
-// +stateify savable
-type taskStatData struct {
- t *kernel.Task
-
- // If tgstats is true, accumulate fault stats (not implemented) and CPU
- // time across all tasks in t's thread group.
- tgstats bool
+// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
+// effective user and group.
+type taskOwnedInode struct {
+ kernfs.Inode
- // pidns is the PID namespace associated with the proc filesystem that
- // includes the file using this statData.
- pidns *kernel.PIDNamespace
+ // owner is the task that owns this inode.
+ owner *kernel.Task
}
-var _ vfs.DynamicBytesSource = (*taskStatData)(nil)
+var _ kernfs.Inode = (*taskOwnedInode)(nil)
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t))
- fmt.Fprintf(buf, "(%s) ", s.t.Name())
- fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0])
- ppid := kernel.ThreadID(0)
- if parent := s.t.Parent(); parent != nil {
- ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
- }
- fmt.Fprintf(buf, "%d ", ppid)
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup()))
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session()))
- fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
- fmt.Fprintf(buf, "0 " /* flags */)
- fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
- var cputime usage.CPUStats
- if s.tgstats {
- cputime = s.t.ThreadGroup().CPUStats()
- } else {
- cputime = s.t.CPUStats()
- }
- fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
- cputime = s.t.ThreadGroup().JoinedChildCPUStats()
- fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
- fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness())
- fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count())
-
- // itrealvalue. Since kernel 2.6.17, this field is no longer
- // maintained, and is hard coded as 0.
- fmt.Fprintf(buf, "0 ")
-
- // Start time is relative to boot time, expressed in clock ticks.
- fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
-
- var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
- fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
-
- // rsslim.
- fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
-
- fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
- fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
- fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
- terminationSignal := linux.Signal(0)
- if s.t == s.t.ThreadGroup().Leader() {
- terminationSignal = s.t.ThreadGroup().TerminationSignal()
- }
- fmt.Fprintf(buf, "%d ", terminationSignal)
- fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
- fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
- fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
- fmt.Fprintf(buf, "0\n" /* exit_code */)
+func newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), ino, inode, perm)
- return nil
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
}
-// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm.
-//
-// +stateify savable
-type statmData struct {
- t *kernel.Task
-}
+func newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &kernfs.StaticDirectory{}
-var _ vfs.DynamicBytesSource = (*statmData)(nil)
+ // Note: credentials are overridden by taskOwnedInode.
+ dir.Init(task.Credentials(), ino, perm)
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
+ inode := &taskOwnedInode{Inode: dir, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(inode)
- fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
- return nil
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := dir.OrderedChildren.Populate(d, children)
+ dir.IncLinks(links)
+
+ return d
}
-// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
-//
-// +stateify savable
-type statusData struct {
- t *kernel.Task
- pidns *kernel.PIDNamespace
+// Stat implements kernfs.Inode.
+func (i *taskOwnedInode) Stat(fs *vfs.Filesystem) linux.Statx {
+ stat := i.Inode.Stat(fs)
+ uid, gid := i.getOwner(linux.FileMode(stat.Mode))
+ stat.UID = uint32(uid)
+ stat.GID = uint32(gid)
+ return stat
}
-var _ vfs.DynamicBytesSource = (*statusData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name())
- fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus())
- fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup()))
- fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t))
- ppid := kernel.ThreadID(0)
- if parent := s.t.Parent(); parent != nil {
- ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+// CheckPermissions implements kernfs.Inode.
+func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := i.Mode()
+ uid, gid := i.getOwner(mode)
+ return vfs.GenericCheckPermissions(
+ creds,
+ ats,
+ mode.FileType() == linux.ModeDirectory,
+ uint16(mode),
+ uid,
+ gid,
+ )
+}
+
+func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) {
+ // By default, set the task owner as the file owner.
+ creds := i.owner.Credentials()
+ uid := creds.EffectiveKUID
+ gid := creds.EffectiveKGID
+
+ // Linux doesn't apply dumpability adjustments to world readable/executable
+ // directories so that applications can stat /proc/PID to determine the
+ // effective UID of a process. See fs/proc/base.c:task_dump_owner.
+ if mode.FileType() == linux.ModeDirectory && mode.Permissions() == 0555 {
+ return uid, gid
}
- fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
- tpid := kernel.ThreadID(0)
- if tracer := s.t.Tracer(); tracer != nil {
- tpid = s.pidns.IDOfTask(tracer)
+
+ // If the task is not dumpable, then root (in the namespace preferred)
+ // owns the file.
+ m := getMM(i.owner)
+ if m == nil {
+ return auth.RootKUID, auth.RootKGID
}
- fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
- var fds int
- var vss, rss, data uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.Size()
+ if m.Dumpability() != mm.UserDumpable {
+ uid = auth.RootKUID
+ if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() {
+ uid = kuid
}
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- data = mm.VirtualDataSize()
+ gid = auth.RootKGID
+ if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() {
+ gid = kgid
}
- })
- fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
- fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
- fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
- fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
- fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count())
- creds := s.t.Credentials()
- fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
- fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
- fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
- fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
- fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode())
- // We unconditionally report a single NUMA node. See
- // pkg/sentry/syscalls/linux/sys_mempolicy.go.
- fmt.Fprintf(buf, "Mems_allowed:\t1\n")
- fmt.Fprintf(buf, "Mems_allowed_list:\t0\n")
- return nil
-}
-
-// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
-type ioUsage interface {
- // IOUsage returns the io usage data.
- IOUsage() *usage.IO
+ }
+ return uid, gid
}
-// +stateify savable
-type ioData struct {
- ioUsage
+func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
+ if isThreadGroup {
+ return &ioData{ioUsage: t.ThreadGroup()}
+ }
+ return &ioData{ioUsage: t}
}
-var _ vfs.DynamicBytesSource = (*ioData)(nil)
+func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
+ // Namespace symlinks should contain the namespace name and the inode number
+ // for the namespace instance, so for example user:[123456]. We currently fake
+ // the inode number by sticking the symlink inode in its place.
+ target := fmt.Sprintf("%s:[%d]", ns, ino)
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- io := usage.IO{}
- io.Accumulate(i.IOUsage())
+ inode := &kernfs.StaticSymlink{}
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), ino, target)
- fmt.Fprintf(buf, "char: %d\n", io.CharsRead)
- fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten)
- fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls)
- fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls)
- fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead)
- fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten)
- fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
- return nil
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
new file mode 100644
index 000000000..7bc352ae9
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -0,0 +1,527 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// mm gets the kernel task's MemoryManager. No additional reference is taken on
+// mm here. This is safe because MemoryManager.destroy is required to leave the
+// MemoryManager in a state where it's still usable as a DynamicBytesSource.
+func getMM(task *kernel.Task) *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// getMMIncRef returns t's MemoryManager. If getMMIncRef succeeds, the
+// MemoryManager's users count is incremented, and must be decremented by the
+// caller when it is no longer in use.
+func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) {
+ if task.ExitState() == kernel.TaskExitDead {
+ return nil, syserror.ESRCH
+ }
+ var m *mm.MemoryManager
+ task.WithMuLocked(func(t *kernel.Task) {
+ m = t.MemoryManager()
+ })
+ if m == nil || !m.IncUsers() {
+ return nil, io.EOF
+ }
+ return m, nil
+}
+
+type bufferWriter struct {
+ buf *bytes.Buffer
+}
+
+// WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns
+// the number of bytes written. It may return a partial write without an
+// error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not
+// return a full write with an error (i.e. srcs.NumBytes(), err) where err
+// != nil).
+func (w *bufferWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ written := srcs.NumBytes()
+ for !srcs.IsEmpty() {
+ w.buf.Write(srcs.Head().ToSlice())
+ srcs = srcs.Tail()
+ }
+ return written, nil
+}
+
+// auxvData implements vfs.DynamicBytesSource for /proc/[pid]/auxv.
+//
+// +stateify savable
+type auxvData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*auxvData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ m, err := getMMIncRef(d.task)
+ if err != nil {
+ return err
+ }
+ defer m.DecUsers(ctx)
+
+ // Space for buffer with AT_NULL (0) terminator at the end.
+ auxv := m.Auxv()
+ buf.Grow((len(auxv) + 1) * 16)
+ for _, e := range auxv {
+ var tmp [8]byte
+ usermem.ByteOrder.PutUint64(tmp[:], e.Key)
+ buf.Write(tmp[:])
+
+ usermem.ByteOrder.PutUint64(tmp[:], uint64(e.Value))
+ buf.Write(tmp[:])
+ }
+ return nil
+}
+
+// execArgType enumerates the types of exec arguments that are exposed through
+// proc.
+type execArgType int
+
+const (
+ cmdlineDataArg execArgType = iota
+ environDataArg
+)
+
+// cmdlineData implements vfs.DynamicBytesSource for /proc/[pid]/cmdline.
+//
+// +stateify savable
+type cmdlineData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+
+ // arg is the type of exec argument this file contains.
+ arg execArgType
+}
+
+var _ dynamicInode = (*cmdlineData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ m, err := getMMIncRef(d.task)
+ if err != nil {
+ return err
+ }
+ defer m.DecUsers(ctx)
+
+ // Figure out the bounds of the exec arg we are trying to read.
+ var ar usermem.AddrRange
+ switch d.arg {
+ case cmdlineDataArg:
+ ar = usermem.AddrRange{
+ Start: m.ArgvStart(),
+ End: m.ArgvEnd(),
+ }
+ case environDataArg:
+ ar = usermem.AddrRange{
+ Start: m.EnvvStart(),
+ End: m.EnvvEnd(),
+ }
+ default:
+ panic(fmt.Sprintf("unknown exec arg type %v", d.arg))
+ }
+ if ar.Start == 0 || ar.End == 0 {
+ // Don't attempt to read before the start/end are set up.
+ return io.EOF
+ }
+
+ // N.B. Technically this should be usermem.IOOpts.IgnorePermissions = true
+ // until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading
+ // cmdline and environment").
+ writer := &bufferWriter{buf: buf}
+ if n, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil {
+ // Nothing to copy or something went wrong.
+ return err
+ }
+
+ // On Linux, if the NULL byte at the end of the argument vector has been
+ // overwritten, it continues reading the environment vector as part of
+ // the argument vector.
+ if d.arg == cmdlineDataArg && buf.Bytes()[buf.Len()-1] != 0 {
+ if end := bytes.IndexByte(buf.Bytes(), 0); end != -1 {
+ // If we found a NULL character somewhere else in argv, truncate the
+ // return up to the NULL terminator (including it).
+ buf.Truncate(end)
+ return nil
+ }
+
+ // There is no NULL terminator in the string, return into envp.
+ arEnvv := usermem.AddrRange{
+ Start: m.EnvvStart(),
+ End: m.EnvvEnd(),
+ }
+
+ // Upstream limits the returned amount to one page of slop.
+ // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208
+ // we'll return one page total between argv and envp because of the
+ // above page restrictions.
+ if buf.Len() >= usermem.PageSize {
+ // Returned at least one page already, nothing else to add.
+ return nil
+ }
+ remaining := usermem.PageSize - buf.Len()
+ if int(arEnvv.Length()) > remaining {
+ end, ok := arEnvv.Start.AddLength(uint64(remaining))
+ if !ok {
+ return syserror.EFAULT
+ }
+ arEnvv.End = end
+ }
+ if _, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil {
+ return err
+ }
+
+ // Linux will return envp up to and including the first NULL character,
+ // so find it.
+ if end := bytes.IndexByte(buf.Bytes()[ar.Length():], 0); end != -1 {
+ buf.Truncate(end)
+ }
+ }
+
+ return nil
+}
+
+// +stateify savable
+type commInode struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+func newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry {
+ inode := &commInode{task: task}
+ inode.DynamicBytesFile.Init(task.Credentials(), ino, &commData{task: task}, perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (i *commInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ // This file can always be read or written by members of the same thread
+ // group. See fs/proc/base.c:proc_tid_comm_permission.
+ //
+ // N.B. This check is currently a no-op as we don't yet support writing and
+ // this file is world-readable anyways.
+ t := kernel.TaskFromContext(ctx)
+ if t != nil && t.ThreadGroup() == i.task.ThreadGroup() && !ats.MayExec() {
+ return nil
+ }
+
+ return i.DynamicBytesFile.CheckPermissions(ctx, creds, ats)
+}
+
+// commData implements vfs.DynamicBytesSource for /proc/[pid]/comm.
+//
+// +stateify savable
+type commData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*commData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *commData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(d.task.Name())
+ buf.WriteString("\n")
+ return nil
+}
+
+// idMapData implements vfs.DynamicBytesSource for /proc/[pid]/{gid_map|uid_map}.
+//
+// +stateify savable
+type idMapData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+ gids bool
+}
+
+var _ dynamicInode = (*idMapData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var entries []auth.IDMapEntry
+ if d.gids {
+ entries = d.task.UserNamespace().GIDMap()
+ } else {
+ entries = d.task.UserNamespace().UIDMap()
+ }
+ for _, e := range entries {
+ fmt.Fprintf(buf, "%10d %10d %10d\n", e.FirstID, e.FirstParentID, e.Length)
+ }
+ return nil
+}
+
+// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
+//
+// +stateify savable
+type mapsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := getMM(d.task); mm != nil {
+ mm.ReadMapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps.
+//
+// +stateify savable
+type smapsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*smapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := getMM(d.task); mm != nil {
+ mm.ReadSmapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// +stateify savable
+type taskStatData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+
+ // If tgstats is true, accumulate fault stats (not implemented) and CPU
+ // time across all tasks in t's thread group.
+ tgstats bool
+
+ // pidns is the PID namespace associated with the proc filesystem that
+ // includes the file using this statData.
+ pidns *kernel.PIDNamespace
+}
+
+var _ dynamicInode = (*taskStatData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.task))
+ fmt.Fprintf(buf, "(%s) ", s.task.Name())
+ fmt.Fprintf(buf, "%c ", s.task.StateStatus()[0])
+ ppid := kernel.ThreadID(0)
+ if parent := s.task.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "%d ", ppid)
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.task.ThreadGroup().ProcessGroup()))
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.task.ThreadGroup().Session()))
+ fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
+ fmt.Fprintf(buf, "0 " /* flags */)
+ fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
+ var cputime usage.CPUStats
+ if s.tgstats {
+ cputime = s.task.ThreadGroup().CPUStats()
+ } else {
+ cputime = s.task.CPUStats()
+ }
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ cputime = s.task.ThreadGroup().JoinedChildCPUStats()
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ fmt.Fprintf(buf, "%d %d ", s.task.Priority(), s.task.Niceness())
+ fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Count())
+
+ // itrealvalue. Since kernel 2.6.17, this field is no longer
+ // maintained, and is hard coded as 0.
+ fmt.Fprintf(buf, "0 ")
+
+ // Start time is relative to boot time, expressed in clock ticks.
+ fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime())))
+
+ var vss, rss uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+ fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
+
+ // rsslim.
+ fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Limits().Get(limits.Rss).Cur)
+
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
+ fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
+ terminationSignal := linux.Signal(0)
+ if s.task == s.task.ThreadGroup().Leader() {
+ terminationSignal = s.task.ThreadGroup().TerminationSignal()
+ }
+ fmt.Fprintf(buf, "%d ", terminationSignal)
+ fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
+ fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
+ fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
+ fmt.Fprintf(buf, "0\n" /* exit_code */)
+
+ return nil
+}
+
+// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm.
+//
+// +stateify savable
+type statmData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*statmData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var vss, rss uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+
+ fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+ return nil
+}
+
+// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
+//
+// +stateify savable
+type statusData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+var _ dynamicInode = (*statusData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name())
+ fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus())
+ fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup()))
+ fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task))
+ ppid := kernel.ThreadID(0)
+ if parent := s.task.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
+ tpid := kernel.ThreadID(0)
+ if tracer := s.task.Tracer(); tracer != nil {
+ tpid = s.pidns.IDOfTask(tracer)
+ }
+ fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
+ var fds int
+ var vss, rss, data uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.Size()
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
+ })
+ fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
+ fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
+ fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
+ fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
+ fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count())
+ creds := s.task.Credentials()
+ fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
+ fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
+ fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
+ fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
+ fmt.Fprintf(buf, "Seccomp:\t%d\n", s.task.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(buf, "Mems_allowed_list:\t0\n")
+ return nil
+}
+
+// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
+type ioUsage interface {
+ // IOUsage returns the io usage data.
+ IOUsage() *usage.IO
+}
+
+// +stateify savable
+type ioData struct {
+ kernfs.DynamicBytesFile
+
+ ioUsage
+}
+
+var _ dynamicInode = (*ioData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ io := usage.IO{}
+ io.Accumulate(i.IOUsage())
+
+ fmt.Fprintf(buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten)
+ fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls)
+ fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls)
+ fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead)
+ fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten)
+ fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
new file mode 100644
index 000000000..a97b1753a
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -0,0 +1,235 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ selfName = "self"
+ threadSelfName = "thread-self"
+)
+
+// InoGenerator generates unique inode numbers for a given filesystem.
+type InoGenerator interface {
+ NextIno() uint64
+}
+
+// tasksInode represents the inode for /proc/ directory.
+//
+// +stateify savable
+type tasksInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+
+ inoGen InoGenerator
+ pidns *kernel.PIDNamespace
+
+ // '/proc/self' and '/proc/thread-self' have custom directory offsets in
+ // Linux. So handle them outside of OrderedChildren.
+ selfSymlink *vfs.Dentry
+ threadSelfSymlink *vfs.Dentry
+}
+
+var _ kernfs.Inode = (*tasksInode)(nil)
+
+func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace) (*tasksInode, *kernfs.Dentry) {
+ root := auth.NewRootCredentials(pidns.UserNamespace())
+ contents := map[string]*kernfs.Dentry{
+ "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(cpuInfoData(k))),
+ //"filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}),
+ "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}),
+ "sys": newSysDir(root, inoGen),
+ "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
+ }
+
+ inode := &tasksInode{
+ pidns: pidns,
+ inoGen: inoGen,
+ selfSymlink: newSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
+ threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
+ }
+ inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := inode.OrderedChildren.Populate(dentry, contents)
+ inode.IncLinks(links)
+
+ return inode, dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ // Try to lookup a corresponding task.
+ tid, err := strconv.ParseUint(name, 10, 64)
+ if err != nil {
+ // If it failed to parse, check if it's one of the special handled files.
+ switch name {
+ case selfName:
+ return i.selfSymlink, nil
+ case threadSelfName:
+ return i.threadSelfSymlink, nil
+ }
+ return nil, syserror.ENOENT
+ }
+
+ task := i.pidns.TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return nil, syserror.ENOENT
+ }
+
+ taskDentry := newTaskInode(i.inoGen, task, i.pidns, true)
+ return taskDentry.VFSDentry(), nil
+}
+
+// Valid implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) Valid(ctx context.Context) bool {
+ return true
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
+ // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
+ const FIRST_PROCESS_ENTRY = 256
+
+ // Use maxTaskID to shortcut searches that will result in 0 entries.
+ const maxTaskID = kernel.TasksLimit + 1
+ if offset >= maxTaskID {
+ return offset, nil
+ }
+
+ // According to Linux (fs/proc/base.c:proc_pid_readdir()), process directories
+ // start at offset FIRST_PROCESS_ENTRY with '/proc/self', followed by
+ // '/proc/thread-self' and then '/proc/[pid]'.
+ if offset < FIRST_PROCESS_ENTRY {
+ offset = FIRST_PROCESS_ENTRY
+ }
+
+ if offset == FIRST_PROCESS_ENTRY {
+ dirent := vfs.Dirent{
+ Name: selfName,
+ Type: linux.DT_LNK,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+ if offset == FIRST_PROCESS_ENTRY+1 {
+ dirent := vfs.Dirent{
+ Name: threadSelfName,
+ Type: linux.DT_LNK,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+
+ // Collect all tasks that TGIDs are greater than the offset specified. Per
+ // Linux we only include in directory listings if it's the leader. But for
+ // whatever crazy reason, you can still walk to the given node.
+ var tids []int
+ startTid := offset - FIRST_PROCESS_ENTRY - 2
+ for _, tg := range i.pidns.ThreadGroups() {
+ tid := i.pidns.IDOfThreadGroup(tg)
+ if int64(tid) < startTid {
+ continue
+ }
+ if leader := tg.Leader(); leader != nil {
+ tids = append(tids, int(tid))
+ }
+ }
+
+ if len(tids) == 0 {
+ return offset, nil
+ }
+
+ sort.Ints(tids)
+ for _, tid := range tids {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(tid), 10),
+ Type: linux.DT_DIR,
+ Ino: i.inoGen.NextIno(),
+ NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+ return maxTaskID, nil
+}
+
+// Open implements kernfs.Inode.
+func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
+
+func (i *tasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
+ stat := i.InodeAttrs.Stat(vsfs)
+
+ // Add dynamic children to link count.
+ for _, tg := range i.pidns.ThreadGroups() {
+ if leader := tg.Leader(); leader != nil {
+ stat.Nlink++
+ }
+ }
+
+ return stat
+}
+
+func cpuInfoData(k *kernel.Kernel) string {
+ features := k.FeatureSet()
+ if features == nil {
+ // Kernel is always initialized with a FeatureSet.
+ panic("cpuinfo read with nil FeatureSet")
+ }
+ var buf bytes.Buffer
+ for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
+ features.WriteCPUInfoTo(i, &buf)
+ }
+ return buf.String()
+}
+
+func shmData(v uint64) dynamicInode {
+ return newStaticFile(strconv.FormatUint(v, 10))
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
new file mode 100644
index 000000000..ad3760e39
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -0,0 +1,337 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+type selfSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ pidns *kernel.PIDNamespace
+}
+
+var _ kernfs.Inode = (*selfSymlink)(nil)
+
+func newSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+ inode := &selfSymlink{pidns: pidns}
+ inode.Init(creds, ino, linux.ModeSymlink|perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // Who is reading this link?
+ return "", syserror.EINVAL
+ }
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ if tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return strconv.FormatUint(uint64(tgid), 10), nil
+}
+
+type threadSelfSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ pidns *kernel.PIDNamespace
+}
+
+var _ kernfs.Inode = (*threadSelfSymlink)(nil)
+
+func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+ inode := &threadSelfSymlink{pidns: pidns}
+ inode.Init(creds, ino, linux.ModeSymlink|perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // Who is reading this link?
+ return "", syserror.EINVAL
+ }
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ tid := s.pidns.IDOfTask(t)
+ if tid == 0 || tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return fmt.Sprintf("%d/task/%d", tgid, tid), nil
+}
+
+// cpuStats contains the breakdown of CPU time for /proc/stat.
+type cpuStats struct {
+ // user is time spent in userspace tasks with non-positive niceness.
+ user uint64
+
+ // nice is time spent in userspace tasks with positive niceness.
+ nice uint64
+
+ // system is time spent in non-interrupt kernel context.
+ system uint64
+
+ // idle is time spent idle.
+ idle uint64
+
+ // ioWait is time spent waiting for IO.
+ ioWait uint64
+
+ // irq is time spent in interrupt context.
+ irq uint64
+
+ // softirq is time spent in software interrupt context.
+ softirq uint64
+
+ // steal is involuntary wait time.
+ steal uint64
+
+ // guest is time spent in guests with non-positive niceness.
+ guest uint64
+
+ // guestNice is time spent in guests with positive niceness.
+ guestNice uint64
+}
+
+// String implements fmt.Stringer.
+func (c cpuStats) String() string {
+ return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
+}
+
+// statData implements vfs.DynamicBytesSource for /proc/stat.
+//
+// +stateify savable
+type statData struct {
+ kernfs.DynamicBytesFile
+
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*statData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/37226836): We currently export only zero CPU stats. We could
+ // at least provide some aggregate stats.
+ var cpu cpuStats
+ fmt.Fprintf(buf, "cpu %s\n", cpu)
+
+ for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
+ fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
+ }
+
+ // The total number of interrupts is dependent on the CPUs and PCI
+ // devices on the system. See arch_probe_nr_irqs.
+ //
+ // Since we don't report real interrupt stats, just choose an arbitrary
+ // value from a representative VM.
+ const numInterrupts = 256
+
+ // The Kernel doesn't handle real interrupts, so report all zeroes.
+ // TODO(b/37226836): We could count page faults as #PF.
+ fmt.Fprintf(buf, "intr 0") // total
+ for i := 0; i < numInterrupts; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+
+ // Total number of context switches.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "ctxt 0\n")
+
+ // CLOCK_REALTIME timestamp from boot, in seconds.
+ fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
+
+ // Total number of clones.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "processes 0\n")
+
+ // Number of runnable tasks.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_running 0\n")
+
+ // Number of tasks waiting on IO.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_blocked 0\n")
+
+ // Number of each softirq handled.
+ fmt.Fprintf(buf, "softirq 0") // total
+ for i := 0; i < linux.NumSoftIRQ; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+ return nil
+}
+
+// loadavgData backs /proc/loadavg.
+//
+// +stateify savable
+type loadavgData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*loadavgData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/62345059): Include real data in fields.
+ // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
+ // Column 4-5: currently running processes and the total number of processes.
+ // Column 6: the last process ID used.
+ fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
+ return nil
+}
+
+// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
+//
+// +stateify savable
+type meminfoData struct {
+ kernfs.DynamicBytesFile
+
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*meminfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ mf := d.k.MemoryFile()
+ mf.UpdateUsage()
+ snapshot, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ anon := snapshot.Anonymous + snapshot.Tmpfs
+ file := snapshot.PageCache + snapshot.Mapped
+ // We don't actually have active/inactive LRUs, so just make up numbers.
+ activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ inactiveFile := file - activeFile
+
+ fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
+ memFree := (totalSize - totalUsage) / 1024
+ // We use MemFree as MemAvailable because we don't swap.
+ // TODO(rahat): When reclaim is implemented the value of MemAvailable
+ // should change.
+ fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree)
+ fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree)
+ fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices
+ fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
+ // Emulate a system with no swap, which disables inactivation of anon pages.
+ fmt.Fprintf(buf, "SwapCache: 0 kB\n")
+ fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024)
+ fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Inactive(anon): 0 kB\n")
+ fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024)
+ fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "SwapTotal: 0 kB\n")
+ fmt.Fprintf(buf, "SwapFree: 0 kB\n")
+ fmt.Fprintf(buf, "Dirty: 0 kB\n")
+ fmt.Fprintf(buf, "Writeback: 0 kB\n")
+ fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
+ fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
+ return nil
+}
+
+// uptimeData implements vfs.DynamicBytesSource for /proc/uptime.
+//
+// +stateify savable
+type uptimeData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*uptimeData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*uptimeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ now := time.NowFromContext(ctx)
+
+ // Pretend that we've spent zero time sleeping (second number).
+ fmt.Fprintf(buf, "%.2f 0.00\n", now.Sub(k.Timekeeper().BootTime()).Seconds())
+ return nil
+}
+
+// versionData implements vfs.DynamicBytesSource for /proc/version.
+//
+// +stateify savable
+type versionData struct {
+ kernfs.DynamicBytesFile
+
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*versionData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ init := v.k.GlobalInit()
+ if init == nil {
+ // Attempted to read before the init Task is created. This can
+ // only occur during startup, which should never need to read
+ // this file.
+ panic("Attempted to read version before initial Task is available")
+ }
+
+ // /proc/version takes the form:
+ //
+ // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
+ // (COMPILER_VERSION) VERSION"
+ //
+ // where:
+ // - SYSNAME, RELEASE, and VERSION are the same as returned by
+ // sys_utsname
+ // - COMPILE_USER is the user that build the kernel
+ // - COMPILE_HOST is the hostname of the machine on which the kernel
+ // was built
+ // - COMPILER_VERSION is the version reported by the building compiler
+ //
+ // Since we don't really want to expose build information to
+ // applications, those fields are omitted.
+ //
+ // FIXME(mpratt): Using Version from the init task SyscallTable
+ // disregards the different version a task may have (e.g., in a uts
+ // namespace).
+ ver := init.Leader().SyscallTable().Version
+ fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/net.go b/pkg/sentry/fsimpl/proc/tasks_net.go
index fd46eebf8..06dc43c26 100644
--- a/pkg/sentry/fsimpl/proc/net.go
+++ b/pkg/sentry/fsimpl/proc/tasks_net.go
@@ -46,8 +46,7 @@ func (n *ifinet6) contents() []string {
for id, naddrs := range n.s.InterfaceAddrs() {
nic, ok := nics[id]
if !ok {
- // NIC was added after NICNames was called. We'll just
- // ignore it.
+ // NIC was added after NICNames was called. We'll just ignore it.
continue
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
new file mode 100644
index 000000000..aabf2bf0c
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -0,0 +1,143 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// newSysDir returns the dentry corresponding to /proc/sys directory.
+func newSysDir(root *auth.Credentials, inoGen InoGenerator) *kernfs.Dentry {
+ return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "kernel": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "hostname": newDentry(root, inoGen.NextIno(), 0444, &hostnameData{}),
+ "shmall": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMALL)),
+ "shmmax": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMAX)),
+ "shmmni": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMNI)),
+ }),
+ "vm": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "mmap_min_addr": newDentry(root, inoGen.NextIno(), 0444, &mmapMinAddrData{}),
+ "overcommit_memory": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0\n")),
+ }),
+ "net": newSysNetDir(root, inoGen),
+ })
+}
+
+// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
+func newSysNetDir(root *auth.Credentials, inoGen InoGenerator) *kernfs.Dentry {
+ return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "net": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ // Add tcp_sack.
+ // TODO(gvisor.dev/issue/1195): tcp_sack allows write(2)
+ // "tcp_sack": newTCPSackInode(ctx, msrc, s),
+
+ // The following files are simple stubs until they are implemented in
+ // netstack, most of these files are configuration related. We use the
+ // value closest to the actual netstack behavior or any empty file, all
+ // of these files will have mode 0444 (read-only for all users).
+ "ip_local_port_range": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")),
+ "ipfrag_time": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+
+ // tcp_allowed_congestion_control tell the user what they are able to
+ // do as an unprivledged process so we leave it empty.
+ "tcp_allowed_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")),
+ "tcp_available_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("reno")),
+ "tcp_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("reno")),
+
+ // Many of the following stub files are features netstack doesn't
+ // support. The unsupported features return "0" to indicate they are
+ // disabled.
+ "tcp_base_mss": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1280")),
+ "tcp_dsack": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_early_retrans": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fack": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fastopen": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fastopen_key": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_time": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+ "tcp_probe_interval": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_probe_threshold": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_retries1": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("3")),
+ "tcp_retries2": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("15")),
+ "tcp_rfc1337": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+ "tcp_synack_retries": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("5")),
+ "tcp_syn_retries": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("3")),
+ "tcp_timestamps": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+ }),
+ "core": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "default_qdisc": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("pfifo_fast")),
+ "message_burst": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("10")),
+ "message_cost": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("5")),
+ "optmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "rmem_default": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")),
+ "rmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")),
+ "somaxconn": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("128")),
+ "wmem_default": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")),
+ "wmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")),
+ }),
+ }),
+ })
+}
+
+// mmapMinAddrData implements vfs.DynamicBytesSource for
+// /proc/sys/vm/mmap_min_addr.
+//
+// +stateify savable
+type mmapMinAddrData struct {
+ kernfs.DynamicBytesFile
+
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*mmapMinAddrData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress())
+ return nil
+}
+
+// hostnameData implements vfs.DynamicBytesSource for /proc/sys/kernel/hostname.
+//
+// +stateify savable
+type hostnameData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*hostnameData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*hostnameData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ utsns := kernel.UTSNamespaceFromContext(ctx)
+ buf.WriteString(utsns.HostName())
+ buf.WriteString("\n")
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/net_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
index 20a77a8ca..20a77a8ca 100644
--- a/pkg/sentry/fsimpl/proc/net_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
new file mode 100644
index 000000000..6b58c16b9
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -0,0 +1,566 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "math"
+ "path"
+ "strconv"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var (
+ // Next offset 256 by convention. Adds 1 for the next offset.
+ selfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 0 + 1}
+ threadSelfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 1 + 1}
+
+ // /proc/[pid] next offset starts at 256+2 (files above), then adds the
+ // PID, and adds 1 for the next offset.
+ proc1 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 1 + 1}
+ proc2 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 2 + 1}
+ proc3 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 3 + 1}
+)
+
+type testIterDirentsCallback struct {
+ dirents []vfs.Dirent
+}
+
+func (t *testIterDirentsCallback) Handle(d vfs.Dirent) bool {
+ t.dirents = append(t.dirents, d)
+ return true
+}
+
+func checkDots(dirs []vfs.Dirent) ([]vfs.Dirent, error) {
+ if got := len(dirs); got < 2 {
+ return dirs, fmt.Errorf("wrong number of dirents, want at least: 2, got: %d: %v", got, dirs)
+ }
+ for i, want := range []string{".", ".."} {
+ if got := dirs[i].Name; got != want {
+ return dirs, fmt.Errorf("wrong name, want: %s, got: %s", want, got)
+ }
+ if got := dirs[i].Type; got != linux.DT_DIR {
+ return dirs, fmt.Errorf("wrong type, want: %d, got: %d", linux.DT_DIR, got)
+ }
+ }
+ return dirs[2:], nil
+}
+
+func checkTasksStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) {
+ wants := map[string]vfs.Dirent{
+ "cpuinfo": {Type: linux.DT_REG},
+ "loadavg": {Type: linux.DT_REG},
+ "meminfo": {Type: linux.DT_REG},
+ "mounts": {Type: linux.DT_LNK},
+ "self": selfLink,
+ "stat": {Type: linux.DT_REG},
+ "sys": {Type: linux.DT_DIR},
+ "thread-self": threadSelfLink,
+ "uptime": {Type: linux.DT_REG},
+ "version": {Type: linux.DT_REG},
+ }
+ return checkFiles(gots, wants)
+}
+
+func checkTaskStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) {
+ wants := map[string]vfs.Dirent{
+ "auxv": {Type: linux.DT_REG},
+ "cmdline": {Type: linux.DT_REG},
+ "comm": {Type: linux.DT_REG},
+ "environ": {Type: linux.DT_REG},
+ "gid_map": {Type: linux.DT_REG},
+ "io": {Type: linux.DT_REG},
+ "maps": {Type: linux.DT_REG},
+ "ns": {Type: linux.DT_DIR},
+ "smaps": {Type: linux.DT_REG},
+ "stat": {Type: linux.DT_REG},
+ "statm": {Type: linux.DT_REG},
+ "status": {Type: linux.DT_REG},
+ "task": {Type: linux.DT_DIR},
+ "uid_map": {Type: linux.DT_REG},
+ }
+ return checkFiles(gots, wants)
+}
+
+func checkFiles(gots []vfs.Dirent, wants map[string]vfs.Dirent) ([]vfs.Dirent, error) {
+ // Go over all files, when there is a match, the file is removed from both
+ // 'gots' and 'wants'. wants is expected to reach 0, as all files must
+ // be present. Remaining files in 'gots', is returned to caller to decide
+ // whether this is valid or not.
+ for i := 0; i < len(gots); i++ {
+ got := gots[i]
+ want, ok := wants[got.Name]
+ if !ok {
+ continue
+ }
+ if want.Type != got.Type {
+ return gots, fmt.Errorf("wrong file type, want: %v, got: %v: %+v", want.Type, got.Type, got)
+ }
+ if want.NextOff != 0 && want.NextOff != got.NextOff {
+ return gots, fmt.Errorf("wrong dirent offset, want: %v, got: %v: %+v", want.NextOff, got.NextOff, got)
+ }
+
+ delete(wants, got.Name)
+ gots = append(gots[0:i], gots[i+1:]...)
+ i--
+ }
+ if len(wants) != 0 {
+ return gots, fmt.Errorf("not all files were found, missing: %+v", wants)
+ }
+ return gots, nil
+}
+
+func setup() (context.Context, *vfs.VirtualFilesystem, vfs.VirtualDentry, error) {
+ k, err := boot()
+ if err != nil {
+ return nil, nil, vfs.VirtualDentry{}, fmt.Errorf("creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("procfs", &procFSType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "procfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace(): %v", err)
+ }
+ return ctx, vfsObj, mntns.Root(), nil
+}
+
+func TestTasksEmpty(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt failed: %v", err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ cb.dirents, err = checkTasksStaticFiles(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+}
+
+func TestTasks(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ var tasks []*kernel.Task
+ for i := 0; i < 5; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := createTask(ctx, fmt.Sprintf("name-%d", i), tc)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ tasks = append(tasks, task)
+ }
+
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ cb.dirents, err = checkTasksStaticFiles(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ lastPid := 0
+ for _, d := range cb.dirents {
+ pid, err := strconv.Atoi(d.Name)
+ if err != nil {
+ t.Fatalf("Invalid process directory %q", d.Name)
+ }
+ if lastPid > pid {
+ t.Errorf("pids not in order: %v", cb.dirents)
+ }
+ found := false
+ for _, t := range tasks {
+ if k.TaskSet().Root.IDOfTask(t) == kernel.ThreadID(pid) {
+ found = true
+ }
+ }
+ if !found {
+ t.Errorf("Additional task ID %d listed: %v", pid, tasks)
+ }
+ // Next offset starts at 256+2 ('self' and 'thread-self'), then adds the
+ // PID, and adds 1 for the next offset.
+ if want := int64(256 + 2 + pid + 1); d.NextOff != want {
+ t.Errorf("Wrong dirent offset want: %d got: %d: %+v", want, d.NextOff, d)
+ }
+ }
+
+ // Test lookup.
+ for _, path := range []string{"/1", "/2"} {
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(path)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err)
+ }
+ buf := make([]byte, 1)
+ bufIOSeq := usermem.BytesIOSequence(buf)
+ if _, err := fd.Read(ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR {
+ t.Errorf("wrong error reading directory: %v", err)
+ }
+ }
+
+ if _, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/9999")},
+ &vfs.OpenOptions{},
+ ); err != syserror.ENOENT {
+ t.Fatalf("wrong error from vfsfs.OpenAt(/9999): %v", err)
+ }
+}
+
+func TestTasksOffset(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ for i := 0; i < 3; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ if _, err := createTask(ctx, fmt.Sprintf("name-%d", i), tc); err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ }
+
+ for _, tc := range []struct {
+ name string
+ offset int64
+ wants map[string]vfs.Dirent
+ }{
+ {
+ name: "small offset",
+ offset: 100,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "offset at start",
+ offset: 256,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip /proc/self",
+ offset: 257,
+ wants: map[string]vfs.Dirent{
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip symlinks",
+ offset: 258,
+ wants: map[string]vfs.Dirent{
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip first process",
+ offset: 260,
+ wants: map[string]vfs.Dirent{
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "last process",
+ offset: 261,
+ wants: map[string]vfs.Dirent{
+ "3": proc3,
+ },
+ },
+ {
+ name: "after last",
+ offset: 262,
+ wants: nil,
+ },
+ {
+ name: "TaskLimit+1",
+ offset: kernel.TasksLimit + 1,
+ wants: nil,
+ },
+ {
+ name: "max",
+ offset: math.MaxInt64,
+ wants: nil,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+ if _, err := fd.Impl().Seek(ctx, tc.offset, linux.SEEK_SET); err != nil {
+ t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ if cb.dirents, err = checkFiles(cb.dirents, tc.wants); err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+ })
+ }
+}
+
+func TestTask(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ _, err = createTask(ctx, "name", tc)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/1")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/1) failed: %v", err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ cb.dirents, err = checkTaskStaticFiles(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+}
+
+func TestProcSelf(t *testing.T) {
+ ctx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(ctx)
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := createTask(ctx, "name", tc)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+
+ fd, err := vfsObj.OpenAt(
+ task,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/self/"), FollowFinalSymlink: true},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/self/) failed: %v", err)
+ }
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ cb.dirents, err = checkTaskStaticFiles(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if len(cb.dirents) != 0 {
+ t.Errorf("found more files than expected: %+v", cb.dirents)
+ }
+}
+
+func iterateDir(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, fd *vfs.FileDescription) {
+ t.Logf("Iterating: /proc%s", fd.MappedName(ctx))
+
+ cb := testIterDirentsCallback{}
+ if err := fd.Impl().IterDirents(ctx, &cb); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ var err error
+ cb.dirents, err = checkDots(cb.dirents)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ for _, d := range cb.dirents {
+ childPath := path.Join(fd.MappedName(ctx), d.Name)
+ if d.Type == linux.DT_LNK {
+ link, err := vfsObj.ReadlinkAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(childPath)},
+ )
+ if err != nil {
+ t.Errorf("vfsfs.ReadlinkAt(%v) failed: %v", childPath, err)
+ } else {
+ t.Logf("Skipping symlink: /proc%s => %s", childPath, link)
+ }
+ continue
+ }
+
+ t.Logf("Opening: /proc%s", childPath)
+ child, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse(childPath)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Errorf("vfsfs.OpenAt(%v) failed: %v", childPath, err)
+ continue
+ }
+ stat, err := child.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Errorf("Stat(%v) failed: %v", childPath, err)
+ }
+ if got := linux.FileMode(stat.Mode).DirentType(); got != d.Type {
+ t.Errorf("wrong file mode, stat: %v, dirent: %v", got, d.Type)
+ }
+ if d.Type == linux.DT_DIR {
+ // Found another dir, let's do it again!
+ iterateDir(ctx, t, vfsObj, root, child)
+ }
+ }
+}
+
+// TestTree iterates all directories and stats every file.
+func TestTree(t *testing.T) {
+ uberCtx, vfsObj, root, err := setup()
+ if err != nil {
+ t.Fatalf("Setup failed: %v", err)
+ }
+ defer root.DecRef()
+
+ k := kernel.KernelFromContext(uberCtx)
+ var tasks []*kernel.Task
+ for i := 0; i < 5; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := createTask(uberCtx, fmt.Sprintf("name-%d", i), tc)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ tasks = append(tasks, task)
+ }
+
+ ctx := tasks[0]
+ fd, err := vfsObj.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(uberCtx),
+ &vfs.PathOperation{Root: root, Start: root, Path: fspath.Parse("/")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+ iterateDir(ctx, t, vfsObj, root, fd)
+}
diff --git a/pkg/sentry/fsimpl/proc/version.go b/pkg/sentry/fsimpl/proc/version.go
deleted file mode 100644
index e1643d4e0..000000000
--- a/pkg/sentry/fsimpl/proc/version.go
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// versionData implements vfs.DynamicBytesSource for /proc/version.
-//
-// +stateify savable
-type versionData struct {
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*versionData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- init := v.k.GlobalInit()
- if init == nil {
- // Attempted to read before the init Task is created. This can
- // only occur during startup, which should never need to read
- // this file.
- panic("Attempted to read version before initial Task is available")
- }
-
- // /proc/version takes the form:
- //
- // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
- // (COMPILER_VERSION) VERSION"
- //
- // where:
- // - SYSNAME, RELEASE, and VERSION are the same as returned by
- // sys_utsname
- // - COMPILE_USER is the user that build the kernel
- // - COMPILE_HOST is the hostname of the machine on which the kernel
- // was built
- // - COMPILER_VERSION is the version reported by the building compiler
- //
- // Since we don't really want to expose build information to
- // applications, those fields are omitted.
- //
- // FIXME(mpratt): Using Version from the init task SyscallTable
- // disregards the different version a task may have (e.g., in a uts
- // namespace).
- ver := init.Leader().SyscallTable().Version
- fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 5689bed3b..7601c7c04 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -1,14 +1,13 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "dentry_list",
out = "dentry_list.go",
- package = "memfs",
+ package = "tmpfs",
prefix = "dentry",
template = "//pkg/ilist:generic_list",
types = {
@@ -18,27 +17,38 @@ go_template_instance(
)
go_library(
- name = "memfs",
+ name = "tmpfs",
srcs = [
"dentry_list.go",
"directory.go",
"filesystem.go",
- "memfs.go",
"named_pipe.go",
"regular_file.go",
"symlink.go",
+ "tmpfs.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs",
+ importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs",
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
"//pkg/fspath",
+ "//pkg/log",
"//pkg/sentry/arch",
"//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/safemem",
+ "//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
],
)
@@ -48,7 +58,7 @@ go_test(
size = "small",
srcs = ["benchmark_test.go"],
deps = [
- ":memfs",
+ ":tmpfs",
"//pkg/abi/linux",
"//pkg/fspath",
"//pkg/refs",
@@ -63,16 +73,21 @@ go_test(
)
go_test(
- name = "memfs_test",
+ name = "tmpfs_test",
size = "small",
- srcs = ["pipe_test.go"],
- embed = [":memfs"],
+ srcs = [
+ "pipe_test.go",
+ "regular_file_test.go",
+ "stat_test.go",
+ ],
+ embed = [":tmpfs"],
deps = [
"//pkg/abi/linux",
"//pkg/fspath",
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index 6e987af88..d88c83499 100644
--- a/pkg/sentry/fsimpl/memfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -27,7 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -176,8 +176,10 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -365,8 +367,10 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -395,7 +399,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
}
defer mountPoint.DecRef()
// Create and mount the submount.
- if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil {
+ if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
filePathBuilder.WriteString(mountPointName)
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index 0bd82e480..887ca2619 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 4a83f310c..4cd7e9aea 100644
--- a/pkg/sentry/fsimpl/memfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"fmt"
@@ -50,13 +50,14 @@ afterSymlink:
return nil, err
}
if nextVFSD == nil {
- // Since the Dentry tree is the sole source of truth for memfs, if it's
+ // Since the Dentry tree is the sole source of truth for tmpfs, if it's
// not in the Dentry tree, it doesn't exist.
return nil, syserror.ENOENT
}
next := nextVFSD.Impl().(*dentry)
if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
- // TODO: symlink traversals update access time
+ // TODO(gvisor.dev/issues/1197): Symlink traversals updates
+ // access time.
if err := rp.HandleSymlink(symlink.target); err != nil {
return nil, err
}
@@ -348,13 +349,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32,
}
// mnt.EndWrite() is called by regularFileFD.Release().
}
- mnt.IncRef()
- d.IncRef()
fd.vfsfd.Init(&fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{})
if flags&linux.O_TRUNC != 0 {
impl.mu.Lock()
- impl.data = impl.data[:0]
- atomic.StoreInt64(&impl.dataLen, 0)
+ impl.data.Truncate(0, impl.memFile)
+ atomic.StoreUint64(&impl.size, 0)
impl.mu.Unlock()
}
return &fd.vfsfd, nil
@@ -364,8 +363,6 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32,
return nil, syserror.EISDIR
}
var fd directoryFD
- mnt.IncRef()
- d.IncRef()
fd.vfsfd.Init(&fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
case *symlink:
@@ -505,7 +502,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
oldParent.inode.decLinksLocked()
newParent.inode.incLinksLocked()
}
- // TODO: update timestamps and parent directory sizes
+ // TODO(gvisor.dev/issues/1197): Update timestamps and parent directory
+ // sizes.
vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD)
return nil
}
@@ -559,15 +557,11 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return err
}
- if opts.Stat.Mask == 0 {
- return nil
- }
- // TODO: implement inode.setStat
- return syserror.EPERM
+ return d.inode.setStat(opts.Stat)
}
// StatAt implements vfs.FilesystemImpl.StatAt.
@@ -591,7 +585,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
if err != nil {
return linux.Statfs{}, err
}
- // TODO: actually implement statfs
+ // TODO(gvisor.dev/issues/1197): Actually implement statfs.
return linux.Statfs{}, syserror.ENOSYS
}
diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index d5060850e..40bde54de 100644
--- a/pkg/sentry/fsimpl/memfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -55,8 +55,6 @@ func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, v
return nil, err
}
mnt := rp.Mount()
- mnt.IncRef()
- vfsd.IncRef()
fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
}
diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index be917aeee..70b42a6ec 100644
--- a/pkg/sentry/fsimpl/memfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"bytes"
@@ -152,8 +152,10 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
// Create VFS.
vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
new file mode 100644
index 000000000..5fa70cc6d
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -0,0 +1,392 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "io"
+ "math"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+type regularFile struct {
+ inode inode
+
+ // memFile is a platform.File used to allocate pages to this regularFile.
+ memFile *pgalloc.MemoryFile
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ // data maps offsets into the file to offsets into memFile that store
+ // the file's data.
+ data fsutil.FileRangeSet
+
+ // size is the size of data, but accessed using atomic memory
+ // operations to avoid locking in inode.stat().
+ size uint64
+
+ // seals represents file seals on this inode.
+ seals uint32
+}
+
+func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
+ file := &regularFile{
+ memFile: fs.memFile,
+ }
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
+
+// truncate grows or shrinks the file to the given size. It returns true if the
+// file size was updated.
+func (rf *regularFile) truncate(size uint64) (bool, error) {
+ rf.mu.Lock()
+ defer rf.mu.Unlock()
+
+ if size == rf.size {
+ // Nothing to do.
+ return false, nil
+ }
+
+ if size > rf.size {
+ // Growing the file.
+ if rf.seals&linux.F_SEAL_GROW != 0 {
+ // Seal does not allow growth.
+ return false, syserror.EPERM
+ }
+ rf.size = size
+ return true, nil
+ }
+
+ // Shrinking the file
+ if rf.seals&linux.F_SEAL_SHRINK != 0 {
+ // Seal does not allow shrink.
+ return false, syserror.EPERM
+ }
+
+ // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have
+ // mappings.
+
+ rf.data.Truncate(size, rf.memFile)
+ rf.size = size
+ return true, nil
+}
+
+type regularFileFD struct {
+ fileDescription
+
+ // These are immutable.
+ readable bool
+ writable bool
+
+ // off is the file offset. off is accessed using atomic memory operations.
+ // offMu serializes operations that may mutate off.
+ off int64
+ offMu sync.Mutex
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release() {
+ if fd.writable {
+ fd.vfsfd.VirtualDentry().Mount().EndWrite()
+ }
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if !fd.readable {
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putRegularFileReadWriter(rw)
+ return int64(n), err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if !fd.writable {
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ srclen := src.NumBytes()
+ if srclen == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ end := offset + srclen
+ if end < offset {
+ // Overflow.
+ return 0, syserror.EFBIG
+ }
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := src.CopyInTo(ctx, rw)
+ putRegularFileReadWriter(rw)
+ return n, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.offMu.Lock()
+ defer fd.offMu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // use offset as specified
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size))
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *regularFileFD) Sync(ctx context.Context) error {
+ return nil
+}
+
+// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
+type regularFileReadWriter struct {
+ file *regularFile
+
+ // Offset into the file to read/write at. Note that this may be
+ // different from the FD offset if PRead/PWrite is used.
+ off uint64
+}
+
+var regularFileReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &regularFileReadWriter{}
+ },
+}
+
+func getRegularFileReadWriter(file *regularFile, offset int64) *regularFileReadWriter {
+ rw := regularFileReadWriterPool.Get().(*regularFileReadWriter)
+ rw.file = file
+ rw.off = uint64(offset)
+ return rw
+}
+
+func putRegularFileReadWriter(rw *regularFileReadWriter) {
+ rw.file = nil
+ regularFileReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ rw.file.mu.RLock()
+
+ // Compute the range to read (limited by file size and overflow-checked).
+ if rw.off >= rw.file.size {
+ rw.file.mu.RUnlock()
+ return 0, io.EOF
+ }
+ end := rw.file.size
+ if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
+ end = rend
+ }
+
+ var done uint64
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Tmpfs holes are zero-filled.
+ gapmr := gap.Range().Intersect(mr)
+ dst := dsts.TakeFirst64(gapmr.Length())
+ n, err := safemem.ZeroSeq(dst)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ rw.file.mu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+ rw.file.mu.RUnlock()
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ rw.file.mu.Lock()
+
+ // Compute the range to write (overflow-checked).
+ end := rw.off + srcs.NumBytes()
+ if end <= rw.off {
+ end = math.MaxInt64
+ }
+
+ // Check if seals prevent either file growth or all writes.
+ switch {
+ case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
+ rw.file.mu.Unlock()
+ return 0, syserror.EPERM
+ case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ // When growth is sealed, Linux effectively allows writes which would
+ // normally grow the file to partially succeed up to the current EOF,
+ // rounded down to the page boundary before the EOF.
+ //
+ // This happens because writes (and thus the growth check) for tmpfs
+ // files proceed page-by-page on Linux, and the final write to the page
+ // containing EOF fails, resulting in a partial write up to the start of
+ // that page.
+ //
+ // To emulate this behaviour, artifically truncate the write to the
+ // start of the page containing the current EOF.
+ //
+ // See Linux, mm/filemap.c:generic_perform_write() and
+ // mm/shmem.c:shmem_write_begin().
+ if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart {
+ end = pgstart
+ }
+ if end <= rw.off {
+ // Truncation would result in no data being written.
+ rw.file.mu.Unlock()
+ return 0, syserror.EPERM
+ }
+ }
+
+ // Page-aligned mr for when we need to allocate memory. RoundUp can't
+ // overflow since end is an int64.
+ pgstartaddr := usermem.Addr(rw.off).RoundDown()
+ pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
+
+ var (
+ done uint64
+ retErr error
+ )
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.off += uint64(n)
+ srcs = srcs.DropFirst64(n)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Allocate memory for the write.
+ gapMR := gap.Range().Intersect(pgMR)
+ fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Write to that memory as usual.
+ seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
+ }
+ }
+exitLoop:
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.off > rw.file.size {
+ atomic.StoreUint64(&rw.file.size, rw.off)
+ }
+
+ rw.file.mu.Unlock()
+ return done, retErr
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
new file mode 100644
index 000000000..034a29fdb
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -0,0 +1,436 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "sync/atomic"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// nextFileID is used to generate unique file names.
+var nextFileID int64
+
+// newTmpfsRoot creates a new tmpfs mount, and returns the root. If the error
+// is not nil, then cleanup should be called when the root is no longer needed.
+func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
+ }
+ root := mntns.Root()
+ return vfsObj, root, func() {
+ root.DecRef()
+ mntns.DecRef(vfsObj)
+ }, nil
+}
+
+// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If
+// the returned err is not nil, then cleanup should be called when the FD is no
+// longer needed.
+func newFileFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ filename := fmt.Sprintf("tmpfs-test-file-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the file that will be write/read.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.ModeRegular | mode,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err)
+ }
+
+ return fd, cleanup, nil
+}
+
+// newDirFD is like newFileFD, but for directories.
+func newDirFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ dirname := fmt.Sprintf("tmpfs-test-dir-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the dir.
+ if err := vfsObj.MkdirAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dirname),
+ }, &vfs.MkdirOptions{
+ Mode: linux.ModeDirectory | mode,
+ }); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create directory %q: %v", dirname, err)
+ }
+
+ // Open the dir and return it.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dirname),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to open directory %q: %v", dirname, err)
+ }
+
+ return fd, cleanup, nil
+}
+
+// newPipeFD is like newFileFD, but for pipes.
+func newPipeFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ pipename := fmt.Sprintf("tmpfs-test-pipe-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the pipe.
+ if err := vfsObj.MknodAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(pipename),
+ }, &vfs.MknodOptions{
+ Mode: linux.ModeNamedPipe | mode,
+ }); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create pipe %q: %v", pipename, err)
+ }
+
+ // Open the pipe and return it.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(pipename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to open pipe %q: %v", pipename, err)
+ }
+
+ return fd, cleanup, nil
+}
+
+// Test that we can write some data to a file and read it back.`
+func TestSimpleWriteRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write.
+ data := []byte("foobarbaz")
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+
+ // Seek back to beginning.
+ if _, err := fd.Seek(ctx, 0, linux.SEEK_SET); err != nil {
+ t.Fatalf("fd.Seek failed: %v", err)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(0); got != want {
+ t.Errorf("fd.Seek(0) left offset at %d, want %d", got, want)
+ }
+
+ // Read.
+ buf := make([]byte, len(data))
+ n, err = fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Read got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(buf), string(data); got != want {
+ t.Errorf("Read got %q want %s", got, want)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+}
+
+func TestPWrite(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill file with 1k 'a's.
+ data := bytes.Repeat([]byte{'a'}, 1000)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Write "gVisor is awesome" at various offsets.
+ buf := []byte("gVisor is awesome")
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PWrite offset=%d", offset)
+ t.Run(name, func(t *testing.T) {
+ n, err := fd.PWrite(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.WriteOptions{})
+ if err != nil {
+ t.Errorf("fd.PWrite got err %v want nil", err)
+ }
+ if n != int64(len(buf)) {
+ t.Errorf("fd.PWrite got %d bytes want %d", n, len(buf))
+ }
+
+ // Update data to reflect expected file contents.
+ if len(data) < offset+len(buf) {
+ data = append(data, make([]byte, (offset+len(buf))-len(data))...)
+ }
+ copy(data[offset:], buf)
+
+ // Read the whole file and compare with data.
+ readBuf := make([]byte, len(data))
+ n, err = fd.PRead(ctx, usermem.BytesIOSequence(readBuf), 0, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("fd.PRead failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.PRead got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(readBuf), string(data); got != want {
+ t.Errorf("PRead got %q want %s", got, want)
+ }
+
+ })
+ }
+}
+
+func TestPRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write 100 sequences of 'gVisor is awesome'.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Read various sizes from various offsets.
+ sizes := []int{0, 1, 2, 10, 20, 50, 100, 1000}
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, 1000, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+
+ for _, size := range sizes {
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PRead offset=%d size=%d", offset, size)
+ t.Run(name, func(t *testing.T) {
+ var (
+ wantRead []byte
+ wantErr error
+ )
+ if offset < len(data) {
+ wantRead = data[offset:]
+ } else if size > 0 {
+ wantErr = io.EOF
+ }
+ if offset+size < len(data) {
+ wantRead = wantRead[:size]
+ }
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.ReadOptions{})
+ if err != wantErr {
+ t.Errorf("fd.PRead got err %v want %v", err, wantErr)
+ }
+ if n != int64(len(wantRead)) {
+ t.Errorf("fd.PRead got %d bytes want %d", n, len(wantRead))
+ }
+ if got := string(buf[:n]); got != string(wantRead) {
+ t.Errorf("fd.PRead got %q want %q", got, string(wantRead))
+ }
+ })
+ }
+ }
+}
+
+func TestTruncate(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill the file with some data.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ written, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+
+ // Size should be same as written.
+ sizeStatOpts := vfs.StatOptions{Mask: linux.STATX_SIZE}
+ stat, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := int64(stat.Size), written; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+
+ // Truncate down.
+ newSize := uint64(10)
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ // Size should be updated.
+ statAfterTruncateDown, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := statAfterTruncateDown.Size, newSize; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+ // We should only read newSize worth of data.
+ buf := make([]byte, 1000)
+ if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ } else if uint64(n) != newSize {
+ t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+ }
+ // Mtime and Ctime should be bumped.
+ if got := statAfterTruncateDown.Mtime.ToNsec(); got <= stat.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want > %v", got, stat.Mtime)
+ }
+ if got := statAfterTruncateDown.Ctime.ToNsec(); got <= stat.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+ }
+
+ // Truncate up.
+ newSize = 100
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ // Size should be updated.
+ statAfterTruncateUp, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := statAfterTruncateUp.Size, newSize; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+ // We should read newSize worth of data.
+ buf = make([]byte, 1000)
+ if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ } else if uint64(n) != newSize {
+ t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+ }
+ // Bytes should be null after 10, since we previously truncated to 10.
+ for i := uint64(10); i < newSize; i++ {
+ if buf[i] != 0 {
+ t.Errorf("fd.PRead got byte %d=%x, want 0", i, buf[i])
+ break
+ }
+ }
+ // Mtime and Ctime should be bumped.
+ if got := statAfterTruncateUp.Mtime.ToNsec(); got <= statAfterTruncateDown.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want > %v", got, statAfterTruncateDown.Mtime)
+ }
+ if got := statAfterTruncateUp.Ctime.ToNsec(); got <= statAfterTruncateDown.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+ }
+
+ // Truncate to the current size.
+ newSize = statAfterTruncateUp.Size
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ statAfterTruncateNoop, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ // Mtime and Ctime should not be bumped, since operation is a noop.
+ if got := statAfterTruncateNoop.Mtime.ToNsec(); got != statAfterTruncateUp.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want %v", got, statAfterTruncateUp.Mtime)
+ }
+ if got := statAfterTruncateNoop.Ctime.ToNsec(); got != statAfterTruncateUp.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want %v", got, statAfterTruncateUp.Ctime)
+ }
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go
new file mode 100644
index 000000000..ebe035dee
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go
@@ -0,0 +1,232 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func TestStatAfterCreate(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mode := linux.FileMode(0644)
+
+ // Run with different file types.
+ // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets.
+ for _, typ := range []string{"file", "dir", "pipe"} {
+ t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
+ var (
+ fd *vfs.FileDescription
+ cleanup func()
+ err error
+ )
+ switch typ {
+ case "file":
+ fd, cleanup, err = newFileFD(ctx, mode)
+ case "dir":
+ fd, cleanup, err = newDirFD(ctx, mode)
+ case "pipe":
+ fd, cleanup, err = newPipeFD(ctx, mode)
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ got, err := fd.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Atime, Ctime, Mtime should all be current time (non-zero).
+ atime, ctime, mtime := got.Atime.ToNsec(), got.Ctime.ToNsec(), got.Mtime.ToNsec()
+ if atime != ctime || ctime != mtime {
+ t.Errorf("got atime=%d ctime=%d mtime=%d, wanted equal values", atime, ctime, mtime)
+ }
+ if atime == 0 {
+ t.Errorf("got atime=%d, want non-zero", atime)
+ }
+
+ // Btime should be 0, as it is not set by tmpfs.
+ if btime := got.Btime.ToNsec(); btime != 0 {
+ t.Errorf("got btime %d, want 0", got.Btime.ToNsec())
+ }
+
+ // Size should be 0.
+ if got.Size != 0 {
+ t.Errorf("got size %d, want 0", got.Size)
+ }
+
+ // Nlink should be 1 for files, 2 for dirs.
+ wantNlink := uint32(1)
+ if typ == "dir" {
+ wantNlink = 2
+ }
+ if got.Nlink != wantNlink {
+ t.Errorf("got nlink %d, want %d", got.Nlink, wantNlink)
+ }
+
+ // UID and GID are set from context creds.
+ creds := auth.CredentialsFromContext(ctx)
+ if got.UID != uint32(creds.EffectiveKUID) {
+ t.Errorf("got uid %d, want %d", got.UID, uint32(creds.EffectiveKUID))
+ }
+ if got.GID != uint32(creds.EffectiveKGID) {
+ t.Errorf("got gid %d, want %d", got.GID, uint32(creds.EffectiveKGID))
+ }
+
+ // Mode.
+ wantMode := uint16(mode)
+ switch typ {
+ case "file":
+ wantMode |= linux.S_IFREG
+ case "dir":
+ wantMode |= linux.S_IFDIR
+ case "pipe":
+ wantMode |= linux.S_IFIFO
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+
+ if got.Mode != wantMode {
+ t.Errorf("got mode %x, want %x", got.Mode, wantMode)
+ }
+
+ // Ino.
+ if got.Ino == 0 {
+ t.Errorf("got ino %d, want not 0", got.Ino)
+ }
+ })
+ }
+}
+
+func TestSetStatAtime(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL}
+
+ // Get initial stat.
+ initialStat, err := fd.Stat(ctx, allStatOptions)
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Set atime, but without the mask.
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{
+ Mask: 0,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }}); err != nil {
+ t.Errorf("SetStat atime without mask failed: %v")
+ }
+ // Atime should be unchanged.
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != initialStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime)
+ }
+
+ // Set atime, this time included in the mask.
+ setStat := linux.Statx{
+ Mask: linux.STATX_ATIME,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
+ t.Errorf("SetStat atime with mask failed: %v")
+ }
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != setStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime)
+ }
+}
+
+func TestSetStat(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mode := linux.FileMode(0644)
+
+ // Run with different file types.
+ // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets.
+ for _, typ := range []string{"file", "dir", "pipe"} {
+ t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
+ var (
+ fd *vfs.FileDescription
+ cleanup func()
+ err error
+ )
+ switch typ {
+ case "file":
+ fd, cleanup, err = newFileFD(ctx, mode)
+ case "dir":
+ fd, cleanup, err = newDirFD(ctx, mode)
+ case "pipe":
+ fd, cleanup, err = newPipeFD(ctx, mode)
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL}
+
+ // Get initial stat.
+ initialStat, err := fd.Stat(ctx, allStatOptions)
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Set atime, but without the mask.
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{
+ Mask: 0,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }}); err != nil {
+ t.Errorf("SetStat atime without mask failed: %v")
+ }
+ // Atime should be unchanged.
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != initialStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime)
+ }
+
+ // Set atime, this time included in the mask.
+ setStat := linux.Statx{
+ Mask: linux.STATX_ATIME,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
+ t.Errorf("SetStat atime with mask failed: %v")
+ }
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != setStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/memfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index b2ac2cbeb..5246aca84 100644
--- a/pkg/sentry/fsimpl/memfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 8d0167c93..1d4889c89 100644
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -12,31 +12,29 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package memfs provides a filesystem implementation that behaves like tmpfs:
+// Package tmpfs provides a filesystem implementation that behaves like tmpfs:
// the Dentry tree is the sole source of truth for the state of the filesystem.
//
-// memfs is intended primarily to demonstrate filesystem implementation
-// patterns. Real uses cases for an in-memory filesystem should use tmpfs
-// instead.
-//
// Lock order:
//
// filesystem.mu
// regularFileFD.offMu
// regularFile.mu
// inode.mu
-package memfs
+package tmpfs
import (
"fmt"
"math"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -47,6 +45,12 @@ type FilesystemType struct{}
type filesystem struct {
vfsfs vfs.Filesystem
+ // memFile is used to allocate pages to for regular files.
+ memFile *pgalloc.MemoryFile
+
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock time.Clock
+
// mu serializes changes to the Dentry tree.
mu sync.RWMutex
@@ -55,7 +59,15 @@ type filesystem struct {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- var fs filesystem
+ memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
+ if memFileProvider == nil {
+ panic("MemoryFileProviderFromContext returned nil")
+ }
+ clock := time.RealtimeClockFromContext(ctx)
+ fs := filesystem{
+ memFile: memFileProvider.MemoryFile(),
+ clock: clock,
+ }
fs.vfsfs.Init(vfsObj, &fs)
root := fs.newDentry(fs.newDirectory(creds, 01777))
return &fs.vfsfs, &root.vfsd, nil
@@ -74,11 +86,11 @@ type dentry struct {
// immutable.
inode *inode
- // memfs doesn't count references on dentries; because the dentry tree is
+ // tmpfs doesn't count references on dentries; because the dentry tree is
// the sole source of truth, it is by definition always consistent with the
// state of the filesystem. However, it does count references on inodes,
// because inode resources are released when all references are dropped.
- // (memfs doesn't really have resources to release, but we implement
+ // (tmpfs doesn't really have resources to release, but we implement
// reference counting because tmpfs regular files will.)
// dentryEntry (ugh) links dentries into their parent directory.childList.
@@ -110,6 +122,9 @@ func (d *dentry) DecRef() {
// inode represents a filesystem object.
type inode struct {
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock time.Clock
+
// refs is a reference count. refs is accessed using atomic memory
// operations.
//
@@ -120,26 +135,37 @@ type inode struct {
// filesystem.RmdirAt() drops the reference.
refs int64
- // Inode metadata; protected by mu and accessed using atomic memory
- // operations unless otherwise specified.
- mu sync.RWMutex
+ // Inode metadata. Writing multiple fields atomically requires holding
+ // mu, othewise atomic operations can be used.
+ mu sync.Mutex
mode uint32 // excluding file type bits, which are based on impl
nlink uint32 // protected by filesystem.mu instead of inode.mu
uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
gid uint32 // auth.KGID, but ...
ino uint64 // immutable
+ // Linux's tmpfs has no concept of btime.
+ atime int64 // nanoseconds
+ ctime int64 // nanoseconds
+ mtime int64 // nanoseconds
+
impl interface{} // immutable
}
const maxLinks = math.MaxUint32
func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
+ i.clock = fs.clock
i.refs = 1
i.mode = uint32(mode)
i.uid = uint32(creds.EffectiveKUID)
i.gid = uint32(creds.EffectiveKGID)
i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1)
+ // Tmpfs creation sets atime, ctime, and mtime to current time.
+ now := i.clock.Now().Nanoseconds()
+ i.atime = now
+ i.ctime = now
+ i.mtime = now
// i.nlink initialized by caller
i.impl = impl
}
@@ -150,7 +176,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials,
// i.nlink < maxLinks.
func (i *inode) incLinksLocked() {
if i.nlink == 0 {
- panic("memfs.inode.incLinksLocked() called with no existing links")
+ panic("tmpfs.inode.incLinksLocked() called with no existing links")
}
if i.nlink == maxLinks {
panic("memfs.inode.incLinksLocked() called with maximum link count")
@@ -163,14 +189,14 @@ func (i *inode) incLinksLocked() {
// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
func (i *inode) decLinksLocked() {
if i.nlink == 0 {
- panic("memfs.inode.decLinksLocked() called with no existing links")
+ panic("tmpfs.inode.decLinksLocked() called with no existing links")
}
atomic.AddUint32(&i.nlink, ^uint32(0))
}
func (i *inode) incRef() {
if atomic.AddInt64(&i.refs, 1) <= 1 {
- panic("memfs.inode.incRef() called without holding a reference")
+ panic("tmpfs.inode.incRef() called without holding a reference")
}
}
@@ -189,14 +215,14 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
// This is unnecessary; it's mostly to simulate what tmpfs would do.
- if regfile, ok := i.impl.(*regularFile); ok {
- regfile.mu.Lock()
- regfile.data = nil
- atomic.StoreInt64(&regfile.dataLen, 0)
- regfile.mu.Unlock()
+ if regFile, ok := i.impl.(*regularFile); ok {
+ regFile.mu.Lock()
+ regFile.data.DropAll(regFile.memFile)
+ atomic.StoreUint64(&regFile.size, 0)
+ regFile.mu.Unlock()
}
} else if refs < 0 {
- panic("memfs.inode.decRef() called without holding a reference")
+ panic("tmpfs.inode.decRef() called without holding a reference")
}
}
@@ -207,20 +233,29 @@ func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, i
// Go won't inline this function, and returning linux.Statx (which is quite
// big) means spending a lot of time in runtime.duffcopy(), so instead it's an
// output parameter.
+//
+// Note that Linux does not guarantee to return consistent data (in the case of
+// a concurrent modification), so we do not require holding inode.mu.
func (i *inode) statTo(stat *linux.Statx) {
- stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK |
+ linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_ATIME |
+ linux.STATX_BTIME | linux.STATX_CTIME | linux.STATX_MTIME
stat.Blksize = 1 // usermem.PageSize in tmpfs
stat.Nlink = atomic.LoadUint32(&i.nlink)
stat.UID = atomic.LoadUint32(&i.uid)
stat.GID = atomic.LoadUint32(&i.gid)
stat.Mode = uint16(atomic.LoadUint32(&i.mode))
stat.Ino = i.ino
- // TODO: device number
+ // Linux's tmpfs has no concept of btime, so zero-value is returned.
+ stat.Atime = linux.NsecToStatxTimestamp(i.atime)
+ stat.Ctime = linux.NsecToStatxTimestamp(i.ctime)
+ stat.Mtime = linux.NsecToStatxTimestamp(i.mtime)
+ // TODO(gvisor.dev/issues/1197): Device number.
switch impl := i.impl.(type) {
case *regularFile:
stat.Mode |= linux.S_IFREG
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
- stat.Size = uint64(atomic.LoadInt64(&impl.dataLen))
+ stat.Size = uint64(atomic.LoadUint64(&impl.size))
// In tmpfs, this will be FileRangeSet.Span() / 512 (but also cached in
// a uint64 accessed using atomic memory operations to avoid taking
// locks).
@@ -239,6 +274,75 @@ func (i *inode) statTo(stat *linux.Statx) {
}
}
+func (i *inode) setStat(stat linux.Statx) error {
+ if stat.Mask == 0 {
+ return nil
+ }
+ i.mu.Lock()
+ var (
+ needsMtimeBump bool
+ needsCtimeBump bool
+ )
+ mask := stat.Mask
+ if mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&i.mode, uint32(stat.Mode))
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&i.uid, stat.UID)
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&i.gid, stat.GID)
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ switch impl := i.impl.(type) {
+ case *regularFile:
+ updated, err := impl.truncate(stat.Size)
+ if err != nil {
+ return err
+ }
+ if updated {
+ needsMtimeBump = true
+ needsCtimeBump = true
+ }
+ case *directory:
+ return syserror.EISDIR
+ case *symlink:
+ return syserror.EINVAL
+ case *namedPipe:
+ // Nothing.
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", i.impl))
+ }
+ }
+ if mask&linux.STATX_ATIME != 0 {
+ atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_MTIME != 0 {
+ atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+ needsCtimeBump = true
+ // Ignore the mtime bump, since we just set it ourselves.
+ needsMtimeBump = false
+ }
+ if mask&linux.STATX_CTIME != 0 {
+ atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+ // Ignore the ctime bump, since we just set it ourselves.
+ needsCtimeBump = false
+ }
+ now := i.clock.Now().Nanoseconds()
+ if needsMtimeBump {
+ atomic.StoreInt64(&i.mtime, now)
+ }
+ if needsCtimeBump {
+ atomic.StoreInt64(&i.ctime, now)
+ }
+ i.mu.Unlock()
+ return nil
+}
+
// allocatedBlocksForSize returns the number of 512B blocks needed to
// accommodate the given size in bytes, as appropriate for struct
// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block
@@ -261,7 +365,7 @@ func (i *inode) direntType() uint8 {
}
}
-// fileDescription is embedded by memfs implementations of
+// fileDescription is embedded by tmpfs implementations of
// vfs.FileDescriptionImpl.
type fileDescription struct {
vfsfd vfs.FileDescription
@@ -285,9 +389,5 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- if opts.Stat.Mask == 0 {
- return nil
- }
- // TODO: implement inode.setStat
- return syserror.EPERM
+ return fd.inode().setStat(opts.Stat)
}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 2706927ff..ac85ba0c8 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -35,7 +35,7 @@ go_template_instance(
out = "seqatomic_taskgoroutineschedinfo_unsafe.go",
package = "kernel",
suffix = "TaskGoroutineSchedInfo",
- template = "//pkg/syncutil:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "TaskGoroutineSchedInfo",
},
@@ -209,7 +209,7 @@ go_library(
"//pkg/sentry/usermem",
"//pkg/state",
"//pkg/state/statefile",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
@@ -241,6 +241,7 @@ go_test(
"//pkg/sentry/time",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
index 244655b5c..920fe4329 100644
--- a/pkg/sentry/kernel/abstract_socket_namespace.go
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -15,11 +15,11 @@
package kernel
import (
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
)
// +stateify savable
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 04c244447..1aa72fa47 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_credentials_unsafe.go",
package = "auth",
suffix = "Credentials",
- template = "//pkg/syncutil:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "Credentials",
},
@@ -64,6 +64,7 @@ go_library(
"//pkg/bits",
"//pkg/log",
"//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go
index af28ccc65..9dd52c860 100644
--- a/pkg/sentry/kernel/auth/user_namespace.go
+++ b/pkg/sentry/kernel/auth/user_namespace.go
@@ -16,8 +16,8 @@ package auth
import (
"math"
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index 3361e8b7d..c47f6b6fc 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 9c0a4e1b4..430311cc0 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -18,7 +18,6 @@ package epoll
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/refs"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index e65b961e8..c831fbab2 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 12f0d429b..687690679 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -18,7 +18,6 @@ package eventfd
import (
"math"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 49d81b712..6b36bc63e 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index 6b0bb0324..d32c3e90a 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -16,12 +16,11 @@
package fasync
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 11f613a11..cd1501f85 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -18,7 +18,6 @@ import (
"bytes"
"fmt"
"math"
- "sync"
"sync/atomic"
"syscall"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FDFlags define flags for an individual descriptor.
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
index 2bcb6216a..eccb7d1e7 100644
--- a/pkg/sentry/kernel/fd_table_test.go
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -16,7 +16,6 @@ package kernel
import (
"runtime"
- "sync"
"testing"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/filetest"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index ded27d668..2448c1d99 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -16,10 +16,10 @@ package kernel
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FSContext contains filesystem context.
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 75ec31761..50db443ce 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "atomicptr_bucket_unsafe.go",
package = "futex",
suffix = "Bucket",
- template = "//pkg/syncutil:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "bucket",
},
@@ -42,6 +42,7 @@ go_library(
"//pkg/sentry/context",
"//pkg/sentry/memmap",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
@@ -51,5 +52,8 @@ go_test(
size = "small",
srcs = ["futex_test.go"],
embed = [":futex"],
- deps = ["//pkg/sentry/usermem"],
+ deps = [
+ "//pkg/sentry/usermem",
+ "//pkg/sync",
+ ],
)
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 278cc8143..d1931c8f4 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -18,11 +18,10 @@
package futex
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go
index 65e5d1428..c23126ca5 100644
--- a/pkg/sentry/kernel/futex/futex_test.go
+++ b/pkg/sentry/kernel/futex/futex_test.go
@@ -17,13 +17,13 @@ package futex
import (
"math"
"runtime"
- "sync"
"sync/atomic"
"syscall"
"testing"
"unsafe"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// testData implements the Target interface, and allows us to
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index bd3fb4c03..c85e97fef 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -36,7 +36,6 @@ import (
"fmt"
"io"
"path/filepath"
- "sync"
"sync/atomic"
"time"
@@ -67,6 +66,7 @@ import (
uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -762,7 +762,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
mounts.IncRef()
}
- tg := k.newThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits, k.monotonicClock)
+ tg := k.NewThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
ctx := args.NewContext(k)
// Get the root directory from the MountNamespace.
@@ -1191,6 +1191,11 @@ func (k *Kernel) GlobalInit() *ThreadGroup {
return k.globalInit
}
+// TestOnly_SetGlobalInit sets the thread group with ID 1 in the root PID namespace.
+func (k *Kernel) TestOnly_SetGlobalInit(tg *ThreadGroup) {
+ k.globalInit = tg
+}
+
// ApplicationCores returns the number of CPUs visible to sandboxed
// applications.
func (k *Kernel) ApplicationCores() uint {
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index d7a7d1169..7f36252a9 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/metric",
"//pkg/sentry/kernel",
"//pkg/sentry/usage",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go
index b0d98e7f0..200565bb8 100644
--- a/pkg/sentry/kernel/memevent/memory_events.go
+++ b/pkg/sentry/kernel/memevent/memory_events.go
@@ -17,7 +17,6 @@
package memevent
import (
- "sync"
"time"
"gvisor.dev/gvisor/pkg/eventchannel"
@@ -26,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
)
var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.")
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 9d34f6d4d..5eeaeff66 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -43,6 +43,7 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
index 95bee2d37..1c0f34269 100644
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ b/pkg/sentry/kernel/pipe/buffer.go
@@ -16,9 +16,9 @@ package pipe
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// buffer encapsulates a queueable byte buffer.
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index 4a19ab7ce..716f589af 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -15,12 +15,11 @@
package pipe
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 1a1b38f83..e4fd7d420 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -17,12 +17,12 @@ package pipe
import (
"fmt"
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index ef9641e6a..8394eb78b 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -17,7 +17,6 @@ package pipe
import (
"io"
"math"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 6416e0dd8..bf7461cbb 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -15,13 +15,12 @@
package pipe
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 24ea002ba..b14429854 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -15,17 +15,29 @@
package kernel
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
-// Restartable sequences, as described in https://lwn.net/Articles/650333/.
+// Restartable sequences.
+//
+// We support two different APIs for restartable sequences.
+//
+// 1. The upstream interface added in v4.18.
+// 2. The interface described in https://lwn.net/Articles/650333/.
+//
+// Throughout this file and other parts of the kernel, the latter is referred
+// to as "old rseq". This interface was never merged upstream, but is supported
+// for a limited set of applications that use it regardless.
-// RSEQCriticalRegion describes a restartable sequence critical region.
+// OldRSeqCriticalRegion describes an old rseq critical region.
//
// +stateify savable
-type RSEQCriticalRegion struct {
+type OldRSeqCriticalRegion struct {
// When a task in this thread group has its CPU preempted (as defined by
// platform.ErrContextCPUPreempted) or has a signal delivered to an
// application handler while its instruction pointer is in CriticalSection,
@@ -35,86 +47,359 @@ type RSEQCriticalRegion struct {
Restart usermem.Addr
}
-// RSEQAvailable returns true if t supports restartable sequences.
-func (t *Task) RSEQAvailable() bool {
+// RSeqAvailable returns true if t supports (old and new) restartable sequences.
+func (t *Task) RSeqAvailable() bool {
return t.k.useHostCores && t.k.Platform.DetectsCPUPreemption()
}
-// RSEQCriticalRegion returns a copy of t's thread group's current restartable
-// sequence.
-func (t *Task) RSEQCriticalRegion() RSEQCriticalRegion {
- return *t.tg.rscr.Load().(*RSEQCriticalRegion)
+// SetRSeq registers addr as this thread's rseq structure.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr != 0 {
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EINVAL
+ }
+ return syserror.EBUSY
+ }
+
+ // rseq must be aligned and correctly sized.
+ if addr&(linux.AlignOfRSeq-1) != 0 {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if _, ok := t.MemoryManager().CheckIORange(addr, linux.SizeOfRSeq); !ok {
+ return syserror.EFAULT
+ }
+
+ t.rseqAddr = addr
+ t.rseqSignature = signature
+
+ // Initialize the CPUID.
+ //
+ // Linux implicitly does this on return from userspace, where failure
+ // would cause SIGSEGV.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return syserror.EFAULT
+ }
+
+ return nil
}
-// SetRSEQCriticalRegion replaces t's thread group's restartable sequence.
+// ClearRSeq unregisters addr as this thread's rseq structure.
//
-// Preconditions: t.RSEQAvailable() == true.
-func (t *Task) SetRSEQCriticalRegion(rscr RSEQCriticalRegion) error {
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) ClearRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr == 0 {
+ return syserror.EINVAL
+ }
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EPERM
+ }
+
+ if err := t.rseqClearCPU(); err != nil {
+ return err
+ }
+
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ if t.oldRSeqCPUAddr == 0 {
+ // rseqCPU no longer needed.
+ t.rseqCPU = -1
+ }
+
+ return nil
+}
+
+// OldRSeqCriticalRegion returns a copy of t's thread group's current
+// old restartable sequence.
+func (t *Task) OldRSeqCriticalRegion() OldRSeqCriticalRegion {
+ return *t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+}
+
+// SetOldRSeqCriticalRegion replaces t's thread group's old restartable
+// sequence.
+//
+// Preconditions: t.RSeqAvailable() == true.
+func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error {
// These checks are somewhat more lenient than in Linux, which (bizarrely)
- // requires rscr.CriticalSection to be non-empty and rscr.Restart to be
- // outside of rscr.CriticalSection, even if rscr.CriticalSection.Start == 0
+ // requires r.CriticalSection to be non-empty and r.Restart to be
+ // outside of r.CriticalSection, even if r.CriticalSection.Start == 0
// (which disables the critical region).
- if rscr.CriticalSection.Start == 0 {
- rscr.CriticalSection.End = 0
- rscr.Restart = 0
- t.tg.rscr.Store(&rscr)
+ if r.CriticalSection.Start == 0 {
+ r.CriticalSection.End = 0
+ r.Restart = 0
+ t.tg.oldRSeqCritical.Store(&r)
return nil
}
- if rscr.CriticalSection.Start >= rscr.CriticalSection.End {
+ if r.CriticalSection.Start >= r.CriticalSection.End {
return syserror.EINVAL
}
- if rscr.CriticalSection.Contains(rscr.Restart) {
+ if r.CriticalSection.Contains(r.Restart) {
return syserror.EINVAL
}
- // TODO(jamieliu): check that rscr.CriticalSection and rscr.Restart are in
- // the application address range, for consistency with Linux
- t.tg.rscr.Store(&rscr)
+ // TODO(jamieliu): check that r.CriticalSection and r.Restart are in
+ // the application address range, for consistency with Linux.
+ t.tg.oldRSeqCritical.Store(&r)
return nil
}
-// RSEQCPUAddr returns the address that RSEQ will keep updated with t's CPU
-// number.
+// OldRSeqCPUAddr returns the address that old rseq will keep updated with t's
+// CPU number.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) RSEQCPUAddr() usermem.Addr {
- return t.rseqCPUAddr
+func (t *Task) OldRSeqCPUAddr() usermem.Addr {
+ return t.oldRSeqCPUAddr
}
-// SetRSEQCPUAddr replaces the address that RSEQ will keep updated with t's CPU
-// number.
+// SetOldRSeqCPUAddr replaces the address that old rseq will keep updated with
+// t's CPU number.
//
-// Preconditions: t.RSEQAvailable() == true. The caller must be running on the
+// Preconditions: t.RSeqAvailable() == true. The caller must be running on the
// task goroutine. t's AddressSpace must be active.
-func (t *Task) SetRSEQCPUAddr(addr usermem.Addr) error {
- t.rseqCPUAddr = addr
- if addr != 0 {
- t.rseqCPU = int32(hostcpu.GetCPU())
- if err := t.rseqCopyOutCPU(); err != nil {
- t.rseqCPUAddr = 0
- t.rseqCPU = -1
- return syserror.EINVAL // yes, EINVAL, not err or EFAULT
- }
- } else {
- t.rseqCPU = -1
+func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error {
+ t.oldRSeqCPUAddr = addr
+
+ // Check that addr is writable.
+ //
+ // N.B. rseqUpdateCPU may fail on a bad t.rseqAddr as well. That's
+ // unfortunate, but unlikely in a correct program.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.oldRSeqCPUAddr = 0
+ return syserror.EINVAL // yes, EINVAL, not err or EFAULT
}
return nil
}
// Preconditions: The caller must be running on the task goroutine. t's
// AddressSpace must be active.
-func (t *Task) rseqCopyOutCPU() error {
+func (t *Task) rseqUpdateCPU() error {
+ if t.rseqAddr == 0 && t.oldRSeqCPUAddr == 0 {
+ t.rseqCPU = -1
+ return nil
+ }
+
+ t.rseqCPU = int32(hostcpu.GetCPU())
+
+ // Update both CPUs, even if one fails.
+ rerr := t.rseqCopyOutCPU()
+ oerr := t.oldRSeqCopyOutCPU()
+
+ if rerr != nil {
+ return rerr
+ }
+ return oerr
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) oldRSeqCopyOutCPU() error {
+ if t.oldRSeqCPUAddr == 0 {
+ return nil
+ }
+
buf := t.CopyScratchBuffer(4)
usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU))
- _, err := t.CopyOutBytes(t.rseqCPUAddr, buf)
+ _, err := t.CopyOutBytes(t.oldRSeqCPUAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqCopyOutCPU() error {
+ if t.rseqAddr == 0 {
+ return nil
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqClearCPU() error {
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, 0) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
return err
}
+// rseqAddrInterrupt checks if IP is in a critical section, and aborts if so.
+//
+// This is a bit complex since both the RSeq and RSeqCriticalSection structs
+// are stored in userspace. So we must:
+//
+// 1. Copy in the address of RSeqCriticalSection from RSeq.
+// 2. Copy in RSeqCriticalSection itself.
+// 3. Validate critical section struct version, address range, abort address.
+// 4. Validate the abort signature (4 bytes preceding abort IP match expected
+// signature).
+// 5. Clear address of RSeqCriticalSection from RSeq.
+// 6. Finally, conditionally abort.
+//
+// See kernel/rseq.c:rseq_ip_fixup for reference.
+//
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqAddrInterrupt() {
+ if t.rseqAddr == 0 {
+ return
+ }
+
+ critAddrAddr, ok := t.rseqAddr.AddLength(linux.OffsetOfRSeqCriticalSection)
+ if !ok {
+ // SetRSeq should validate this.
+ panic(fmt.Sprintf("t.rseqAddr (%#x) not large enough", t.rseqAddr))
+ }
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ t.Debugf("Only 64-bit rseq supported.")
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ if _, err := t.CopyInBytes(critAddrAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ critAddr := usermem.Addr(usermem.ByteOrder.Uint64(buf))
+ if critAddr == 0 {
+ return
+ }
+
+ buf = t.CopyScratchBuffer(linux.SizeOfRSeqCriticalSection)
+ if _, err := t.CopyInBytes(critAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Manually marshal RSeqCriticalSection as this is in the hot path when
+ // rseq is enabled. It must be as fast as possible.
+ //
+ // TODO(b/130243041): Replace with go_marshal.
+ cs := linux.RSeqCriticalSection{
+ Version: usermem.ByteOrder.Uint32(buf[0:4]),
+ Flags: usermem.ByteOrder.Uint32(buf[4:8]),
+ Start: usermem.ByteOrder.Uint64(buf[8:16]),
+ PostCommitOffset: usermem.ByteOrder.Uint64(buf[16:24]),
+ Abort: usermem.ByteOrder.Uint64(buf[24:32]),
+ }
+
+ if cs.Version != 0 {
+ t.Debugf("Unknown version in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ start := usermem.Addr(cs.Start)
+ critRange, ok := start.ToRange(cs.PostCommitOffset)
+ if !ok {
+ t.Debugf("Invalid start and offset in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ abort := usermem.Addr(cs.Abort)
+ if critRange.Contains(abort) {
+ t.Debugf("Abort in critical section in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Verify signature.
+ sigAddr := abort - linux.SizeOfRSeqSignature
+
+ buf = t.CopyScratchBuffer(linux.SizeOfRSeqSignature)
+ if _, err := t.CopyInBytes(sigAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section signature from %#x for rseq: %v", sigAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ sig := usermem.ByteOrder.Uint32(buf)
+ if sig != t.rseqSignature {
+ t.Debugf("Mismatched rseq signature %d != %d", sig, t.rseqSignature)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Clear the critical section address.
+ //
+ // NOTE(b/143949567): We don't support any rseq flags, so we always
+ // restart if we are in the critical section, and thus *always* clear
+ // critAddrAddr.
+ if _, err := t.MemoryManager().ZeroOut(t, critAddrAddr, int64(t.Arch().Width()), usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ t.Debugf("Failed to clear critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Finally we can actually decide whether or not to restart.
+ if !critRange.Contains(usermem.Addr(t.Arch().IP())) {
+ return
+ }
+
+ t.Arch().SetIP(uintptr(cs.Abort))
+}
+
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) rseqInterrupt() {
- rscr := t.tg.rscr.Load().(*RSEQCriticalRegion)
- if ip := t.Arch().IP(); rscr.CriticalSection.Contains(usermem.Addr(ip)) {
- t.Debugf("Interrupted RSEQ critical section at %#x; restarting at %#x", ip, rscr.Restart)
- t.Arch().SetIP(uintptr(rscr.Restart))
- t.Arch().SetRSEQInterruptedIP(ip)
+func (t *Task) oldRSeqInterrupt() {
+ r := t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+ if ip := t.Arch().IP(); r.CriticalSection.Contains(usermem.Addr(ip)) {
+ t.Debugf("Interrupted rseq critical section at %#x; restarting at %#x", ip, r.Restart)
+ t.Arch().SetIP(uintptr(r.Restart))
+ t.Arch().SetOldRSeqInterruptedIP(ip)
}
}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) rseqInterrupt() {
+ t.rseqAddrInterrupt()
+ t.oldRSeqInterrupt()
+}
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index f4c00cd86..13a961594 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index de9617e9d..18299814e 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -17,7 +17,6 @@ package semaphore
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index cd48945e6..7321b22ed 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -24,6 +24,7 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 5bd610f68..8ddef7eb8 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -35,7 +35,6 @@ package shm
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
@@ -49,6 +48,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -71,9 +71,20 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
// shms maps segment ids to segments.
+ //
+ // shms holds all referenced segments, which are removed on the last
+ // DecRef. Thus, it cannot itself hold a reference on the Shm.
+ //
+ // Since removal only occurs after the last (unlocked) DecRef, there
+ // exists a short window during which a Shm still exists in Shm, but is
+ // unreferenced. Users must use TryIncRef to determine if the Shm is
+ // still valid.
shms map[ID]*Shm
// keysToShms maps segment keys to segments.
+ //
+ // Shms in keysToShms are guaranteed to be referenced, as they are
+ // removed by disassociateKey before the last DecRef.
keysToShms map[Key]*Shm
// Sum of the sizes of all existing segments rounded up to page size, in
@@ -95,10 +106,18 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
}
// FindByID looks up a segment given an ID.
+//
+// FindByID returns a reference on Shm.
func (r *Registry) FindByID(id ID) *Shm {
r.mu.Lock()
defer r.mu.Unlock()
- return r.shms[id]
+ s := r.shms[id]
+ // Take a reference on s. If TryIncRef fails, s has reached the last
+ // DecRef, but hasn't quite been removed from r.shms yet.
+ if s != nil && s.TryIncRef() {
+ return s
+ }
+ return nil
}
// dissociateKey removes the association between a segment and its key,
@@ -119,6 +138,8 @@ func (r *Registry) dissociateKey(s *Shm) {
// FindOrCreate looks up or creates a segment in the registry. It's functionally
// analogous to open(2).
+//
+// FindOrCreate returns a reference on Shm.
func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) {
// "A new segment was to be created and size is less than SHMMIN or
@@ -166,6 +187,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
return nil, syserror.EEXIST
}
+ shm.IncRef()
return shm, nil
}
@@ -193,7 +215,14 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
// Need to create a new segment.
creator := fs.FileOwnerFromContext(ctx)
perms := fs.FilePermsFromMode(mode)
- return r.newShm(ctx, pid, key, creator, perms, size)
+ s, err := r.newShm(ctx, pid, key, creator, perms, size)
+ if err != nil {
+ return nil, err
+ }
+ // The initial reference is held by s itself. Take another to return to
+ // the caller.
+ s.IncRef()
+ return s, nil
}
// newShm creates a new segment in the registry.
@@ -296,22 +325,26 @@ func (r *Registry) remove(s *Shm) {
// Shm represents a single shared memory segment.
//
-// Shm segment are backed directly by an allocation from platform
-// memory. Segments are always mapped as a whole, greatly simplifying how
-// mappings are tracked. However note that mremap and munmap calls may cause the
-// vma for a segment to become fragmented; which requires special care when
-// unmapping a segment. See mm/shm.go.
+// Shm segment are backed directly by an allocation from platform memory.
+// Segments are always mapped as a whole, greatly simplifying how mappings are
+// tracked. However note that mremap and munmap calls may cause the vma for a
+// segment to become fragmented; which requires special care when unmapping a
+// segment. See mm/shm.go.
//
// Segments persist until they are explicitly marked for destruction via
-// shmctl(SHM_RMID).
+// MarkDestroyed().
//
// Shm implements memmap.Mappable and memmap.MappingIdentity.
//
// +stateify savable
type Shm struct {
- // AtomicRefCount tracks the number of references to this segment from
- // maps. A segment always holds a reference to itself, until it's marked for
+ // AtomicRefCount tracks the number of references to this segment.
+ //
+ // A segment holds a reference to itself until it is marked for
// destruction.
+ //
+ // In addition to direct users, the MemoryManager will hold references
+ // via MappingIdentity.
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
@@ -484,9 +517,8 @@ type AttachOpts struct {
// ConfigureAttach creates an mmap configuration for the segment with the
// requested attach options.
//
-// ConfigureAttach returns with a ref on s on success. The caller should drop
-// this once the map is installed. This reference prevents s from being
-// destroyed before the returned configuration is used.
+// Postconditions: The returned MMapOpts are valid only as long as a reference
+// continues to be held on s.
func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -504,7 +536,6 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts Attac
// in the user namespace that governs its IPC namespace." - man shmat(2)
return memmap.MMapOpts{}, syserror.EACCES
}
- s.IncRef()
return memmap.MMapOpts{
Length: s.size,
Offset: 0,
@@ -549,10 +580,15 @@ func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
}
creds := auth.CredentialsFromContext(ctx)
- nattach := uint64(s.ReadRefs())
- // Don't report the self-reference we keep prior to being marked for
- // destruction. However, also don't report a count of -1 for segments marked
- // as destroyed, with no mappings.
+ // Use the reference count as a rudimentary count of the number of
+ // attaches. We exclude:
+ //
+ // 1. The reference the caller holds.
+ // 2. The self-reference held by s prior to destruction.
+ //
+ // Note that this may still overcount by including transient references
+ // used in concurrent calls.
+ nattach := uint64(s.ReadRefs()) - 1
if !s.pendingDestruction {
nattach--
}
@@ -620,18 +656,17 @@ func (s *Shm) MarkDestroyed() {
s.registry.dissociateKey(s)
s.mu.Lock()
- // Only drop the segment's self-reference once, when destruction is
- // requested. Otherwise, repeated calls to shmctl(IPC_RMID) would force a
- // segment to be destroyed prematurely, potentially with active maps to the
- // segment's address range. Remaining references are dropped when the
- // segment is detached or unmaped.
+ defer s.mu.Unlock()
if !s.pendingDestruction {
s.pendingDestruction = true
- s.mu.Unlock() // Must release s.mu before calling s.DecRef.
+ // Drop the self-reference so destruction occurs when all
+ // external references are gone.
+ //
+ // N.B. This cannot be the final DecRef, as the caller also
+ // holds a reference.
s.DecRef()
return
}
- s.mu.Unlock()
}
// checkOwnership verifies whether a segment may be accessed by ctx as an
diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go
index a16f3d57f..768fda220 100644
--- a/pkg/sentry/kernel/signal_handlers.go
+++ b/pkg/sentry/kernel/signal_handlers.go
@@ -15,10 +15,9 @@
package kernel
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
)
// SignalHandlers holds information about signal actions.
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 9f7e19b4d..89e4d84b1 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 4b08d7d72..28be4a939 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -16,8 +16,6 @@
package signalfd
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -26,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 2fdee0282..d2d01add4 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -16,13 +16,13 @@ package kernel
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// maxSyscallNum is the highest supported syscall number.
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
index 8227ecf1d..4607cde2f 100644
--- a/pkg/sentry/kernel/syslog.go
+++ b/pkg/sentry/kernel/syslog.go
@@ -17,7 +17,8 @@ package kernel
import (
"fmt"
"math/rand"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// syslog represents a sentry-global kernel log.
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index ab0c6c4aa..978d66da8 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -17,7 +17,6 @@ package kernel
import (
gocontext "context"
"runtime/trace"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -37,7 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -85,7 +84,7 @@ type Task struct {
//
// gosched is protected by goschedSeq. gosched is owned by the task
// goroutine.
- goschedSeq syncutil.SeqCount `state:"nosave"`
+ goschedSeq sync.SeqCount `state:"nosave"`
gosched TaskGoroutineSchedInfo
// yieldCount is the number of times the task goroutine has called
@@ -489,18 +488,43 @@ type Task struct {
// netns is protected by mu. netns is owned by the task goroutine.
netns bool
- // If rseqPreempted is true, before the next call to p.Switch(), interrupt
- // RSEQ critical regions as defined by tg.rseq and write the task
- // goroutine's CPU number to rseqCPUAddr. rseqCPU is the last CPU number
- // written to rseqCPUAddr.
+ // If rseqPreempted is true, before the next call to p.Switch(),
+ // interrupt rseq critical regions as defined by rseqAddr and
+ // tg.oldRSeqCritical and write the task goroutine's CPU number to
+ // rseqAddr/oldRSeqCPUAddr.
//
- // If rseqCPUAddr is 0, rseqCPU is -1.
+ // We support two ABIs for restartable sequences:
//
- // rseqCPUAddr, rseqCPU, and rseqPreempted are exclusive to the task
- // goroutine.
+ // 1. The upstream interface added in v4.18,
+ // 2. An "old" interface never merged upstream. In the implementation,
+ // this is referred to as "old rseq".
+ //
+ // rseqPreempted is exclusive to the task goroutine.
rseqPreempted bool `state:"nosave"`
- rseqCPUAddr usermem.Addr
- rseqCPU int32
+
+ // rseqCPU is the last CPU number written to rseqAddr/oldRSeqCPUAddr.
+ //
+ // If rseq is unused, rseqCPU is -1 for convenient use in
+ // platform.Context.Switch.
+ //
+ // rseqCPU is exclusive to the task goroutine.
+ rseqCPU int32
+
+ // oldRSeqCPUAddr is a pointer to the userspace old rseq CPU variable.
+ //
+ // oldRSeqCPUAddr is exclusive to the task goroutine.
+ oldRSeqCPUAddr usermem.Addr
+
+ // rseqAddr is a pointer to the userspace linux.RSeq structure.
+ //
+ // rseqAddr is exclusive to the task goroutine.
+ rseqAddr usermem.Addr
+
+ // rseqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ //
+ // rseqSignature is exclusive to the task goroutine.
+ rseqSignature uint32
// copyScratchBuffer is a buffer available to CopyIn/CopyOut
// implementations that require an intermediate buffer to copy data
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 3eadfedb4..247bd4aba 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -236,14 +236,19 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else if opts.NewPIDNamespace {
pidns = pidns.NewChild(userns)
}
+
tg := t.tg
+ rseqAddr := usermem.Addr(0)
+ rseqSignature := uint32(0)
if opts.NewThreadGroup {
tg.mounts.IncRef()
sh := t.tg.signalHandlers
if opts.NewSignalHandlers {
sh = sh.Fork()
}
- tg = t.k.newThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy(), t.k.monotonicClock)
+ tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy())
+ rseqAddr = t.rseqAddr
+ rseqSignature = t.rseqSignature
}
cfg := &TaskConfig{
@@ -260,6 +265,8 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
UTSNamespace: utsns,
IPCNamespace: ipcns,
AbstractSocketNamespace: t.abstractSockets,
+ RSeqAddr: rseqAddr,
+ RSeqSignature: rseqSignature,
ContainerID: t.ContainerID(),
}
if opts.NewThreadGroup {
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 90a6190f1..fa6528386 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -190,9 +190,11 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.updateRSSLocked()
// Restartable sequence state is discarded.
t.rseqPreempted = false
- t.rseqCPUAddr = 0
t.rseqCPU = -1
- t.tg.rscr.Store(&RSEQCriticalRegion{})
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+ t.oldRSeqCPUAddr = 0
+ t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
t.tg.pidns.owner.mu.Unlock()
// Remove FDs with the CloseOnExec flag set.
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index d97f8c189..6357273d3 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -169,12 +169,22 @@ func (*runApp) execute(t *Task) taskRunState {
// Apply restartable sequences.
if t.rseqPreempted {
t.rseqPreempted = false
- if t.rseqCPUAddr != 0 {
+ if t.rseqAddr != 0 || t.oldRSeqCPUAddr != 0 {
+ // Linux writes the CPU on every preemption. We only do
+ // so if it changed. Thus we may delay delivery of
+ // SIGSEGV if rseqAddr/oldRSeqCPUAddr is invalid.
cpu := int32(hostcpu.GetCPU())
if t.rseqCPU != cpu {
t.rseqCPU = cpu
if err := t.rseqCopyOutCPU(); err != nil {
- t.Warningf("Failed to copy CPU to %#x for RSEQ: %v", t.rseqCPUAddr, err)
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // Re-enter the task run loop for signal delivery.
+ return (*runApp)(nil)
+ }
+ if err := t.oldRSeqCopyOutCPU(); err != nil {
+ t.Debugf("Failed to copy CPU to %#x for old rseq: %v", t.oldRSeqCPUAddr, err)
t.forceSignal(linux.SIGSEGV, false)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
// Re-enter the task run loop for signal delivery.
@@ -320,7 +330,7 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runApp)(nil)
case platform.ErrContextCPUPreempted:
- // Ensure that RSEQ critical sections are interrupted and per-thread
+ // Ensure that rseq critical sections are interrupted and per-thread
// CPU values are updated before the next platform.Context.Switch().
t.rseqPreempted = true
return (*runApp)(nil)
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 3522a4ae5..58af16ee2 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -79,6 +80,13 @@ type TaskConfig struct {
// AbstractSocketNamespace is the AbstractSocketNamespace of the new task.
AbstractSocketNamespace *AbstractSocketNamespace
+ // RSeqAddr is a pointer to the the userspace linux.RSeq structure.
+ RSeqAddr usermem.Addr
+
+ // RSeqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ RSeqSignature uint32
+
// ContainerID is the container the new task belongs to.
ContainerID string
}
@@ -126,6 +134,8 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
ipcns: cfg.IPCNamespace,
abstractSockets: cfg.AbstractSocketNamespace,
rseqCPU: -1,
+ rseqAddr: cfg.RSeqAddr,
+ rseqSignature: cfg.RSeqSignature,
futexWaiter: futex.NewWaiter(),
containerID: cfg.ContainerID,
}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 72568d296..768e958d2 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -15,7 +15,6 @@
package kernel
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -238,8 +238,8 @@ type ThreadGroup struct {
// execed is protected by the TaskSet mutex.
execed bool
- // rscr is the thread group's RSEQ critical region.
- rscr atomic.Value `state:".(*RSEQCriticalRegion)"`
+ // oldRSeqCritical is the thread group's old rseq critical region.
+ oldRSeqCritical atomic.Value `state:".(*OldRSeqCriticalRegion)"`
// mounts is the thread group's mount namespace. This does not really
// correspond to a "mount namespace" in Linux, but is more like a
@@ -256,35 +256,35 @@ type ThreadGroup struct {
tty *TTY
}
-// newThreadGroup returns a new, empty thread group in PID namespace ns. The
+// NewThreadGroup returns a new, empty thread group in PID namespace ns. The
// thread group leader will send its parent terminationSignal when it exits.
// The new thread group isn't visible to the system until a task has been
// created inside of it by a successful call to TaskSet.NewTask.
-func (k *Kernel) newThreadGroup(mounts *fs.MountNamespace, ns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet, monotonicClock *timekeeperClock) *ThreadGroup {
+func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet) *ThreadGroup {
tg := &ThreadGroup{
threadGroupNode: threadGroupNode{
- pidns: ns,
+ pidns: pidns,
},
signalHandlers: sh,
terminationSignal: terminationSignal,
ioUsage: &usage.IO{},
limits: limits,
- mounts: mounts,
+ mounts: mntns,
}
tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg})
tg.timers = make(map[linux.TimerID]*IntervalTimer)
- tg.rscr.Store(&RSEQCriticalRegion{})
+ tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
return tg
}
-// saveRscr is invoked by stateify.
-func (tg *ThreadGroup) saveRscr() *RSEQCriticalRegion {
- return tg.rscr.Load().(*RSEQCriticalRegion)
+// saveOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) saveOldRSeqCritical() *OldRSeqCriticalRegion {
+ return tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
}
-// loadRscr is invoked by stateify.
-func (tg *ThreadGroup) loadRscr(rscr *RSEQCriticalRegion) {
- tg.rscr.Store(rscr)
+// loadOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) loadOldRSeqCritical(r *OldRSeqCriticalRegion) {
+ tg.oldRSeqCritical.Store(r)
}
// SignalHandlers returns the signal handlers used by tg.
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 8267929a6..bf2dabb6e 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -16,9 +16,9 @@ package kernel
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 31847e1df..4e4de0512 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -13,6 +13,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index 107394183..706de83ef 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -19,10 +19,10 @@ package time
import (
"fmt"
"math"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index 76417342a..dc99301de 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -16,7 +16,6 @@ package kernel
import (
"fmt"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/log"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Timekeeper manages all of the kernel clocks.
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
index 048de26dc..464d2306a 100644
--- a/pkg/sentry/kernel/tty.go
+++ b/pkg/sentry/kernel/tty.go
@@ -14,7 +14,7 @@
package kernel
-import "sync"
+import "gvisor.dev/gvisor/pkg/sync"
// TTY defines the relationship between a thread group and its controlling
// terminal.
diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go
index 0a563e715..8ccf04bd1 100644
--- a/pkg/sentry/kernel/uts_namespace.go
+++ b/pkg/sentry/kernel/uts_namespace.go
@@ -15,9 +15,8 @@
package kernel
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
)
// UTSNamespace represents a UTS namespace, a holder of two system identifiers:
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 156e67bf8..9fa841e8b 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -15,6 +15,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/sentry/context",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go
index b6c22656b..31b9e9ff6 100644
--- a/pkg/sentry/limits/limits.go
+++ b/pkg/sentry/limits/limits.go
@@ -16,8 +16,9 @@
package limits
import (
- "sync"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// LimitType defines a type of resource limit.
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 839931f67..83e248431 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -118,7 +118,7 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip/buffer",
],
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 1b746d030..4b48866ad 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -15,8 +15,6 @@
package mm
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -25,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 58a5c186d..fa86ebced 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -35,8 +35,6 @@
package mm
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -44,7 +42,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// MemoryManager implements a virtual address space.
@@ -82,7 +80,7 @@ type MemoryManager struct {
users int32
// mappingMu is analogous to Linux's struct mm_struct::mmap_sem.
- mappingMu syncutil.DowngradableRWMutex `state:"nosave"`
+ mappingMu sync.DowngradableRWMutex `state:"nosave"`
// vmas stores virtual memory areas. Since vmas are stored by value,
// clients should usually use vmaIterator.ValuePtr() instead of
@@ -125,7 +123,7 @@ type MemoryManager struct {
// activeMu is loosely analogous to Linux's struct
// mm_struct::page_table_lock.
- activeMu syncutil.DowngradableRWMutex `state:"nosave"`
+ activeMu sync.DowngradableRWMutex `state:"nosave"`
// pmas stores platform mapping areas used to implement vmas. Since pmas
// are stored by value, clients should usually use pmaIterator.ValuePtr()
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
index 8c2246bb4..79610acb7 100644
--- a/pkg/sentry/mm/procfs.go
+++ b/pkg/sentry/mm/procfs.go
@@ -66,8 +66,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer
var start usermem.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
mm.appendVMAMapsEntryLocked(ctx, vseg, buf)
}
@@ -81,7 +79,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer
//
// Artifically adjust the seqfile handle so we only output vsyscall entry once.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
buf.WriteString(vsyscallMapsEntry)
}
}
@@ -97,8 +94,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
start = *handle.(*usermem.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
vmaAddr := vseg.End()
data = append(data, seqfile.SeqData{
Buf: mm.vmaMapsEntryLocked(ctx, vseg),
@@ -116,7 +111,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
//
// Artifically adjust the seqfile handle so we only output vsyscall entry once.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
vmaAddr := vsyscallEnd
data = append(data, seqfile.SeqData{
Buf: []byte(vsyscallMapsEntry),
@@ -187,15 +181,12 @@ func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffe
var start usermem.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf)
}
// We always emulate vsyscall, so advertise it here. See
// ReadMapsSeqFileData for additional commentary.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
buf.WriteString(vsyscallSmapsEntry)
}
}
@@ -211,8 +202,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
start = *handle.(*usermem.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
vmaAddr := vseg.End()
data = append(data, seqfile.SeqData{
Buf: mm.vmaSmapsEntryLocked(ctx, vseg),
@@ -223,7 +212,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
// We always emulate vsyscall, so advertise it here. See
// ReadMapsSeqFileData for additional commentary.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
vmaAddr := vsyscallEnd
data = append(data, seqfile.SeqData{
Buf: []byte(vsyscallSmapsEntry),
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index f404107af..a9a2642c5 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -73,6 +73,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/state",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index f7f7298c4..c99e023d9 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -25,7 +25,6 @@ import (
"fmt"
"math"
"os"
- "sync"
"sync/atomic"
"syscall"
"time"
@@ -37,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD
index b6d008dbe..85e882df9 100644
--- a/pkg/sentry/platform/interrupt/BUILD
+++ b/pkg/sentry/platform/interrupt/BUILD
@@ -10,6 +10,7 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt",
visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go
index a4651f500..57be41647 100644
--- a/pkg/sentry/platform/interrupt/interrupt.go
+++ b/pkg/sentry/platform/interrupt/interrupt.go
@@ -17,7 +17,8 @@ package interrupt
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// Receiver receives interrupt notifications from a Forwarder.
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index f3afd98da..6a358d1d4 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -55,6 +55,7 @@ go_library(
"//pkg/sentry/platform/safecopy",
"//pkg/sentry/time",
"//pkg/sentry/usermem",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index ea8b9632e..a25f3c449 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -15,13 +15,13 @@
package kvm
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// dirtySet tracks vCPUs for invalidation.
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index e5fac0d6a..2f02c03cf 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -17,8 +17,6 @@
package kvm
import (
- "unsafe"
-
"gvisor.dev/gvisor/pkg/sentry/arch"
)
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index f2c2c059e..a7850faed 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -18,13 +18,13 @@ package kvm
import (
"fmt"
"os"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// KVM represents a lightweight VM context.
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 7d02ebf19..e6d912168 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -17,7 +17,6 @@ package kvm
import (
"fmt"
"runtime"
- "sync"
"sync/atomic"
"syscall"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// machine contains state associated with the VM as a whole.
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index b99fe425e..873e39dc7 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -90,7 +90,9 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
// Clear from all PCIDs.
for _, c := range m.vCPUs {
- c.PCIDs.Drop(pt)
+ if c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
}
}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 7ae47f291..3b1f20219 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -97,7 +97,9 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
// Clear from all PCIDs.
for _, c := range m.vCPUs {
- c.PCIDs.Drop(pt)
+ if c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
}
}
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 0df8cfa0f..cd13390c3 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -33,6 +33,7 @@ go_library(
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/platform/safecopy",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index 7b120a15d..bb0e03880 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -46,13 +46,13 @@ package ptrace
import (
"os"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/interrupt"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
var (
diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s
index 64c718d21..16f9c523e 100644
--- a/pkg/sentry/platform/ptrace/stub_amd64.s
+++ b/pkg/sentry/platform/ptrace/stub_amd64.s
@@ -64,6 +64,8 @@ begin:
CMPQ AX, $0
JL error
+ MOVQ $0, BX
+
// SIGSTOP to wait for attach.
//
// The SYSCALL instruction will be used for future syscall injection by
@@ -73,23 +75,26 @@ begin:
MOVQ $SIGSTOP, SI
SYSCALL
- // The tracer may "detach" and/or allow code execution here in three cases:
- //
- // 1. New (traced) stub threads are explicitly detached by the
- // goroutine in newSubprocess. However, they are detached while in
- // group-stop, so they do not execute code here.
- //
- // 2. If a tracer thread exits, it implicitly detaches from the stub,
- // potentially allowing code execution here. However, the Go runtime
- // never exits individual threads, so this case never occurs.
- //
- // 3. subprocess.createStub clones a new stub process that is untraced,
+ // The sentry sets BX to 1 when creating stub process.
+ CMPQ BX, $1
+ JE clone
+
+ // Notify the Sentry that syscall exited.
+done:
+ INT $3
+ // Be paranoid.
+ JMP done
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
// thus executing this code. We setup the PDEATHSIG before SIGSTOPing
// ourselves for attach by the tracer.
//
// R15 has been updated with the expected PPID.
- JMP begin
+ CMPQ AX, $0
+ JE begin
+ // The clone syscall returns a non-zero value.
+ JMP done
error:
// Exit with -errno.
MOVQ AX, DI
diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s
index 2c5e4d5cb..6162df02a 100644
--- a/pkg/sentry/platform/ptrace/stub_arm64.s
+++ b/pkg/sentry/platform/ptrace/stub_arm64.s
@@ -59,6 +59,8 @@ begin:
CMP $0x0, R0
BLT error
+ MOVD $0, R9
+
// SIGSTOP to wait for attach.
//
// The SYSCALL instruction will be used for future syscall injection by
@@ -66,22 +68,26 @@ begin:
MOVD $SYS_KILL, R8
MOVD $SIGSTOP, R1
SVC
- // The tracer may "detach" and/or allow code execution here in three cases:
- //
- // 1. New (traced) stub threads are explicitly detached by the
- // goroutine in newSubprocess. However, they are detached while in
- // group-stop, so they do not execute code here.
- //
- // 2. If a tracer thread exits, it implicitly detaches from the stub,
- // potentially allowing code execution here. However, the Go runtime
- // never exits individual threads, so this case never occurs.
- //
- // 3. subprocess.createStub clones a new stub process that is untraced,
+
+ // The sentry sets R9 to 1 when creating stub process.
+ CMP $1, R9
+ BEQ clone
+
+done:
+ // Notify the Sentry that syscall exited.
+ BRK $3
+ B done // Be paranoid.
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
// thus executing this code. We setup the PDEATHSIG before SIGSTOPing
// ourselves for attach by the tracer.
//
// R7 has been updated with the expected PPID.
- B begin
+ CMP $0, R0
+ BEQ begin
+
+ // The clone system call returned a non-zero value.
+ B done
error:
// Exit with -errno.
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index ddb1f41e3..15dc46a5b 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -18,14 +18,15 @@ import (
"fmt"
"os"
"runtime"
- "sync"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Linux kernel errnos which "should never be seen by user programs", but will
@@ -429,13 +430,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
for {
- // Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ // Execute the syscall instruction. The task has to stop on the
+ // trap instruction which is right after the syscall
+ // instruction.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_CONT, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
sig := t.wait(stopped)
- if sig == (syscallEvent | syscall.SIGTRAP) {
+ if sig == syscall.SIGTRAP {
// Reached syscall-enter-stop.
break
} else {
@@ -447,18 +450,6 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
}
- // Complete the actual system call.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
- panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
- }
-
- // Wait for syscall-exit-stop. "[Signal-delivery-stop] never happens
- // between syscall-enter-stop and syscall-exit-stop; it happens *after*
- // syscall-exit-stop.)" - ptrace(2), "Syscall-stops"
- if sig := t.wait(stopped); sig != (syscallEvent | syscall.SIGTRAP) {
- t.dumpAndPanic(fmt.Sprintf("wait failed: expected SIGTRAP, got %v [%d]", sig, sig))
- }
-
// Grab registers.
if err := t.getRegs(regs); err != nil {
panic(fmt.Sprintf("ptrace get regs failed: %v", err))
@@ -541,14 +532,14 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
if isSingleStepping(regs) {
if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
- syscall.PTRACE_SYSEMU_SINGLESTEP,
+ unix.PTRACE_SYSEMU_SINGLESTEP,
uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
} else {
if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
- syscall.PTRACE_SYSEMU,
+ unix.PTRACE_SYSEMU,
uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 606dc2b1d..e99798c56 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -141,9 +141,11 @@ func (t *thread) adjustInitRegsRip() {
t.initRegs.Rip -= initRegsRipAdjustment
}
-// Pass the expected PPID to the child via R15 when creating stub process
+// Pass the expected PPID to the child via R15 when creating stub process.
func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
initregs.R15 = uint64(ppid)
+ // Rbx has to be set to 1 when creating stub process.
+ initregs.Rbx = 1
}
// patchSignalInfo patches the signal info to account for hitting the seccomp
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index 62a686ee7..7b975137f 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -127,6 +127,8 @@ func (t *thread) adjustInitRegsRip() {
// Pass the expected PPID to the child via X7 when creating stub process
func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
initregs.Regs[7] = uint64(ppid)
+ // R9 has to be set to 1 when creating stub process.
+ initregs.Regs[9] = 1
}
// patchSignalInfo patches the signal info to account for hitting the seccomp
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index cf13ea5e4..74968dfdf 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -54,7 +54,7 @@ func probeSeccomp() bool {
for {
// Attempt an emulation.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
index 2e6fbe488..245b20722 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
@@ -18,7 +18,6 @@
package ptrace
import (
- "sync"
"sync/atomic"
"syscall"
"unsafe"
@@ -26,6 +25,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
+ "gvisor.dev/gvisor/pkg/sync"
)
// maskPool contains reusable CPU masks for setting affinity. Unfortunately,
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
index 3f094c2a7..86fd5ed58 100644
--- a/pkg/sentry/platform/ring0/defs.go
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -17,7 +17,7 @@ package ring0
import (
"syscall"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
)
// Kernel is a global kernel object.
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 10dbd381f..9dae0dccb 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -18,6 +18,7 @@ package ring0
import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
var (
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
index dc0eeec01..a850ce6cf 100644
--- a/pkg/sentry/platform/ring0/defs_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -18,6 +18,7 @@ package ring0
import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
var (
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index add2c3e08..da07815ff 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -357,6 +357,73 @@ TEXT ·Current(SB),NOSPLIT,$0-8
#define STACK_FRAME_SIZE 16
TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
+ // Step1, save sentry context into memory.
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
+ MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+
+ WORD $0xd5384003 // MRS SPSR_EL1, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG)
+ MOVD R30, CPU_REGISTERS+PTRACE_PC(RSV_REG)
+ MOVD RSP, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_SP(RSV_REG)
+
+ MOVD CPU_REGISTERS+PTRACE_R3(RSV_REG), R3
+
+ // Step2, save SP_EL1, PSTATE into kernel temporary stack.
+ // switch to temporary stack.
+ LOAD_KERNEL_STACK(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R11
+ MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R12
+ STP (R11, R12), 16*0(RSP)
+
+ MOVD CPU_REGISTERS+PTRACE_R11(RSV_REG), R11
+ MOVD CPU_REGISTERS+PTRACE_R12(RSV_REG), R12
+
+ // Step3, test user pagetable.
+ // If user pagetable is empty, trapped in el1_ia.
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG)
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ // If pagetable is not empty, recovery kernel temporary stack.
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
+ // Step4, load app context pointer.
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP
+
+ // Step5, prepare the environment for container application.
+ // set sp_el0.
+ MOVD PTRACE_SP(RSV_REG_APP), R1
+ WORD $0xd5184101 //MSR R1, SP_EL0
+ // set pc.
+ MOVD PTRACE_PC(RSV_REG_APP), R1
+ MSR R1, ELR_EL1
+ // set pstate.
+ MOVD PTRACE_PSTATE(RSV_REG_APP), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ // RSV_REG & RSV_REG_APP will be loaded at the end.
+ REGISTERS_LOAD(RSV_REG_APP, 0)
+
+ // switch to user pagetable.
+ MOVD PTRACE_R18(RSV_REG_APP), RSV_REG
+ MOVD PTRACE_R9(RSV_REG_APP), RSV_REG_APP
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ STP (RSV_REG, RSV_REG_APP), 16*0(RSP)
+
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+
+ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
ERET()
TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
@@ -407,6 +474,16 @@ TEXT ·El1_sync(SB),NOSPLIT,$0
B el1_invalid
el1_da:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $0, CPU_ERROR_TYPE(RSV_REG)
+
+ MOVD $PageFault, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
B ·Halt(SB)
el1_ia:
@@ -554,6 +631,10 @@ TEXT ·Vectors(SB),NOSPLIT,$0
B ·El0_error_invalid(SB)
nop31Instructions()
+ // The exception-vector-table is required to be 11-bits aligned.
+ // Please see Linux source code as reference: arch/arm64/kernel/entry.s.
+ // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs.
+ // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
WORD $0xd503201f //nop
nop31Instructions()
WORD $0xd503201f
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index 1c9171004..0e6a6235b 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "funcdata.h"
+#include "textflag.h"
+
TEXT ·CPACREL1(SB),NOSPLIT,$0-8
WORD $0xd5381041 // MRS CPACR_EL1, R1
MOVD R1, ret+0(FP)
RET
-TEXT ·FPCR(SB),NOSPLIT,$0-8
+TEXT ·GetFPCR(SB),NOSPLIT,$0-8
WORD $0xd53b4201 // MRS NZCV, R1
MOVD R1, ret+0(FP)
RET
@@ -27,7 +30,7 @@ TEXT ·GetFPSR(SB),NOSPLIT,$0-8
MOVD R1, ret+0(FP)
RET
-TEXT ·FPCR(SB),NOSPLIT,$0-8
+TEXT ·SetFPCR(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R1
WORD $0xd51b4201 // MSR R1, NZCV
RET
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index e2e15ba5c..387a7f6c3 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -96,7 +96,10 @@ go_library(
"//pkg/sentry/platform/kvm:__subpackages__",
"//pkg/sentry/platform/ring0:__subpackages__",
],
- deps = ["//pkg/sentry/usermem"],
+ deps = [
+ "//pkg/sentry/usermem",
+ "//pkg/sync",
+ ],
)
go_test(
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
index 0f029f25d..e199bae18 100644
--- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
@@ -17,7 +17,7 @@
package pagetables
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// limitPCID is the number of valid PCIDs.
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index af1a4e95f..1684dfc24 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -327,7 +327,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
}
// PackTOS packs an IP_TOS socket control message.
-func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte {
+func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
@@ -471,6 +471,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_IP:
switch h.Type {
case linux.IP_TOS:
+ if length < linux.SizeOfControlMessageTOS {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
cmsgs.IP.HasTOS = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
i += AlignUp(length, width)
@@ -481,6 +484,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_IPV6:
switch h.Type {
case linux.IPV6_TCLASS:
+ if length < linux.SizeOfControlMessageTClass {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
cmsgs.IP.HasTClass = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += AlignUp(length, width)
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 5eb06bbf4..b70047d81 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -14,6 +14,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/log",
"//pkg/sentry/kernel",
"//pkg/sentry/usermem",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 9f87c32f1..a9cfc1749 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
@@ -35,6 +36,7 @@ const errorTargetName = "ERROR"
// metadata is opaque to netstack. It holds data that we need to translate
// between Linux's and netstack's iptables representations.
+// TODO(gvisor.dev/issue/170): This might be removable.
type metadata struct {
HookEntry [linux.NF_INET_NUMHOOKS]uint32
Underflow [linux.NF_INET_NUMHOOKS]uint32
@@ -51,7 +53,7 @@ func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTG
}
// Find the appropriate table.
- table, err := findTable(ep, info.TableName())
+ table, err := findTable(ep, info.Name)
if err != nil {
return linux.IPTGetinfo{}, err
}
@@ -82,30 +84,31 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i
}
// Find the appropriate table.
- table, err := findTable(ep, userEntries.TableName())
+ table, err := findTable(ep, userEntries.Name)
if err != nil {
return linux.KernelIPTGetEntries{}, err
}
// Convert netstack's iptables rules to something that the iptables
// tool can understand.
- entries, _, err := convertNetstackToBinary(userEntries.TableName(), table)
+ entries, _, err := convertNetstackToBinary(userEntries.Name.String(), table)
if err != nil {
return linux.KernelIPTGetEntries{}, err
}
if binary.Size(entries) > uintptr(outLen) {
+ log.Warningf("Insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
return entries, nil
}
-func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) {
+func findTable(ep tcpip.Endpoint, tablename linux.TableName) (iptables.Table, *syserr.Error) {
ipt, err := ep.IPTables()
if err != nil {
return iptables.Table{}, syserr.FromError(err)
}
- table, ok := ipt.Tables[tableName]
+ table, ok := ipt.Tables[tablename.String()]
if !ok {
return iptables.Table{}, syserr.ErrInvalidArgument
}
@@ -135,110 +138,68 @@ func FillDefaultIPTables(stack *stack.Stack) {
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
+func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
// Return values.
var entries linux.KernelIPTGetEntries
var meta metadata
// The table name has to fit in the struct.
- if linux.XT_TABLE_MAXNAMELEN < len(name) {
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ log.Warningf("Table name %q too long.", tablename)
return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
}
- copy(entries.Name[:], name)
-
- // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of
- // these chains ends with an unconditional policy entry.
- for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ {
- chain, ok := table.BuiltinChains[hook]
- if !ok {
- // This table doesn't support this hook.
- continue
- }
-
- // Sanity check.
- if len(chain.Rules) < 1 {
- return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
- }
+ copy(entries.Name[:], tablename)
- for ruleIdx, rule := range chain.Rules {
- // If this is the first rule of a builtin chain, set
- // the metadata hook entry point.
- if ruleIdx == 0 {
+ for ruleIdx, rule := range table.Rules {
+ // Is this a chain entry point?
+ for hook, hookRuleIdx := range table.BuiltinChains {
+ if hookRuleIdx == ruleIdx {
meta.HookEntry[hook] = entries.Size
}
-
- // Each rule corresponds to an entry.
- entry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
- NextOffset: linux.SizeOfIPTEntry,
- TargetOffset: linux.SizeOfIPTEntry,
- },
+ }
+ // Is this a chain underflow point?
+ for underflow, underflowRuleIdx := range table.Underflows {
+ if underflowRuleIdx == ruleIdx {
+ meta.Underflow[underflow] = entries.Size
}
+ }
- for _, matcher := range rule.Matchers {
- // Serialize the matcher and add it to the
- // entry.
- serialized := marshalMatcher(matcher)
- entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
- entry.TargetOffset += uint16(len(serialized))
- }
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
- // Serialize and append the target.
- serialized := marshalTarget(rule.Target)
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
entry.Elems = append(entry.Elems, serialized...)
entry.NextOffset += uint16(len(serialized))
-
- // The underflow rule is the last rule in the chain,
- // and is an unconditional rule (i.e. it matches any
- // packet). This is enforced when saving iptables.
- if ruleIdx == len(chain.Rules)-1 {
- meta.Underflow[hook] = entries.Size
- }
-
- entries.Size += uint32(entry.NextOffset)
- entries.Entrytable = append(entries.Entrytable, entry)
- meta.NumEntries++
+ entry.TargetOffset += uint16(len(serialized))
}
- }
-
- // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of
- // these starts with an error node holding the chain's name and ends
- // with an unconditional return.
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
- // Lastly, each table ends with an unconditional error target rule as
- // its final entry.
- errorEntry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
- NextOffset: linux.SizeOfIPTEntry,
- TargetOffset: linux.SizeOfIPTEntry,
- },
+ entries.Size += uint32(entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ meta.NumEntries++
}
- var errorTarget linux.XTErrorTarget
- errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget
- copy(errorTarget.ErrorName[:], errorTargetName)
- copy(errorTarget.Target.Name[:], errorTargetName)
-
- // Serialize and add it to the list of entries.
- errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget)
- serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget)
- errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...)
- errorEntry.NextOffset += uint16(len(serializedErrorTarget))
-
- entries.Size += uint32(errorEntry.NextOffset)
- entries.Entrytable = append(entries.Entrytable, errorEntry)
- meta.NumEntries++
- meta.Size = entries.Size
+ meta.Size = entries.Size
return entries, meta, nil
}
func marshalMatcher(matcher iptables.Matcher) []byte {
switch matcher.(type) {
default:
- // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so
- // any call to marshalMatcher will panic.
+ // TODO(gvisor.dev/issue/170): We don't support any matchers
+ // yet, so any call to marshalMatcher will panic.
panic(fmt.Errorf("unknown matcher of type %T", matcher))
}
}
@@ -246,28 +207,46 @@ func marshalMatcher(matcher iptables.Matcher) []byte {
func marshalTarget(target iptables.Target) []byte {
switch target.(type) {
case iptables.UnconditionalAcceptTarget:
- return marshalUnconditionalAcceptTarget()
+ return marshalStandardTarget(iptables.Accept)
+ case iptables.UnconditionalDropTarget:
+ return marshalStandardTarget(iptables.Drop)
+ case iptables.ErrorTarget:
+ return marshalErrorTarget()
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
}
-func marshalUnconditionalAcceptTarget() []byte {
+func marshalStandardTarget(verdict iptables.Verdict) []byte {
// The target's name will be the empty string.
target := linux.XTStandardTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTStandardTarget,
},
- Verdict: translateStandardVerdict(iptables.Accept),
+ Verdict: translateFromStandardVerdict(verdict),
}
ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
return binary.Marshal(ret, usermem.ByteOrder, target)
}
-// translateStandardVerdict translates verdicts the same way as the iptables
+func marshalErrorTarget() []byte {
+ // This is an error target named error
+ target := linux.XTErrorTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTErrorTarget,
+ },
+ }
+ copy(target.Name[:], errorTargetName)
+ copy(target.Target.Name[:], errorTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+// translateFromStandardVerdict translates verdicts the same way as the iptables
// tool.
-func translateStandardVerdict(verdict iptables.Verdict) int32 {
+func translateFromStandardVerdict(verdict iptables.Verdict) int32 {
switch verdict {
case iptables.Accept:
return -linux.NF_ACCEPT - 1
@@ -280,7 +259,258 @@ func translateStandardVerdict(verdict iptables.Verdict) int32 {
case iptables.Jump:
// TODO(gvisor.dev/issue/170): Support Jump.
panic("Jump isn't supported yet")
+ }
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+}
+
+// translateToStandardVerdict translates from the value in a
+// linux.XTStandardTarget to an iptables.Verdict.
+func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) {
+ // TODO(gvisor.dev/issue/170): Support other verdicts.
+ switch val {
+ case -linux.NF_ACCEPT - 1:
+ return iptables.Accept, nil
+ case -linux.NF_DROP - 1:
+ return iptables.Drop, nil
+ case -linux.NF_QUEUE - 1:
+ log.Warningf("Unsupported iptables verdict QUEUE.")
+ case linux.NF_RETURN:
+ log.Warningf("Unsupported iptables verdict RETURN.")
+ default:
+ log.Warningf("Unknown iptables verdict %d.", val)
+ }
+ return iptables.Invalid, syserr.ErrInvalidArgument
+}
+
+// SetEntries sets iptables rules for a single table. See
+// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
+func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
+ printReplace(optVal)
+
+ // Get the basic rules data (struct ipt_replace).
+ if len(optVal) < linux.SizeOfIPTReplace {
+ log.Warningf("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var replace linux.IPTReplace
+ replaceBuf := optVal[:linux.SizeOfIPTReplace]
+ optVal = optVal[linux.SizeOfIPTReplace:]
+ binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
+
+ // TODO(gvisor.dev/issue/170): Support other tables.
+ var table iptables.Table
+ switch replace.Name.String() {
+ case iptables.TablenameFilter:
+ table = iptables.EmptyFilterTable()
default:
- panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+ log.Warningf("We don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ return syserr.ErrInvalidArgument
+ }
+
+ // Convert input into a list of rules and their offsets.
+ var offset uint32
+ var offsets []uint32
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ // Get the struct ipt_entry.
+ if len(optVal) < linux.SizeOfIPTEntry {
+ log.Warningf("netfilter: optVal has insufficient size for entry %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var entry linux.IPTEntry
+ buf := optVal[:linux.SizeOfIPTEntry]
+ optVal = optVal[linux.SizeOfIPTEntry:]
+ binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ if entry.TargetOffset != linux.SizeOfIPTEntry {
+ // TODO(gvisor.dev/issue/170): Support matchers.
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): We should support IPTIP
+ // filtering. We reject any nonzero IPTIP values for now.
+ emptyIPTIP := linux.IPTIP{}
+ if entry.IP != emptyIPTIP {
+ log.Warningf("netfilter: non-empty struct iptip found")
+ return syserr.ErrInvalidArgument
+ }
+
+ // Get the target of the rule.
+ target, consumed, err := parseTarget(optVal)
+ if err != nil {
+ return err
+ }
+ optVal = optVal[consumed:]
+
+ table.Rules = append(table.Rules, iptables.Rule{Target: target})
+ offsets = append(offsets, offset)
+ offset += linux.SizeOfIPTEntry + consumed
+ }
+
+ // Go through the list of supported hooks for this table and, for each
+ // one, set the rule it corresponds to.
+ for hook, _ := range replace.HookEntry {
+ if table.ValidHooks()&uint32(hook) != 0 {
+ hk := hookFromLinux(hook)
+ for ruleIdx, offset := range offsets {
+ if offset == replace.HookEntry[hook] {
+ table.BuiltinChains[hk] = ruleIdx
+ }
+ if offset == replace.Underflow[hook] {
+ table.Underflows[hk] = ruleIdx
+ }
+ }
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset {
+ log.Warningf("Hook %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset {
+ log.Warningf("Underflow %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+
+ ipt := stack.IPTables()
+ table.SetMetadata(metadata{
+ HookEntry: replace.HookEntry,
+ Underflow: replace.Underflow,
+ NumEntries: replace.NumEntries,
+ Size: replace.Size,
+ })
+ ipt.Tables[replace.Name.String()] = table
+ stack.SetIPTables(ipt)
+
+ return nil
+}
+
+// parseTarget parses a target from the start of optVal and returns the target
+// along with the number of bytes it occupies in optVal.
+func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) {
+ if len(optVal) < linux.SizeOfXTEntryTarget {
+ log.Warningf("netfilter: optVal has insufficient size for entry target %d", len(optVal))
+ return nil, 0, syserr.ErrInvalidArgument
+ }
+ var target linux.XTEntryTarget
+ buf := optVal[:linux.SizeOfXTEntryTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &target)
+ switch target.Name.String() {
+ case "":
+ // Standard target.
+ if len(optVal) < linux.SizeOfXTStandardTarget {
+ log.Warningf("netfilter.SetEntries: optVal has insufficient size for standard target %d", len(optVal))
+ return nil, 0, syserr.ErrInvalidArgument
+ }
+ var standardTarget linux.XTStandardTarget
+ buf = optVal[:linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
+
+ verdict, err := translateToStandardVerdict(standardTarget.Verdict)
+ if err != nil {
+ return nil, 0, err
+ }
+ switch verdict {
+ case iptables.Accept:
+ return iptables.UnconditionalAcceptTarget{}, linux.SizeOfXTStandardTarget, nil
+ case iptables.Drop:
+ // TODO(gvisor.dev/issue/170): Return an
+ // iptables.UnconditionalDropTarget to support DROP.
+ log.Infof("netfilter DROP is not supported yet.")
+ return nil, 0, syserr.ErrInvalidArgument
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %v", verdict))
+ }
+
+ case errorTargetName:
+ // Error target.
+ if len(optVal) < linux.SizeOfXTErrorTarget {
+ log.Infof("netfilter.SetEntries: optVal has insufficient size for error target %d", len(optVal))
+ return nil, 0, syserr.ErrInvalidArgument
+ }
+ var errorTarget linux.XTErrorTarget
+ buf = optVal[:linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+
+ // Error targets are used in 2 cases:
+ // * An actual error case. These rules have an error
+ // named errorTargetName. The last entry of the table
+ // is usually an error case to catch any packets that
+ // somehow fall through every rule.
+ // * To mark the start of a user defined chain. These
+ // rules have an error with the name of the chain.
+ switch errorTarget.Name.String() {
+ case errorTargetName:
+ return iptables.ErrorTarget{}, linux.SizeOfXTErrorTarget, nil
+ default:
+ log.Infof("Unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String())
+ return nil, 0, syserr.ErrInvalidArgument
+ }
+ }
+
+ // Unknown target.
+ log.Infof("Unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
+ return nil, 0, syserr.ErrInvalidArgument
+}
+
+func hookFromLinux(hook int) iptables.Hook {
+ switch hook {
+ case linux.NF_INET_PRE_ROUTING:
+ return iptables.Prerouting
+ case linux.NF_INET_LOCAL_IN:
+ return iptables.Input
+ case linux.NF_INET_FORWARD:
+ return iptables.Forward
+ case linux.NF_INET_LOCAL_OUT:
+ return iptables.Output
+ case linux.NF_INET_POST_ROUTING:
+ return iptables.Postrouting
+ }
+ panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
+}
+
+// printReplace prints information about the struct ipt_replace in optVal. It
+// is only for debugging.
+func printReplace(optVal []byte) {
+ // Basic replace info.
+ var replace linux.IPTReplace
+ replaceBuf := optVal[:linux.SizeOfIPTReplace]
+ optVal = optVal[linux.SizeOfIPTReplace:]
+ binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
+ log.Infof("Replacing table %q: %+v", replace.Name.String(), replace)
+
+ // Read in the list of entries at the end of replace.
+ var totalOffset uint16
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ var entry linux.IPTEntry
+ entryBuf := optVal[:linux.SizeOfIPTEntry]
+ binary.Unmarshal(entryBuf, usermem.ByteOrder, &entry)
+ log.Infof("Entry %d (total offset %d): %+v", entryIdx, totalOffset, entry)
+
+ totalOffset += entry.NextOffset
+ if entry.TargetOffset == linux.SizeOfIPTEntry {
+ log.Infof("Entry has no matches.")
+ } else {
+ log.Infof("Entry has matches.")
+ }
+
+ var target linux.XTEntryTarget
+ targetBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTEntryTarget]
+ binary.Unmarshal(targetBuf, usermem.ByteOrder, &target)
+ log.Infof("Target named %q: %+v", target.Name.String(), target)
+
+ switch target.Name.String() {
+ case "":
+ var standardTarget linux.XTStandardTarget
+ stBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(stBuf, usermem.ByteOrder, &standardTarget)
+ log.Infof("Standard target with verdict %q (%d).", linux.VerdictStrings[standardTarget.Verdict], standardTarget.Verdict)
+ case errorTargetName:
+ var errorTarget linux.XTErrorTarget
+ etBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(etBuf, usermem.ByteOrder, &errorTarget)
+ log.Infof("Error target with name %q.", errorTarget.Name.String())
+ default:
+ log.Infof("Unknown target type.")
+ }
+
+ optVal = optVal[entry.NextOffset:]
}
}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 79589e3c8..103933144 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -22,12 +22,12 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 463544c1a..2d9f4ba9b 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -8,6 +8,7 @@ go_library(
srcs = ["port.go"],
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port",
visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go
index e9d3275b1..2cd3afc22 100644
--- a/pkg/sentry/socket/netlink/port/port.go
+++ b/pkg/sentry/socket/netlink/port/port.go
@@ -24,7 +24,8 @@ import (
"fmt"
"math"
"math/rand"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// maxPorts is a sanity limit on the maximum number of ports to allocate per
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 4a1b87a9a..cea56f4ed 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -17,7 +17,6 @@ package netlink
import (
"math"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
@@ -29,12 +28,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -500,29 +499,29 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
trunc := flags&linux.MSG_TRUNC != 0
r := unix.EndpointReader{
+ Ctx: t,
Endpoint: s.ep,
Peek: flags&linux.MSG_PEEK != 0,
}
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
// If MSG_TRUNC is set with a zero byte destination then we still need
// to read the message and discard it, or in the case where MSG_PEEK is
// set, leave it be. In both cases the full message length must be
- // returned. However, the memory manager for the destination will not read
- // the endpoint if the destination is zero length.
- //
- // In order for the endpoint to be read when the destination size is zero,
- // we must cause a read of the endpoint by using a separate fake zero
- // length block sequence and calling the EndpointReader directly.
+ // returned.
if trunc && dst.Addrs.NumBytes() == 0 {
- // Perform a read to a zero byte block sequence. We can ignore the
- // original destination since it was zero bytes. The length returned by
- // ReadToBlocks is ignored and we return the full message length to comply
- // with MSG_TRUNC.
- _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0))))
- return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
}
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
@@ -540,7 +539,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index e414d8055..f78784569 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/socket/netfilter",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 140851c17..d2f7e987d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -29,7 +29,6 @@ import (
"io"
"math"
"reflect"
- "sync"
"syscall"
"time"
@@ -49,6 +48,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -222,17 +222,25 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
+ // transport.Endpoint.SetSockOptBool.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
+ // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
}
// SocketOperations encapsulates all the state needed to represent a network stack
@@ -316,22 +324,15 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
// AF_INET6, and AF_PACKET addresses.
//
-// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
-//
// AddressAndFamily returns an address and its family.
-func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
}
- family := usermem.ByteOrder.Uint16(addr)
- if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) {
- return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
- }
-
// Get the rest of the fields based on the address family.
- switch family {
+ switch family := usermem.ByteOrder.Uint16(addr); family {
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
@@ -630,10 +631,40 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
return r
}
+func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error {
+ if family == uint16(s.family) {
+ return nil
+ }
+ if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
+ v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ if !v {
+ return nil
+ }
+ }
+ return syserr.ErrInvalidArgument
+}
+
+// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the
+// receiver's family is AF_INET6.
+//
+// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
+// represented by the empty string.
+//
+// TODO(gvisor.dev/issues/1556): remove this function.
+func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
+ if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
+ addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ }
+ return addr
+}
+
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */)
+ addr, family, err := AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -645,6 +676,12 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
return syserr.TranslateNetstackError(err)
}
+
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return err
+ }
+ addr = s.mapFamily(addr, family)
+
// Always return right away in the non-blocking case.
if !blocking {
return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
@@ -673,10 +710,14 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */)
+ addr, family, err := AddressAndFamily(sockaddr)
if err != nil {
return err
}
+ if err := s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+ addr = s.mapFamily(addr, family)
// Issue the bind request to the endpoint.
return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
@@ -977,13 +1018,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if len(v) == 0 {
+ if v == 0 {
return []byte{}, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
}
- return append([]byte(v), 0), nil
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+ return append([]byte(nic.Name), 0), nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1213,12 +1264,15 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.V6OnlyOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1323,6 +1377,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
return int32(v), nil
+ case linux.IP_RECVTOS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1356,6 +1425,26 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
return nil
}
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+ }
+ }
+
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
@@ -1427,7 +1516,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
if n == -1 {
n = len(optVal)
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
@@ -1621,7 +1723,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
@@ -1808,6 +1910,13 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+ case linux.IP_RECVTOS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1828,7 +1937,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
linux.IP_RECVORIGDSTADDR,
- linux.IP_RECVTOS,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2026,8 +2134,8 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
case linux.AF_INET6:
var out linux.SockAddrInet6
- if len(addr.Addr) == 4 {
- // Copy address is v4-mapped format.
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
copy(out.Addr[12:], addr.Addr)
out.Addr[10] = 0xff
out.Addr[11] = 0xff
@@ -2248,7 +2356,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
}
func (s *SocketOperations) controlMessages() socket.ControlMessages {
- return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}}
+ return socket.ControlMessages{
+ IP: tcpip.ControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ },
+ }
}
// updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after
@@ -2341,10 +2456,14 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */)
+ addrBuf, family, err := AddressAndFamily(to)
if err != nil {
return 0, err
}
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return 0, err
+ }
+ addrBuf = s.mapFamily(addrBuf, family)
addr = &addrBuf
}
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
deleted file mode 100644
index 4668b87d1..000000000
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ /dev/null
@@ -1,69 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "rpcinet",
- srcs = [
- "device.go",
- "rpcinet.go",
- "socket.go",
- "stack.go",
- "stack_unsafe.go",
- ],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet",
- visibility = ["//pkg/sentry:internal"],
- deps = [
- ":syscall_rpc_go_proto",
- "//pkg/abi/linux",
- "//pkg/binary",
- "//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/device",
- "//pkg/sentry/fs",
- "//pkg/sentry/fs/fsutil",
- "//pkg/sentry/inet",
- "//pkg/sentry/kernel",
- "//pkg/sentry/kernel/time",
- "//pkg/sentry/socket",
- "//pkg/sentry/socket/hostinet",
- "//pkg/sentry/socket/rpcinet/conn",
- "//pkg/sentry/socket/rpcinet/notifier",
- "//pkg/sentry/unimpl",
- "//pkg/sentry/usermem",
- "//pkg/syserr",
- "//pkg/syserror",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- "//pkg/unet",
- "//pkg/waiter",
- ],
-)
-
-proto_library(
- name = "syscall_rpc_proto",
- srcs = ["syscall_rpc.proto"],
- visibility = [
- "//visibility:public",
- ],
-)
-
-cc_proto_library(
- name = "syscall_rpc_cc_proto",
- visibility = [
- "//visibility:public",
- ],
- deps = [":syscall_rpc_proto"],
-)
-
-go_proto_library(
- name = "syscall_rpc_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
- proto = ":syscall_rpc_proto",
- visibility = [
- "//visibility:public",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD
deleted file mode 100644
index 23eadcb1b..000000000
--- a/pkg/sentry/socket/rpcinet/conn/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "conn",
- srcs = ["conn.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn",
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/binary",
- "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
- "//pkg/syserr",
- "//pkg/unet",
- "@com_github_golang_protobuf//proto:go_default_library",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go
deleted file mode 100644
index 356adad99..000000000
--- a/pkg/sentry/socket/rpcinet/conn/conn.go
+++ /dev/null
@@ -1,187 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package conn is an RPC connection to a syscall RPC server.
-package conn
-
-import (
- "fmt"
- "sync"
- "sync/atomic"
- "syscall"
-
- "github.com/golang/protobuf/proto"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/unet"
-
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
-)
-
-type request struct {
- response []byte
- ready chan struct{}
- ignoreResult bool
-}
-
-// RPCConnection represents a single RPC connection to a syscall gofer.
-type RPCConnection struct {
- // reqID is the ID of the last request and must be accessed atomically.
- reqID uint64
-
- sendMu sync.Mutex
- socket *unet.Socket
-
- reqMu sync.Mutex
- requests map[uint64]request
-}
-
-// NewRPCConnection initializes a RPC connection to a socket gofer.
-func NewRPCConnection(s *unet.Socket) *RPCConnection {
- conn := &RPCConnection{socket: s, requests: map[uint64]request{}}
- go func() { // S/R-FIXME(b/77962828)
- var nums [16]byte
- for {
- for n := 0; n < len(nums); {
- nn, err := conn.socket.Read(nums[n:])
- if err != nil {
- panic(fmt.Sprint("error reading length from socket rpc gofer: ", err))
- }
- n += nn
- }
-
- b := make([]byte, binary.LittleEndian.Uint64(nums[:8]))
- id := binary.LittleEndian.Uint64(nums[8:])
-
- for n := 0; n < len(b); {
- nn, err := conn.socket.Read(b[n:])
- if err != nil {
- panic(fmt.Sprint("error reading request from socket rpc gofer: ", err))
- }
- n += nn
- }
-
- conn.reqMu.Lock()
- r := conn.requests[id]
- if r.ignoreResult {
- delete(conn.requests, id)
- } else {
- r.response = b
- conn.requests[id] = r
- }
- conn.reqMu.Unlock()
- close(r.ready)
- }
- }()
- return conn
-}
-
-// NewRequest makes a request to the RPC gofer and returns the request ID and a
-// channel which will be closed once the request completes.
-func (c *RPCConnection) NewRequest(req pb.SyscallRequest, ignoreResult bool) (uint64, chan struct{}) {
- b, err := proto.Marshal(&req)
- if err != nil {
- panic(fmt.Sprint("invalid proto: ", err))
- }
-
- id := atomic.AddUint64(&c.reqID, 1)
- ch := make(chan struct{})
-
- c.reqMu.Lock()
- c.requests[id] = request{ready: ch, ignoreResult: ignoreResult}
- c.reqMu.Unlock()
-
- c.sendMu.Lock()
- defer c.sendMu.Unlock()
-
- var nums [16]byte
- binary.LittleEndian.PutUint64(nums[:8], uint64(len(b)))
- binary.LittleEndian.PutUint64(nums[8:], id)
- for n := 0; n < len(nums); {
- nn, err := c.socket.Write(nums[n:])
- if err != nil {
- panic(fmt.Sprint("error writing length and ID to socket gofer: ", err))
- }
- n += nn
- }
-
- for n := 0; n < len(b); {
- nn, err := c.socket.Write(b[n:])
- if err != nil {
- panic(fmt.Sprint("error writing request to socket gofer: ", err))
- }
- n += nn
- }
-
- return id, ch
-}
-
-// RPCReadFile will execute the ReadFile helper RPC method which avoids the
-// common pattern of open(2), read(2), close(2) by doing all three operations
-// as a single RPC. It will read the entire file or return EFBIG if the file
-// was too large.
-func (c *RPCConnection) RPCReadFile(path string) ([]byte, *syserr.Error) {
- req := &pb.SyscallRequest_ReadFile{&pb.ReadFileRequest{
- Path: path,
- }}
-
- id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-ch
-
- res := c.Request(id).Result.(*pb.SyscallResponse_ReadFile).ReadFile.Result
- if e, ok := res.(*pb.ReadFileResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.ReadFileResponse_Data).Data, nil
-}
-
-// RPCWriteFile will execute the WriteFile helper RPC method which avoids the
-// common pattern of open(2), write(2), write(2), close(2) by doing all
-// operations as a single RPC.
-func (c *RPCConnection) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
- req := &pb.SyscallRequest_WriteFile{&pb.WriteFileRequest{
- Path: path,
- Content: data,
- }}
-
- id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-ch
-
- res := c.Request(id).Result.(*pb.SyscallResponse_WriteFile).WriteFile
- if e := res.ErrorNumber; e != 0 {
- return int64(res.Written), syserr.FromHost(syscall.Errno(e))
- }
-
- return int64(res.Written), nil
-}
-
-// Request retrieves the request corresponding to the given request ID.
-//
-// The channel returned by NewRequest must have been closed before Request can
-// be called. This will happen automatically, do not manually close the
-// channel.
-func (c *RPCConnection) Request(id uint64) pb.SyscallResponse {
- c.reqMu.Lock()
- r := c.requests[id]
- delete(c.requests, id)
- c.reqMu.Unlock()
-
- var resp pb.SyscallResponse
- if err := proto.Unmarshal(r.response, &resp); err != nil {
- panic(fmt.Sprint("invalid proto: ", err))
- }
-
- return resp
-}
diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD
deleted file mode 100644
index a3585e10d..000000000
--- a/pkg/sentry/socket/rpcinet/notifier/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "notifier",
- srcs = ["notifier.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier",
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
- "//pkg/sentry/socket/rpcinet/conn",
- "//pkg/waiter",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go
deleted file mode 100644
index 7efe4301f..000000000
--- a/pkg/sentry/socket/rpcinet/notifier/notifier.go
+++ /dev/null
@@ -1,231 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package notifier implements an FD notifier implementation over RPC.
-package notifier
-
-import (
- "fmt"
- "sync"
- "syscall"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-type fdInfo struct {
- queue *waiter.Queue
- waiting bool
-}
-
-// Notifier holds all the state necessary to issue notifications when IO events
-// occur in the observed FDs.
-type Notifier struct {
- // rpcConn is the connection that is used for sending RPCs.
- rpcConn *conn.RPCConnection
-
- // epFD is the epoll file descriptor used to register for io
- // notifications.
- epFD uint32
-
- // mu protects fdMap.
- mu sync.Mutex
-
- // fdMap maps file descriptors to their notification queues and waiting
- // status.
- fdMap map[uint32]*fdInfo
-}
-
-// NewRPCNotifier creates a new notifier object.
-func NewRPCNotifier(cn *conn.RPCConnection) (*Notifier, error) {
- id, c := cn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCreate1{&pb.EpollCreate1Request{}}}, false /* ignoreResult */)
- <-c
-
- res := cn.Request(id).Result.(*pb.SyscallResponse_EpollCreate1).EpollCreate1.Result
- if e, ok := res.(*pb.EpollCreate1Response_ErrorNumber); ok {
- return nil, syscall.Errno(e.ErrorNumber)
- }
-
- w := &Notifier{
- rpcConn: cn,
- epFD: res.(*pb.EpollCreate1Response_Fd).Fd,
- fdMap: make(map[uint32]*fdInfo),
- }
-
- go w.waitAndNotify() // S/R-FIXME(b/77962828)
-
- return w, nil
-}
-
-// waitFD waits on mask for fd. The fdMap mutex must be hold.
-func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error {
- if !fi.waiting && mask == 0 {
- return nil
- }
-
- e := pb.EpollEvent{
- Events: mask.ToLinux() | unix.EPOLLET,
- Fd: fd,
- }
-
- switch {
- case !fi.waiting && mask != 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_ADD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
- <-c
-
- e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
- if e != 0 {
- return syscall.Errno(e)
- }
-
- fi.waiting = true
- case fi.waiting && mask == 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_DEL, Fd: fd}}}, false /* ignoreResult */)
- <-c
- n.rpcConn.Request(id)
-
- fi.waiting = false
- case fi.waiting && mask != 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_MOD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
- <-c
-
- e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
- if e != 0 {
- return syscall.Errno(e)
- }
- }
-
- return nil
-}
-
-// addFD adds an FD to the list of FDs observed by n.
-func (n *Notifier) addFD(fd uint32, queue *waiter.Queue) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- // Panic if we're already notifying on this FD.
- if _, ok := n.fdMap[fd]; ok {
- panic(fmt.Sprintf("File descriptor %d added twice", fd))
- }
-
- // We have nothing to wait for at the moment. Just add it to the map.
- n.fdMap[fd] = &fdInfo{queue: queue}
-}
-
-// updateFD updates the set of events the FD needs to be notified on.
-func (n *Notifier) updateFD(fd uint32) error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- if fi, ok := n.fdMap[fd]; ok {
- return n.waitFD(fd, fi, fi.queue.Events())
- }
-
- return nil
-}
-
-// RemoveFD removes an FD from the list of FDs observed by n.
-func (n *Notifier) removeFD(fd uint32) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- // Remove from map, then from epoll object.
- n.waitFD(fd, n.fdMap[fd], 0)
- delete(n.fdMap, fd)
-}
-
-// hasFD returns true if the FD is in the list of observed FDs.
-func (n *Notifier) hasFD(fd uint32) bool {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- _, ok := n.fdMap[fd]
- return ok
-}
-
-// waitAndNotify loops waiting for io event notifications from the epoll
-// object. Once notifications arrive, they are dispatched to the
-// registered queue.
-func (n *Notifier) waitAndNotify() error {
- for {
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollWait{&pb.EpollWaitRequest{Fd: n.epFD, NumEvents: 100, Msec: -1}}}, false /* ignoreResult */)
- <-c
-
- res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollWait).EpollWait.Result
- if e, ok := res.(*pb.EpollWaitResponse_ErrorNumber); ok {
- err := syscall.Errno(e.ErrorNumber)
- // NOTE(magi): I don't think epoll_wait can return EAGAIN but I'm being
- // conseratively careful here since exiting the notification thread
- // would be really bad.
- if err == syscall.EINTR || err == syscall.EAGAIN {
- continue
- }
- return err
- }
-
- n.mu.Lock()
- for _, e := range res.(*pb.EpollWaitResponse_Events).Events.Events {
- if fi, ok := n.fdMap[e.Fd]; ok {
- fi.queue.Notify(waiter.EventMaskFromLinux(e.Events))
- }
- }
- n.mu.Unlock()
- }
-}
-
-// AddFD adds an FD to the list of observed FDs.
-func (n *Notifier) AddFD(fd uint32, queue *waiter.Queue) error {
- n.addFD(fd, queue)
- return nil
-}
-
-// UpdateFD updates the set of events the FD needs to be notified on.
-func (n *Notifier) UpdateFD(fd uint32) error {
- return n.updateFD(fd)
-}
-
-// RemoveFD removes an FD from the list of observed FDs.
-func (n *Notifier) RemoveFD(fd uint32) {
- n.removeFD(fd)
-}
-
-// HasFD returns true if the FD is in the list of observed FDs.
-//
-// This should only be used by tests to assert that FDs are correctly
-// registered.
-func (n *Notifier) HasFD(fd uint32) bool {
- return n.hasFD(fd)
-}
-
-// NonBlockingPoll polls the given fd in non-blocking fashion. It is used just
-// to query the FD's current state; this method will block on the RPC response
-// although the syscall is non-blocking.
-func (n *Notifier) NonBlockingPoll(fd uint32, mask waiter.EventMask) waiter.EventMask {
- for {
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Poll{&pb.PollRequest{Fd: fd, Events: mask.ToLinux()}}}, false /* ignoreResult */)
- <-c
-
- res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_Poll).Poll.Result
- if e, ok := res.(*pb.PollResponse_ErrorNumber); ok {
- if syscall.Errno(e.ErrorNumber) == syscall.EINTR {
- continue
- }
- return mask
- }
-
- return waiter.EventMaskFromLinux(res.(*pb.PollResponse_Events).Events)
- }
-}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
deleted file mode 100644
index ddb76d9d4..000000000
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ /dev/null
@@ -1,909 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "sync/atomic"
- "syscall"
- "time"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/sentry/unimpl"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// socketOperations implements fs.FileOperations and socket.Socket for a socket
-// implemented using a host socket.
-type socketOperations struct {
- fsutil.FilePipeSeek `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FileNoFsync `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- socket.SendReceiveTimeout
-
- family int // Read-only.
- stype linux.SockType // Read-only.
- protocol int // Read-only.
-
- fd uint32 // must be O_NONBLOCK
- wq *waiter.Queue
- rpcConn *conn.RPCConnection
- notifier *notifier.Notifier
-
- // shState is the state of the connection with respect to shutdown. Because
- // we're mixing non-blocking semantics on the other side we have to adapt for
- // some strange differences between blocking and non-blocking sockets.
- shState int32
-}
-
-// Verify that we actually implement socket.Socket.
-var _ = socket.Socket(&socketOperations{})
-
-// New creates a new RPC socket.
-func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
- if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- fd := res.(*pb.SocketResponse_Fd).Fd
-
- var wq waiter.Queue
- stack.notifier.AddFD(fd, &wq)
-
- dirent := socket.NewDirent(ctx, socketDevice)
- defer dirent.DecRef()
- return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{
- family: family,
- stype: skType,
- protocol: protocol,
- wq: &wq,
- fd: fd,
- rpcConn: stack.rpcConn,
- notifier: stack.notifier,
- }), nil
-}
-
-func isBlockingErrno(err error) bool {
- return err == syscall.EAGAIN || err == syscall.EWOULDBLOCK
-}
-
-func translateIOSyscallError(err error) error {
- if isBlockingErrno(err) {
- return syserror.ErrWouldBlock
- }
- return err
-}
-
-// setShutdownFlags will set the shutdown flag so we can handle blocking reads
-// after a read shutdown.
-func (s *socketOperations) setShutdownFlags(how int) {
- var f tcpip.ShutdownFlags
- switch how {
- case linux.SHUT_RD:
- f = tcpip.ShutdownRead
- case linux.SHUT_WR:
- f = tcpip.ShutdownWrite
- case linux.SHUT_RDWR:
- f = tcpip.ShutdownWrite | tcpip.ShutdownRead
- }
-
- // Atomically update the flags.
- for {
- old := atomic.LoadInt32(&s.shState)
- if atomic.CompareAndSwapInt32(&s.shState, old, old|int32(f)) {
- break
- }
- }
-}
-
-func (s *socketOperations) resetShutdownFlags() {
- atomic.StoreInt32(&s.shState, 0)
-}
-
-func (s *socketOperations) isShutRdSet() bool {
- return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownRead) != 0
-}
-
-func (s *socketOperations) isShutWrSet() bool {
- return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownWrite) != 0
-}
-
-// Release implements fs.FileOperations.Release.
-func (s *socketOperations) Release() {
- s.notifier.RemoveFD(s.fd)
-
- // We always need to close the FD.
- _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: s.fd}}}, true /* ignoreResult */)
-}
-
-// Readiness implements waiter.Waitable.Readiness.
-func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
- return s.notifier.NonBlockingPoll(s.fd, mask)
-}
-
-// EventRegister implements waiter.Waitable.EventRegister.
-func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- s.wq.EventRegister(e, mask)
- s.notifier.UpdateFD(s.fd)
-}
-
-// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *socketOperations) EventUnregister(e *waiter.Entry) {
- s.wq.EventUnregister(e)
- s.notifier.UpdateFD(s.fd)
-}
-
-func rpcRead(t *kernel.Task, req *pb.SyscallRequest_Read) (*pb.ReadResponse_Data, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Read).Read.Result
- if e, ok := res.(*pb.ReadResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.ReadResponse_Data), nil
-}
-
-// Read implements fs.FileOperations.Read.
-func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- req := &pb.SyscallRequest_Read{&pb.ReadRequest{
- Fd: s.fd,
- Length: uint32(dst.NumBytes()),
- }}
-
- res, se := rpcRead(ctx.(*kernel.Task), req)
- if se == nil {
- n, e := dst.CopyOut(ctx, res.Data)
- return int64(n), e
- }
-
- return 0, se.ToError()
-}
-
-func rpcWrite(t *kernel.Task, req *pb.SyscallRequest_Write) (uint32, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Write).Write.Result
- if e, ok := res.(*pb.WriteResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.WriteResponse_Length).Length, nil
-}
-
-// Write implements fs.FileOperations.Write.
-func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- t := ctx.(*kernel.Task)
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, err
- }
-
- n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}})
- if n > 0 && n < uint32(src.NumBytes()) {
- // The FileOperations.Write interface expects us to return ErrWouldBlock in
- // the event of a partial write.
- return int64(n), syserror.ErrWouldBlock
- }
- return int64(n), err.ToError()
-}
-
-func rpcConnect(t *kernel.Task, fd uint32, sockaddr []byte) *syserr.Error {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Connect{&pb.ConnectRequest{Fd: uint32(fd), Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Connect).Connect.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Connect implements socket.Socket.Connect.
-func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- if !blocking {
- e := rpcConnect(t, s.fd, sockaddr)
- if e == nil {
- // Reset the shutdown state on new connects.
- s.resetShutdownFlags()
- }
- return e
- }
-
- // Register for notification when the endpoint becomes writable, then
- // initiate the connection.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventOut|waiter.EventIn|waiter.EventHUp)
- defer s.EventUnregister(&e)
- for {
- if err := rpcConnect(t, s.fd, sockaddr); err == nil || err != syserr.ErrInProgress && err != syserr.ErrAlreadyInProgress {
- if err == nil {
- // Reset the shutdown state on new connects.
- s.resetShutdownFlags()
- }
- return err
- }
-
- // It's pending, so we have to wait for a notification, and fetch the
- // result once the wait completes.
- if err := t.Block(ch); err != nil {
- return syserr.FromError(err)
- }
- }
-}
-
-func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultPayload, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Accept{&pb.AcceptRequest{Fd: fd, Peer: peer, Flags: syscall.SOCK_NONBLOCK}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Accept).Accept.Result
- if e, ok := res.(*pb.AcceptResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- return res.(*pb.AcceptResponse_Payload).Payload, nil
-}
-
-// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- payload, se := rpcAccept(t, s.fd, peerRequested)
-
- // Check if we need to block.
- if blocking && se == syserr.ErrTryAgain {
- // Register for notifications.
- e, ch := waiter.NewChannelEntry(nil)
- // FIXME(b/119878986): This waiter.EventHUp is a partial
- // measure, need to figure out how to translate linux events to
- // internal events.
- s.EventRegister(&e, waiter.EventIn|waiter.EventHUp)
- defer s.EventUnregister(&e)
-
- // Try to accept the connection again; if it fails, then wait until we
- // get a notification.
- for {
- if payload, se = rpcAccept(t, s.fd, peerRequested); se != syserr.ErrTryAgain {
- break
- }
-
- if err := t.Block(ch); err != nil {
- return 0, nil, 0, syserr.FromError(err)
- }
- }
- }
-
- // Handle any error from accept.
- if se != nil {
- return 0, nil, 0, se
- }
-
- var wq waiter.Queue
- s.notifier.AddFD(payload.Fd, &wq)
-
- dirent := socket.NewDirent(t, socketDevice)
- defer dirent.DecRef()
- fileFlags := fs.FileFlags{
- Read: true,
- Write: true,
- NonSeekable: true,
- NonBlocking: flags&linux.SOCK_NONBLOCK != 0,
- }
- file := fs.NewFile(t, dirent, fileFlags, &socketOperations{
- family: s.family,
- stype: s.stype,
- protocol: s.protocol,
- wq: &wq,
- fd: payload.Fd,
- rpcConn: s.rpcConn,
- notifier: s.notifier,
- })
- defer file.DecRef()
-
- fd, err := t.NewFDFrom(0, file, kernel.FDFlags{
- CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- })
- if err != nil {
- return 0, nil, 0, syserr.FromError(err)
- }
- t.Kernel().RecordSocket(file)
-
- if peerRequested {
- return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil
- }
-
- return fd, nil, 0, nil
-}
-
-// Bind implements socket.Socket.Bind.
-func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: s.fd, Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Listen implements socket.Socket.Listen.
-func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Listen{&pb.ListenRequest{Fd: s.fd, Backlog: int64(backlog)}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Listen).Listen.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Shutdown implements socket.Socket.Shutdown.
-func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
- // We save the shutdown state because of strange differences on linux
- // related to recvs on blocking vs. non-blocking sockets after a SHUT_RD.
- // We need to emulate that behavior on the blocking side.
- // TODO(b/120096741): There is a possible race that can exist with loopback,
- // where data could possibly be lost.
- s.setShutdownFlags(how)
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Shutdown{&pb.ShutdownRequest{Fd: s.fd, How: int64(how)}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Shutdown).Shutdown.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
-
- return nil
-}
-
-// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
- // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
- // within the sentry.
- if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
- if outLen < linux.SizeOfTimeval {
- return nil, syserr.ErrInvalidArgument
- }
-
- return linux.NsecToTimeval(s.RecvTimeout()), nil
- }
- if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
- if outLen < linux.SizeOfTimeval {
- return nil, syserr.ErrInvalidArgument
- }
-
- return linux.NsecToTimeval(s.SendTimeout()), nil
- }
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockOpt).GetSockOpt.Result
- if e, ok := res.(*pb.GetSockOptResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.GetSockOptResponse_Opt).Opt, nil
-}
-
-// SetSockOpt implements socket.Socket.SetSockOpt.
-func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
- // Because blocking actually happens within the sentry we need to inspect
- // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO,
- // and if so, we will save it and use it as the deadline for recv(2)
- // or send(2) related syscalls.
- if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
- if len(opt) < linux.SizeOfTimeval {
- return syserr.ErrInvalidArgument
- }
-
- var v linux.Timeval
- binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
- return syserr.ErrDomain
- }
- s.SetRecvTimeout(v.ToNsecCapped())
- return nil
- }
- if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
- if len(opt) < linux.SizeOfTimeval {
- return syserr.ErrInvalidArgument
- }
-
- var v linux.Timeval
- binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
- return syserr.ErrDomain
- }
- s.SetSendTimeout(v.ToNsecCapped())
- return nil
- }
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_SetSockOpt).SetSockOpt.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetPeerName).GetPeerName.Result
- if e, ok := res.(*pb.GetPeerNameResponse_ErrorNumber); ok {
- return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- addr := res.(*pb.GetPeerNameResponse_Address).Address
- return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
-}
-
-// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockName).GetSockName.Result
- if e, ok := res.(*pb.GetSockNameResponse_ErrorNumber); ok {
- return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- addr := res.(*pb.GetSockNameResponse_Address).Address
- return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
-}
-
-func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) {
- stack := t.NetworkContext().(*Stack)
-
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Ioctl{&pb.IOCtlRequest{Fd: fd, Cmd: cmd, Arg: arg}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Ioctl).Ioctl.Result
- if e, ok := res.(*pb.IOCtlResponse_ErrorNumber); ok {
- return nil, syscall.Errno(e.ErrorNumber)
- }
-
- return res.(*pb.IOCtlResponse_Value).Value, nil
-}
-
-// ifconfIoctlFromStack populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctlFromStack(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
- // If Ptr is NULL, return the necessary buffer size via Len.
- // Otherwise, write up to Len bytes starting at Ptr containing ifreq
- // structs.
- t := ctx.(*kernel.Task)
- s := t.NetworkContext().(*Stack)
- if s == nil {
- return syserr.ErrNoDevice.ToError()
- }
-
- if ifc.Ptr == 0 {
- ifc.Len = int32(len(s.Interfaces())) * int32(linux.SizeOfIFReq)
- return nil
- }
-
- max := ifc.Len
- ifc.Len = 0
- for key, ifaceAddrs := range s.InterfaceAddrs() {
- iface := s.Interfaces()[key]
- for _, ifaceAddr := range ifaceAddrs {
- // Don't write past the end of the buffer.
- if ifc.Len+int32(linux.SizeOfIFReq) > max {
- break
- }
- if ifaceAddr.Family != linux.AF_INET {
- continue
- }
-
- // Populate ifr.ifr_addr.
- ifr := linux.IFReq{}
- ifr.SetName(iface.Name)
- usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
- usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
- copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
-
- // Copy the ifr to userspace.
- dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
- ifc.Len += int32(linux.SizeOfIFReq)
- if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-// Ioctl implements fs.FileOperations.Ioctl.
-func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- t := ctx.(*kernel.Task)
-
- cmd := uint32(args[1].Int())
- arg := args[2].Pointer()
-
- var buf []byte
- switch cmd {
- // The following ioctls take 4 byte argument parameters.
- case syscall.TIOCINQ,
- syscall.TIOCOUTQ:
- buf = make([]byte, 4)
- // The following ioctls have args which are sizeof(struct ifreq).
- case syscall.SIOCGIFADDR,
- syscall.SIOCGIFBRDADDR,
- syscall.SIOCGIFDSTADDR,
- syscall.SIOCGIFFLAGS,
- syscall.SIOCGIFHWADDR,
- syscall.SIOCGIFINDEX,
- syscall.SIOCGIFMAP,
- syscall.SIOCGIFMETRIC,
- syscall.SIOCGIFMTU,
- syscall.SIOCGIFNAME,
- syscall.SIOCGIFNETMASK,
- syscall.SIOCGIFTXQLEN:
- buf = make([]byte, linux.SizeOfIFReq)
- case syscall.SIOCGIFCONF:
- // SIOCGIFCONF has slightly different behavior than the others, in that it
- // will need to populate the array of ifreqs.
- var ifc linux.IFConf
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
- return 0, err
- }
-
- if err := ifconfIoctlFromStack(ctx, io, &ifc); err != nil {
- return 0, err
- }
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
- return 0, err
-
- case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
- unimpl.EmitUnimplementedEvent(ctx)
-
- default:
- return 0, syserror.ENOTTY
- }
-
- _, err := io.CopyIn(ctx, arg, buf, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
- if err != nil {
- return 0, err
- }
-
- v, err := rpcIoctl(t, s.fd, cmd, buf)
- if err != nil {
- return 0, err
- }
-
- if len(v) != len(buf) {
- return 0, syserror.EINVAL
- }
-
- _, err = io.CopyOut(ctx, arg, v, usermem.IOOpts{
- AddressSpaceActive: true,
- })
- return 0, err
-}
-
-func rpcRecvMsg(t *kernel.Task, req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
- if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.RecvmsgResponse_Payload).Payload, nil
-}
-
-// Because we only support SO_TIMESTAMP we will search control messages for
-// that value and set it if so, all other control messages will be ignored.
-func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_ResultPayload) socket.ControlMessages {
- c := socket.ControlMessages{}
- if len(payload.GetCmsgData()) > 0 {
- // Parse the control messages looking for SO_TIMESTAMP.
- msgs, e := syscall.ParseSocketControlMessage(payload.GetCmsgData())
- if e != nil {
- return socket.ControlMessages{}
- }
- for _, m := range msgs {
- if m.Header.Level != linux.SOL_SOCKET || m.Header.Type != linux.SO_TIMESTAMP {
- continue
- }
-
- // Let's parse the time stamp and set it.
- if len(m.Data) < linux.SizeOfTimeval {
- // Give up on locating the SO_TIMESTAMP option.
- return socket.ControlMessages{}
- }
-
- var v linux.Timeval
- binary.Unmarshal(m.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- c.IP.HasTimestamp = true
- c.IP.Timestamp = v.ToNsecCapped()
- break
- }
- }
- return c
-}
-
-// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
- Fd: s.fd,
- Length: uint32(dst.NumBytes()),
- Sender: senderRequested,
- Trunc: flags&linux.MSG_TRUNC != 0,
- Peek: flags&linux.MSG_PEEK != 0,
- CmsgLength: uint32(controlDataLen),
- }}
-
- res, err := rpcRecvMsg(t, req)
- if err == nil {
- var e error
- var n int
- if len(res.Data) > 0 {
- n, e = dst.CopyOut(t, res.Data)
- if e == nil && n != len(res.Data) {
- panic("CopyOut failed to copy full buffer")
- }
- }
- c := s.extractControlMessages(res)
- return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
- }
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
- return 0, 0, nil, 0, socket.ControlMessages{}, err
- }
-
- // We'll have to block. Register for notifications and keep trying to
- // send all the data.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventIn)
- defer s.EventUnregister(&e)
-
- for {
- res, err := rpcRecvMsg(t, req)
- if err == nil {
- var e error
- var n int
- if len(res.Data) > 0 {
- n, e = dst.CopyOut(t, res.Data)
- if e == nil && n != len(res.Data) {
- panic("CopyOut failed to copy full buffer")
- }
- }
- c := s.extractControlMessages(res)
- return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
- }
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
- return 0, 0, nil, 0, socket.ControlMessages{}, err
- }
-
- if s.isShutRdSet() {
- // Blocking would have caused us to block indefinitely so we return 0,
- // this is the same behavior as Linux.
- return 0, 0, nil, 0, socket.ControlMessages{}, nil
- }
-
- if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
- }
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
- }
- }
-}
-
-func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
- if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.SendmsgResponse_Length).Length, nil
-}
-
-// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
- // Whitelist flags.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
- return 0, syserr.ErrInvalidArgument
- }
-
- // Reject Unix control messages.
- if !controlMessages.Unix.Empty() {
- return 0, syserr.ErrInvalidArgument
- }
-
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, syserr.FromError(err)
- }
-
- // TODO(bgeffon): this needs to change to map directly to a SendMsg syscall
- // in the RPC.
- totalWritten := 0
- n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: uint32(s.fd),
- Data: v,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }})
-
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
- return int(n), err
- }
-
- if n > 0 {
- totalWritten += int(n)
- v.TrimFront(int(n))
- }
-
- // We'll have to block. Register for notification and keep trying to
- // send all the data.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventOut)
- defer s.EventUnregister(&e)
-
- for {
- n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: uint32(s.fd),
- Data: v,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }})
-
- if n > 0 {
- totalWritten += int(n)
- v.TrimFront(int(n))
-
- if err == nil && totalWritten < int(src.NumBytes()) {
- continue
- }
- }
-
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
- // We eat the error in this situation.
- return int(totalWritten), nil
- }
-
- if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
- return int(totalWritten), syserr.ErrTryAgain
- }
- return int(totalWritten), syserr.FromError(err)
- }
- }
-}
-
-// State implements socket.Socket.State.
-func (s *socketOperations) State() uint32 {
- // TODO(b/127845868): Define a new rpc to query the socket state.
- return 0
-}
-
-// Type implements socket.Socket.Type.
-func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
- return s.family, s.stype, s.protocol
-}
-
-type socketProvider struct {
- family int
-}
-
-// Socket implements socket.Provider.Socket.
-func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // Check that we are using the RPC network stack.
- stack := t.NetworkContext()
- if stack == nil {
- return nil, nil
- }
-
- s, ok := stack.(*Stack)
- if !ok {
- return nil, nil
- }
-
- // Only accept TCP and UDP.
- //
- // Try to restrict the flags we will accept to minimize backwards
- // incompatibility with netstack.
- stype := stypeflags & linux.SOCK_TYPE_MASK
- switch stype {
- case syscall.SOCK_STREAM:
- switch protocol {
- case 0, syscall.IPPROTO_TCP:
- // ok
- default:
- return nil, nil
- }
- case syscall.SOCK_DGRAM:
- switch protocol {
- case 0, syscall.IPPROTO_UDP:
- // ok
- default:
- return nil, nil
- }
- default:
- return nil, nil
- }
-
- return newSocketFile(t, s, p.family, stype, protocol)
-}
-
-// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
- // Not supported by AF_INET/AF_INET6.
- return nil, nil, nil
-}
-
-func init() {
- for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
- socket.RegisterProvider(family, &socketProvider{family})
- }
-}
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
deleted file mode 100644
index f7878a760..000000000
--- a/pkg/sentry/socket/rpcinet/stack.go
+++ /dev/null
@@ -1,177 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "fmt"
- "syscall"
-
- "gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/socket/hostinet"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/unet"
-)
-
-// Stack implements inet.Stack for RPC backed sockets.
-type Stack struct {
- interfaces map[int32]inet.Interface
- interfaceAddrs map[int32][]inet.InterfaceAddr
- routes []inet.Route
- rpcConn *conn.RPCConnection
- notifier *notifier.Notifier
-}
-
-// NewStack returns a Stack containing the current state of the host network
-// stack.
-func NewStack(fd int32) (*Stack, error) {
- sock, err := unet.NewSocket(int(fd))
- if err != nil {
- return nil, err
- }
-
- stack := &Stack{
- interfaces: make(map[int32]inet.Interface),
- interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
- rpcConn: conn.NewRPCConnection(sock),
- }
-
- var e error
- stack.notifier, e = notifier.NewRPCNotifier(stack.rpcConn)
- if e != nil {
- return nil, e
- }
-
- links, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETLINK)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETLINK failed: %v", err)
- }
-
- addrs, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETADDR)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETADDR failed: %v", err)
- }
-
- e = hostinet.ExtractHostInterfaces(links, addrs, stack.interfaces, stack.interfaceAddrs)
- if e != nil {
- return nil, e
- }
-
- routes, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETROUTE)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETROUTE failed: %v", err)
- }
-
- stack.routes, e = hostinet.ExtractHostRoutes(routes)
- if e != nil {
- return nil, e
- }
-
- return stack, nil
-}
-
-// RPCReadFile will execute the ReadFile helper RPC method which avoids the
-// common pattern of open(2), read(2), close(2) by doing all three operations
-// as a single RPC. It will read the entire file or return EFBIG if the file
-// was too large.
-func (s *Stack) RPCReadFile(path string) ([]byte, *syserr.Error) {
- return s.rpcConn.RPCReadFile(path)
-}
-
-// RPCWriteFile will execute the WriteFile helper RPC method which avoids the
-// common pattern of open(2), write(2), write(2), close(2) by doing all
-// operations as a single RPC.
-func (s *Stack) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
- return s.rpcConn.RPCWriteFile(path, data)
-}
-
-// Interfaces implements inet.Stack.Interfaces.
-func (s *Stack) Interfaces() map[int32]inet.Interface {
- interfaces := make(map[int32]inet.Interface)
- for k, v := range s.interfaces {
- interfaces[k] = v
- }
- return interfaces
-}
-
-// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
-func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
- addrs := make(map[int32][]inet.InterfaceAddr)
- for k, v := range s.interfaceAddrs {
- addrs[k] = append([]inet.InterfaceAddr(nil), v...)
- }
- return addrs
-}
-
-// SupportsIPv6 implements inet.Stack.SupportsIPv6.
-func (s *Stack) SupportsIPv6() bool {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
-func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
-func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
- panic("rpcinet handles procfs directly this method should not be called")
-
-}
-
-// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
-func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-
-}
-
-// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
-func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
-func (s *Stack) TCPSACKEnabled() (bool, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// Statistics implements inet.Stack.Statistics.
-func (s *Stack) Statistics(stat interface{}, arg string) error {
- return syserr.ErrEndpointOperation.ToError()
-}
-
-// RouteTable implements inet.Stack.RouteTable.
-func (s *Stack) RouteTable() []inet.Route {
- return append([]inet.Route(nil), s.routes...)
-}
-
-// Resume implements inet.Stack.Resume.
-func (s *Stack) Resume() {}
-
-// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
-func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
-
-// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
-func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
-
-// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
-func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/socket/rpcinet/stack_unsafe.go b/pkg/sentry/socket/rpcinet/stack_unsafe.go
deleted file mode 100644
index a94bdad83..000000000
--- a/pkg/sentry/socket/rpcinet/stack_unsafe.go
+++ /dev/null
@@ -1,193 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "syscall"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserr"
-)
-
-// NewNetlinkRouteRequest builds a netlink message for getting the RIB,
-// the routing information base.
-func newNetlinkRouteRequest(proto, seq, family int) []byte {
- rr := &syscall.NetlinkRouteRequest{}
- rr.Header.Len = uint32(syscall.NLMSG_HDRLEN + syscall.SizeofRtGenmsg)
- rr.Header.Type = uint16(proto)
- rr.Header.Flags = syscall.NLM_F_DUMP | syscall.NLM_F_REQUEST
- rr.Header.Seq = uint32(seq)
- rr.Data.Family = uint8(family)
- return netlinkRRtoWireFormat(rr)
-}
-
-func netlinkRRtoWireFormat(rr *syscall.NetlinkRouteRequest) []byte {
- b := make([]byte, rr.Header.Len)
- *(*uint32)(unsafe.Pointer(&b[0:4][0])) = rr.Header.Len
- *(*uint16)(unsafe.Pointer(&b[4:6][0])) = rr.Header.Type
- *(*uint16)(unsafe.Pointer(&b[6:8][0])) = rr.Header.Flags
- *(*uint32)(unsafe.Pointer(&b[8:12][0])) = rr.Header.Seq
- *(*uint32)(unsafe.Pointer(&b[12:16][0])) = rr.Header.Pid
- b[16] = byte(rr.Data.Family)
- return b
-}
-
-func (s *Stack) getNetlinkFd() (uint32, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(syscall.AF_NETLINK), Type: int64(syscall.SOCK_RAW | syscall.SOCK_NONBLOCK), Protocol: int64(syscall.NETLINK_ROUTE)}}}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
- if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- return res.(*pb.SocketResponse_Fd).Fd, nil
-}
-
-func (s *Stack) bindNetlinkFd(fd uint32, sockaddr []byte) *syserr.Error {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: fd, Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-func (s *Stack) closeNetlinkFd(fd uint32) {
- _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: fd}}}, true /* ignoreResult */)
-}
-
-func (s *Stack) rpcSendMsg(req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
- if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.SendmsgResponse_Length).Length, nil
-}
-
-func (s *Stack) sendMsg(fd uint32, buf []byte, to []byte, flags int) (int, *syserr.Error) {
- // Whitelist flags.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
- return 0, syserr.ErrInvalidArgument
- }
-
- req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: fd,
- Data: buf,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }}
-
- n, err := s.rpcSendMsg(req)
- return int(n), err
-}
-
-func (s *Stack) rpcRecvMsg(req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
- if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.RecvmsgResponse_Payload).Payload, nil
-}
-
-func (s *Stack) recvMsg(fd, l, flags uint32) ([]byte, *syserr.Error) {
- req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
- Fd: fd,
- Length: l,
- Sender: false,
- Trunc: flags&linux.MSG_TRUNC != 0,
- Peek: flags&linux.MSG_PEEK != 0,
- }}
-
- res, err := s.rpcRecvMsg(req)
- if err != nil {
- return nil, err
- }
- return res.Data, nil
-}
-
-func (s *Stack) netlinkRequest(proto, family int) ([]byte, error) {
- fd, err := s.getNetlinkFd()
- if err != nil {
- return nil, err.ToError()
- }
- defer s.closeNetlinkFd(fd)
-
- lsa := syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}
- b := binary.Marshal(nil, usermem.ByteOrder, &lsa)
- if err := s.bindNetlinkFd(fd, b); err != nil {
- return nil, err.ToError()
- }
-
- wb := newNetlinkRouteRequest(proto, 1, family)
- _, err = s.sendMsg(fd, wb, b, 0)
- if err != nil {
- return nil, err.ToError()
- }
-
- var tab []byte
-done:
- for {
- rb, err := s.recvMsg(fd, uint32(syscall.Getpagesize()), 0)
- nr := len(rb)
- if err != nil {
- return nil, err.ToError()
- }
-
- if nr < syscall.NLMSG_HDRLEN {
- return nil, syserr.ErrInvalidArgument.ToError()
- }
-
- tab = append(tab, rb...)
- msgs, e := syscall.ParseNetlinkMessage(rb)
- if e != nil {
- return nil, e
- }
-
- for _, m := range msgs {
- if m.Header.Type == syscall.NLMSG_DONE {
- break done
- }
- if m.Header.Type == syscall.NLMSG_ERROR {
- return nil, syserr.ErrInvalidArgument.ToError()
- }
- }
- }
-
- return tab, nil
-}
-
-// DoNetlinkRouteRequest returns routing information base, also known as RIB,
-// which consists of network facility information, states and parameters.
-func (s *Stack) DoNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
- data, err := s.netlinkRequest(req, syscall.AF_UNSPEC)
- if err != nil {
- return nil, err
- }
- return syscall.ParseNetlinkMessage(data)
-}
diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
deleted file mode 100644
index b677e9eb3..000000000
--- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto
+++ /dev/null
@@ -1,352 +0,0 @@
-syntax = "proto3";
-
-// package syscall_rpc is a set of networking related system calls that can be
-// forwarded to a socket gofer.
-//
-package syscall_rpc;
-
-message SendmsgRequest {
- uint32 fd = 1;
- bytes data = 2 [ctype = CORD];
- bytes address = 3;
- bool more = 4;
- bool end_of_record = 5;
-}
-
-message SendmsgResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 length = 2;
- }
-}
-
-message IOCtlRequest {
- uint32 fd = 1;
- uint32 cmd = 2;
- bytes arg = 3;
-}
-
-message IOCtlResponse {
- oneof result {
- uint32 error_number = 1;
- bytes value = 2;
- }
-}
-
-message RecvmsgRequest {
- uint32 fd = 1;
- uint32 length = 2;
- bool sender = 3;
- bool peek = 4;
- bool trunc = 5;
- uint32 cmsg_length = 6;
-}
-
-message OpenRequest {
- bytes path = 1;
- uint32 flags = 2;
- uint32 mode = 3;
-}
-
-message OpenResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message ReadRequest {
- uint32 fd = 1;
- uint32 length = 2;
-}
-
-message ReadResponse {
- oneof result {
- uint32 error_number = 1;
- bytes data = 2 [ctype = CORD];
- }
-}
-
-message ReadFileRequest {
- string path = 1;
-}
-
-message ReadFileResponse {
- oneof result {
- uint32 error_number = 1;
- bytes data = 2 [ctype = CORD];
- }
-}
-
-message WriteRequest {
- uint32 fd = 1;
- bytes data = 2 [ctype = CORD];
-}
-
-message WriteResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 length = 2;
- }
-}
-
-message WriteFileRequest {
- string path = 1;
- bytes content = 2;
-}
-
-message WriteFileResponse {
- uint32 error_number = 1;
- uint32 written = 2;
-}
-
-message AddressResponse {
- bytes address = 1;
- uint32 length = 2;
-}
-
-message RecvmsgResponse {
- message ResultPayload {
- bytes data = 1 [ctype = CORD];
- AddressResponse address = 2;
- uint32 length = 3;
- bytes cmsg_data = 4;
- }
- oneof result {
- uint32 error_number = 1;
- ResultPayload payload = 2;
- }
-}
-
-message BindRequest {
- uint32 fd = 1;
- bytes address = 2;
-}
-
-message BindResponse {
- uint32 error_number = 1;
-}
-
-message AcceptRequest {
- uint32 fd = 1;
- bool peer = 2;
- int64 flags = 3;
-}
-
-message AcceptResponse {
- message ResultPayload {
- uint32 fd = 1;
- AddressResponse address = 2;
- }
- oneof result {
- uint32 error_number = 1;
- ResultPayload payload = 2;
- }
-}
-
-message ConnectRequest {
- uint32 fd = 1;
- bytes address = 2;
-}
-
-message ConnectResponse {
- uint32 error_number = 1;
-}
-
-message ListenRequest {
- uint32 fd = 1;
- int64 backlog = 2;
-}
-
-message ListenResponse {
- uint32 error_number = 1;
-}
-
-message ShutdownRequest {
- uint32 fd = 1;
- int64 how = 2;
-}
-
-message ShutdownResponse {
- uint32 error_number = 1;
-}
-
-message CloseRequest {
- uint32 fd = 1;
-}
-
-message CloseResponse {
- uint32 error_number = 1;
-}
-
-message GetSockOptRequest {
- uint32 fd = 1;
- int64 level = 2;
- int64 name = 3;
- uint32 length = 4;
-}
-
-message GetSockOptResponse {
- oneof result {
- uint32 error_number = 1;
- bytes opt = 2;
- }
-}
-
-message SetSockOptRequest {
- uint32 fd = 1;
- int64 level = 2;
- int64 name = 3;
- bytes opt = 4;
-}
-
-message SetSockOptResponse {
- uint32 error_number = 1;
-}
-
-message GetSockNameRequest {
- uint32 fd = 1;
-}
-
-message GetSockNameResponse {
- oneof result {
- uint32 error_number = 1;
- AddressResponse address = 2;
- }
-}
-
-message GetPeerNameRequest {
- uint32 fd = 1;
-}
-
-message GetPeerNameResponse {
- oneof result {
- uint32 error_number = 1;
- AddressResponse address = 2;
- }
-}
-
-message SocketRequest {
- int64 family = 1;
- int64 type = 2;
- int64 protocol = 3;
-}
-
-message SocketResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message EpollWaitRequest {
- uint32 fd = 1;
- uint32 num_events = 2;
- sint64 msec = 3;
-}
-
-message EpollEvent {
- uint32 fd = 1;
- uint32 events = 2;
-}
-
-message EpollEvents {
- repeated EpollEvent events = 1;
-}
-
-message EpollWaitResponse {
- oneof result {
- uint32 error_number = 1;
- EpollEvents events = 2;
- }
-}
-
-message EpollCtlRequest {
- uint32 epfd = 1;
- int64 op = 2;
- uint32 fd = 3;
- EpollEvent event = 4;
-}
-
-message EpollCtlResponse {
- uint32 error_number = 1;
-}
-
-message EpollCreate1Request {
- int64 flag = 1;
-}
-
-message EpollCreate1Response {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message PollRequest {
- uint32 fd = 1;
- uint32 events = 2;
-}
-
-message PollResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 events = 2;
- }
-}
-
-message SyscallRequest {
- oneof args {
- SocketRequest socket = 1;
- SendmsgRequest sendmsg = 2;
- RecvmsgRequest recvmsg = 3;
- BindRequest bind = 4;
- AcceptRequest accept = 5;
- ConnectRequest connect = 6;
- ListenRequest listen = 7;
- ShutdownRequest shutdown = 8;
- CloseRequest close = 9;
- GetSockOptRequest get_sock_opt = 10;
- SetSockOptRequest set_sock_opt = 11;
- GetSockNameRequest get_sock_name = 12;
- GetPeerNameRequest get_peer_name = 13;
- EpollWaitRequest epoll_wait = 14;
- EpollCtlRequest epoll_ctl = 15;
- EpollCreate1Request epoll_create1 = 16;
- PollRequest poll = 17;
- ReadRequest read = 18;
- WriteRequest write = 19;
- OpenRequest open = 20;
- IOCtlRequest ioctl = 21;
- WriteFileRequest write_file = 22;
- ReadFileRequest read_file = 23;
- }
-}
-
-message SyscallResponse {
- oneof result {
- SocketResponse socket = 1;
- SendmsgResponse sendmsg = 2;
- RecvmsgResponse recvmsg = 3;
- BindResponse bind = 4;
- AcceptResponse accept = 5;
- ConnectResponse connect = 6;
- ListenResponse listen = 7;
- ShutdownResponse shutdown = 8;
- CloseResponse close = 9;
- GetSockOptResponse get_sock_opt = 10;
- SetSockOptResponse set_sock_opt = 11;
- GetSockNameResponse get_sock_name = 12;
- GetPeerNameResponse get_peer_name = 13;
- EpollWaitResponse epoll_wait = 14;
- EpollCtlResponse epoll_ctl = 15;
- EpollCreate1Response epoll_create1 = 16;
- PollResponse poll = 17;
- ReadResponse read = 18;
- WriteResponse write = 19;
- OpenResponse open = 20;
- IOCtlResponse ioctl = 21;
- WriteFileResponse write_file = 22;
- ReadFileResponse read_file = 23;
- }
-}
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 2ec1a662d..2447f24ef 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -83,6 +83,19 @@ type EndpointReader struct {
ControlTrunc bool
}
+// Truncate calls RecvMsg on the endpoint without writing to a destination.
+func (r *EndpointReader) Truncate() error {
+ // Ignore bytes read since it will always be zero.
+ _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return err.ToError()
+ }
+ return nil
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 788ad70d2..d7ba95dff 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/ilist",
"//pkg/refs",
"//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index dea11e253..9e6fbc111 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -15,10 +15,9 @@
package transport
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/waiter"
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index e27b1c714..5dcd3d95e 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -15,9 +15,8 @@
package transport
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 529a7a7a9..fcc0da332 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -16,11 +16,11 @@
package transport
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -175,17 +175,25 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptBool sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
+ // GetSockOptBool gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
// State returns the current state of the socket, as represented by Linux in
// procfs.
@@ -851,11 +859,19 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return nil
+}
+
+func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 885758054..7f49ba864 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -116,13 +116,16 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, family, err := netstack.AddressAndFamily(sockaddr)
if err != nil {
if err == syserr.ErrAddressFamilyNotSupported {
err = syserr.ErrInvalidArgument
}
return "", err
}
+ if family != linux.AF_UNIX {
+ return "", syserr.ErrInvalidArgument
+ }
// The address is trimmed by GetAddress.
p := string(addr.Addr)
@@ -544,8 +547,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if senderRequested {
r.From = &tcpip.FullAddress{}
}
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+
+ }
+
var total int64
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
@@ -580,7 +602,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index d46421199..aa1ac720c 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -10,7 +10,8 @@ go_library(
"capability.go",
"clone.go",
"futex.go",
- "linux64.go",
+ "linux64_amd64.go",
+ "linux64_arm64.go",
"open.go",
"poll.go",
"ptrace.go",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64_amd64.go
index e603f858f..1e823b685 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,8 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package strace
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
// linuxAMD64 provides a mapping of the Linux amd64 syscalls and their argument
// types for display / formatting.
var linuxAMD64 = SyscallMap{
@@ -365,3 +372,13 @@ var linuxAMD64 = SyscallMap{
434: makeSyscallInfo("pidfd_open", Hex, Hex),
435: makeSyscallInfo("clone3", Hex, Hex),
}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.AMD64,
+ syscalls: linuxAMD64,
+ },
+ )
+}
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
new file mode 100644
index 000000000..c3ac5248d
--- /dev/null
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -0,0 +1,323 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// linuxARM64 provides a mapping of the Linux arm64 syscalls and their argument
+// types for display / formatting.
+var linuxARM64 = SyscallMap{
+ 0: makeSyscallInfo("io_setup", Hex, Hex),
+ 1: makeSyscallInfo("io_destroy", Hex),
+ 2: makeSyscallInfo("io_submit", Hex, Hex, Hex),
+ 3: makeSyscallInfo("io_cancel", Hex, Hex, Hex),
+ 4: makeSyscallInfo("io_getevents", Hex, Hex, Hex, Hex, Timespec),
+ 5: makeSyscallInfo("setxattr", Path, Path, Hex, Hex, Hex),
+ 6: makeSyscallInfo("lsetxattr", Path, Path, Hex, Hex, Hex),
+ 7: makeSyscallInfo("fsetxattr", FD, Path, Hex, Hex, Hex),
+ 8: makeSyscallInfo("getxattr", Path, Path, Hex, Hex),
+ 9: makeSyscallInfo("lgetxattr", Path, Path, Hex, Hex),
+ 10: makeSyscallInfo("fgetxattr", FD, Path, Hex, Hex),
+ 11: makeSyscallInfo("listxattr", Path, Path, Hex),
+ 12: makeSyscallInfo("llistxattr", Path, Path, Hex),
+ 13: makeSyscallInfo("flistxattr", FD, Path, Hex),
+ 14: makeSyscallInfo("removexattr", Path, Path),
+ 15: makeSyscallInfo("lremovexattr", Path, Path),
+ 16: makeSyscallInfo("fremovexattr", FD, Path),
+ 17: makeSyscallInfo("getcwd", PostPath, Hex),
+ 18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
+ 19: makeSyscallInfo("eventfd2", Hex, Hex),
+ 20: makeSyscallInfo("epoll_create1", Hex),
+ 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 23: makeSyscallInfo("dup", FD),
+ 24: makeSyscallInfo("dup3", FD, FD, Hex),
+ 25: makeSyscallInfo("fcntl", FD, Hex, Hex),
+ 26: makeSyscallInfo("inotify_init1", Hex),
+ 27: makeSyscallInfo("inotify_add_watch", Hex, Path, Hex),
+ 28: makeSyscallInfo("inotify_rm_watch", Hex, Hex),
+ 29: makeSyscallInfo("ioctl", FD, Hex, Hex),
+ 30: makeSyscallInfo("ioprio_set", Hex, Hex, Hex),
+ 31: makeSyscallInfo("ioprio_get", Hex, Hex),
+ 32: makeSyscallInfo("flock", FD, Hex),
+ 33: makeSyscallInfo("mknodat", FD, Path, Mode, Hex),
+ 34: makeSyscallInfo("mkdirat", FD, Path, Hex),
+ 35: makeSyscallInfo("unlinkat", FD, Path, Hex),
+ 36: makeSyscallInfo("symlinkat", Path, Hex, Path),
+ 37: makeSyscallInfo("linkat", FD, Path, Hex, Path, Hex),
+ 38: makeSyscallInfo("renameat", FD, Path, Hex, Path),
+ 39: makeSyscallInfo("umount2", Path, Hex),
+ 40: makeSyscallInfo("mount", Path, Path, Path, Hex, Path),
+ 41: makeSyscallInfo("pivot_root", Path, Path),
+ 42: makeSyscallInfo("nfsservctl", Hex, Hex, Hex),
+ 43: makeSyscallInfo("statfs", Path, Hex),
+ 44: makeSyscallInfo("fstatfs", FD, Hex),
+ 45: makeSyscallInfo("truncate", Path, Hex),
+ 46: makeSyscallInfo("ftruncate", FD, Hex),
+ 47: makeSyscallInfo("fallocate", FD, Hex, Hex, Hex),
+ 48: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
+ 49: makeSyscallInfo("chdir", Path),
+ 50: makeSyscallInfo("fchdir", FD),
+ 51: makeSyscallInfo("chroot", Path),
+ 52: makeSyscallInfo("fchmod", FD, Mode),
+ 53: makeSyscallInfo("fchmodat", FD, Path, Mode),
+ 54: makeSyscallInfo("fchownat", FD, Path, Hex, Hex, Hex),
+ 55: makeSyscallInfo("fchown", FD, Hex, Hex),
+ 56: makeSyscallInfo("openat", FD, Path, OpenFlags, Mode),
+ 57: makeSyscallInfo("close", FD),
+ 58: makeSyscallInfo("vhangup"),
+ 59: makeSyscallInfo("pipe2", PipeFDs, Hex),
+ 60: makeSyscallInfo("quotactl", Hex, Hex, Hex, Hex),
+ 61: makeSyscallInfo("getdents64", FD, Hex, Hex),
+ 62: makeSyscallInfo("lseek", Hex, Hex, Hex),
+ 63: makeSyscallInfo("read", FD, ReadBuffer, Hex),
+ 64: makeSyscallInfo("write", FD, WriteBuffer, Hex),
+ 65: makeSyscallInfo("readv", FD, ReadIOVec, Hex),
+ 66: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
+ 67: makeSyscallInfo("pread64", FD, ReadBuffer, Hex, Hex),
+ 68: makeSyscallInfo("pwrite64", FD, WriteBuffer, Hex, Hex),
+ 69: makeSyscallInfo("preadv", FD, ReadIOVec, Hex, Hex),
+ 70: makeSyscallInfo("pwritev", FD, WriteIOVec, Hex, Hex),
+ 71: makeSyscallInfo("sendfile", FD, FD, Hex, Hex),
+ 72: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 73: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
+ 74: makeSyscallInfo("signalfd4", Hex, Hex, Hex, Hex),
+ 75: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
+ 76: makeSyscallInfo("splice", FD, Hex, FD, Hex, Hex, Hex),
+ 77: makeSyscallInfo("tee", FD, FD, Hex, Hex),
+ 78: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
+ 79: makeSyscallInfo("fstatat", FD, Path, Stat, Hex),
+ 80: makeSyscallInfo("fstat", FD, Stat),
+ 81: makeSyscallInfo("sync"),
+ 82: makeSyscallInfo("fsync", FD),
+ 83: makeSyscallInfo("fdatasync", FD),
+ 84: makeSyscallInfo("sync_file_range", FD, Hex, Hex, Hex),
+ 85: makeSyscallInfo("timerfd_create", Hex, Hex),
+ 86: makeSyscallInfo("timerfd_settime", FD, Hex, ItimerSpec, PostItimerSpec),
+ 87: makeSyscallInfo("timerfd_gettime", FD, PostItimerSpec),
+ 88: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
+ 89: makeSyscallInfo("acct", Hex),
+ 90: makeSyscallInfo("capget", CapHeader, PostCapData),
+ 91: makeSyscallInfo("capset", CapHeader, CapData),
+ 92: makeSyscallInfo("personality", Hex),
+ 93: makeSyscallInfo("exit", Hex),
+ 94: makeSyscallInfo("exit_group", Hex),
+ 95: makeSyscallInfo("waitid", Hex, Hex, Hex, Hex, Rusage),
+ 96: makeSyscallInfo("set_tid_address", Hex),
+ 97: makeSyscallInfo("unshare", CloneFlags),
+ 98: makeSyscallInfo("futex", Hex, FutexOp, Hex, Timespec, Hex, Hex),
+ 99: makeSyscallInfo("set_robust_list", Hex, Hex),
+ 100: makeSyscallInfo("get_robust_list", Hex, Hex, Hex),
+ 101: makeSyscallInfo("nanosleep", Timespec, PostTimespec),
+ 102: makeSyscallInfo("getitimer", ItimerType, PostItimerVal),
+ 103: makeSyscallInfo("setitimer", ItimerType, ItimerVal, PostItimerVal),
+ 104: makeSyscallInfo("kexec_load", Hex, Hex, Hex, Hex),
+ 105: makeSyscallInfo("init_module", Hex, Hex, Hex),
+ 106: makeSyscallInfo("delete_module", Hex, Hex),
+ 107: makeSyscallInfo("timer_create", Hex, Hex, Hex),
+ 108: makeSyscallInfo("timer_gettime", Hex, PostItimerSpec),
+ 109: makeSyscallInfo("timer_getoverrun", Hex),
+ 110: makeSyscallInfo("timer_settime", Hex, Hex, ItimerSpec, PostItimerSpec),
+ 111: makeSyscallInfo("timer_delete", Hex),
+ 112: makeSyscallInfo("clock_settime", Hex, Timespec),
+ 113: makeSyscallInfo("clock_gettime", Hex, PostTimespec),
+ 114: makeSyscallInfo("clock_getres", Hex, PostTimespec),
+ 115: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
+ 116: makeSyscallInfo("syslog", Hex, Hex, Hex),
+ 117: makeSyscallInfo("ptrace", PtraceRequest, Hex, Hex, Hex),
+ 118: makeSyscallInfo("sched_setparam", Hex, Hex),
+ 119: makeSyscallInfo("sched_setscheduler", Hex, Hex, Hex),
+ 120: makeSyscallInfo("sched_getscheduler", Hex),
+ 121: makeSyscallInfo("sched_getparam", Hex, Hex),
+ 122: makeSyscallInfo("sched_setaffinity", Hex, Hex, Hex),
+ 123: makeSyscallInfo("sched_getaffinity", Hex, Hex, Hex),
+ 124: makeSyscallInfo("sched_yield"),
+ 125: makeSyscallInfo("sched_get_priority_max", Hex),
+ 126: makeSyscallInfo("sched_get_priority_min", Hex),
+ 127: makeSyscallInfo("sched_rr_get_interval", Hex, Hex),
+ 128: makeSyscallInfo("restart_syscall"),
+ 129: makeSyscallInfo("kill", Hex, Signal),
+ 130: makeSyscallInfo("tkill", Hex, Signal),
+ 131: makeSyscallInfo("tgkill", Hex, Hex, Signal),
+ 132: makeSyscallInfo("sigaltstack", Hex, Hex),
+ 133: makeSyscallInfo("rt_sigsuspend", Hex),
+ 134: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction),
+ 135: makeSyscallInfo("rt_sigprocmask", SignalMaskAction, SigSet, PostSigSet, Hex),
+ 136: makeSyscallInfo("rt_sigpending", Hex),
+ 137: makeSyscallInfo("rt_sigtimedwait", SigSet, Hex, Timespec, Hex),
+ 138: makeSyscallInfo("rt_sigqueueinfo", Hex, Signal, Hex),
+ 139: makeSyscallInfo("rt_sigreturn"),
+ 140: makeSyscallInfo("setpriority", Hex, Hex, Hex),
+ 141: makeSyscallInfo("getpriority", Hex, Hex),
+ 142: makeSyscallInfo("reboot", Hex, Hex, Hex, Hex),
+ 143: makeSyscallInfo("setregid", Hex, Hex),
+ 144: makeSyscallInfo("setgid", Hex),
+ 145: makeSyscallInfo("setreuid", Hex, Hex),
+ 146: makeSyscallInfo("setuid", Hex),
+ 147: makeSyscallInfo("setresuid", Hex, Hex, Hex),
+ 148: makeSyscallInfo("getresuid", Hex, Hex, Hex),
+ 149: makeSyscallInfo("setresgid", Hex, Hex, Hex),
+ 150: makeSyscallInfo("getresgid", Hex, Hex, Hex),
+ 151: makeSyscallInfo("setfsuid", Hex),
+ 152: makeSyscallInfo("setfsgid", Hex),
+ 153: makeSyscallInfo("times", Hex),
+ 154: makeSyscallInfo("setpgid", Hex, Hex),
+ 155: makeSyscallInfo("getpgid", Hex),
+ 156: makeSyscallInfo("getsid", Hex),
+ 157: makeSyscallInfo("setsid"),
+ 158: makeSyscallInfo("getgroups", Hex, Hex),
+ 159: makeSyscallInfo("setgroups", Hex, Hex),
+ 160: makeSyscallInfo("uname", Uname),
+ 161: makeSyscallInfo("sethostname", Hex, Hex),
+ 162: makeSyscallInfo("setdomainname", Hex, Hex),
+ 163: makeSyscallInfo("getrlimit", Hex, Hex),
+ 164: makeSyscallInfo("setrlimit", Hex, Hex),
+ 165: makeSyscallInfo("getrusage", Hex, Rusage),
+ 166: makeSyscallInfo("umask", Hex),
+ 167: makeSyscallInfo("prctl", Hex, Hex, Hex, Hex, Hex),
+ 168: makeSyscallInfo("getcpu", Hex, Hex, Hex),
+ 169: makeSyscallInfo("gettimeofday", Timeval, Hex),
+ 170: makeSyscallInfo("settimeofday", Timeval, Hex),
+ 171: makeSyscallInfo("adjtimex", Hex),
+ 172: makeSyscallInfo("getpid"),
+ 173: makeSyscallInfo("getppid"),
+ 174: makeSyscallInfo("getuid"),
+ 175: makeSyscallInfo("geteuid"),
+ 176: makeSyscallInfo("getgid"),
+ 177: makeSyscallInfo("getegid"),
+ 178: makeSyscallInfo("gettid"),
+ 179: makeSyscallInfo("sysinfo", Hex),
+ 180: makeSyscallInfo("mq_open", Hex, Hex, Hex, Hex),
+ 181: makeSyscallInfo("mq_unlink", Hex),
+ 182: makeSyscallInfo("mq_timedsend", Hex, Hex, Hex, Hex, Hex),
+ 183: makeSyscallInfo("mq_timedreceive", Hex, Hex, Hex, Hex, Hex),
+ 184: makeSyscallInfo("mq_notify", Hex, Hex),
+ 185: makeSyscallInfo("mq_getsetattr", Hex, Hex, Hex),
+ 186: makeSyscallInfo("msgget", Hex, Hex),
+ 187: makeSyscallInfo("msgctl", Hex, Hex, Hex),
+ 188: makeSyscallInfo("msgrcv", Hex, Hex, Hex, Hex, Hex),
+ 189: makeSyscallInfo("msgsnd", Hex, Hex, Hex, Hex),
+ 190: makeSyscallInfo("semget", Hex, Hex, Hex),
+ 191: makeSyscallInfo("semctl", Hex, Hex, Hex, Hex),
+ 192: makeSyscallInfo("semtimedop", Hex, Hex, Hex, Hex),
+ 193: makeSyscallInfo("semop", Hex, Hex, Hex),
+ 194: makeSyscallInfo("shmget", Hex, Hex, Hex),
+ 195: makeSyscallInfo("shmctl", Hex, Hex, Hex),
+ 196: makeSyscallInfo("shmat", Hex, Hex, Hex),
+ 197: makeSyscallInfo("shmdt", Hex),
+ 198: makeSyscallInfo("socket", SockFamily, SockType, SockProtocol),
+ 199: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
+ 200: makeSyscallInfo("bind", FD, SockAddr, Hex),
+ 201: makeSyscallInfo("listen", FD, Hex),
+ 202: makeSyscallInfo("accept", FD, PostSockAddr, SockLen),
+ 203: makeSyscallInfo("connect", FD, SockAddr, Hex),
+ 204: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
+ 205: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
+ 206: makeSyscallInfo("sendto", FD, Hex, Hex, Hex, SockAddr, Hex),
+ 207: makeSyscallInfo("recvfrom", FD, Hex, Hex, Hex, PostSockAddr, SockLen),
+ 208: makeSyscallInfo("setsockopt", FD, Hex, Hex, Hex, Hex),
+ 209: makeSyscallInfo("getsockopt", FD, Hex, Hex, Hex, Hex),
+ 210: makeSyscallInfo("shutdown", FD, Hex),
+ 211: makeSyscallInfo("sendmsg", FD, SendMsgHdr, Hex),
+ 212: makeSyscallInfo("recvmsg", FD, RecvMsgHdr, Hex),
+ 213: makeSyscallInfo("readahead", Hex, Hex, Hex),
+ 214: makeSyscallInfo("brk", Hex),
+ 215: makeSyscallInfo("munmap", Hex, Hex),
+ 216: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
+ 217: makeSyscallInfo("add_key", Hex, Hex, Hex, Hex, Hex),
+ 218: makeSyscallInfo("request_key", Hex, Hex, Hex, Hex),
+ 219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
+ 220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
+ 221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
+ 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
+ 224: makeSyscallInfo("swapon", Hex, Hex),
+ 225: makeSyscallInfo("swapoff", Hex),
+ 226: makeSyscallInfo("mprotect", Hex, Hex, Hex),
+ 227: makeSyscallInfo("msync", Hex, Hex, Hex),
+ 228: makeSyscallInfo("mlock", Hex, Hex),
+ 229: makeSyscallInfo("munlock", Hex, Hex),
+ 230: makeSyscallInfo("mlockall", Hex),
+ 231: makeSyscallInfo("munlockall"),
+ 232: makeSyscallInfo("mincore", Hex, Hex, Hex),
+ 233: makeSyscallInfo("madvise", Hex, Hex, Hex),
+ 234: makeSyscallInfo("remap_file_pages", Hex, Hex, Hex, Hex, Hex),
+ 235: makeSyscallInfo("mbind", Hex, Hex, Hex, Hex, Hex, Hex),
+ 236: makeSyscallInfo("get_mempolicy", Hex, Hex, Hex, Hex, Hex),
+ 237: makeSyscallInfo("set_mempolicy", Hex, Hex, Hex),
+ 238: makeSyscallInfo("migrate_pages", Hex, Hex, Hex, Hex),
+ 239: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
+ 240: makeSyscallInfo("rt_tgsigqueueinfo", Hex, Hex, Signal, Hex),
+ 241: makeSyscallInfo("perf_event_open", Hex, Hex, Hex, Hex, Hex),
+ 242: makeSyscallInfo("accept4", FD, PostSockAddr, SockLen, SockFlags),
+ 243: makeSyscallInfo("recvmmsg", FD, Hex, Hex, Hex, Hex),
+
+ 260: makeSyscallInfo("wait4", Hex, Hex, Hex, Rusage),
+ 261: makeSyscallInfo("prlimit64", Hex, Hex, Hex, Hex),
+ 262: makeSyscallInfo("fanotify_init", Hex, Hex),
+ 263: makeSyscallInfo("fanotify_mark", Hex, Hex, Hex, Hex, Hex),
+ 264: makeSyscallInfo("name_to_handle_at", FD, Hex, Hex, Hex, Hex),
+ 265: makeSyscallInfo("open_by_handle_at", FD, Hex, Hex),
+ 266: makeSyscallInfo("clock_adjtime", Hex, Hex),
+ 267: makeSyscallInfo("syncfs", FD),
+ 268: makeSyscallInfo("setns", FD, Hex),
+ 269: makeSyscallInfo("sendmmsg", FD, Hex, Hex, Hex),
+ 270: makeSyscallInfo("process_vm_readv", Hex, ReadIOVec, Hex, IOVec, Hex, Hex),
+ 271: makeSyscallInfo("process_vm_writev", Hex, IOVec, Hex, WriteIOVec, Hex, Hex),
+ 272: makeSyscallInfo("kcmp", Hex, Hex, Hex, Hex, Hex),
+ 273: makeSyscallInfo("finit_module", Hex, Hex, Hex),
+ 274: makeSyscallInfo("sched_setattr", Hex, Hex, Hex),
+ 275: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
+ 276: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
+ 277: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 278: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 279: makeSyscallInfo("memfd_create", Path, Hex),
+ 280: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 281: makeSyscallInfo("execveat", FD, Path, Hex, Hex, Hex),
+ 282: makeSyscallInfo("userfaultfd", Hex),
+ 283: makeSyscallInfo("membarrier", Hex),
+ 284: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 285: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 286: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 287: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 291: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 292: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 293: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
+}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.ARM64,
+ syscalls: linuxARM64})
+}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 51f2efb39..b6d7177f4 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */)
+ fa, _, err := netstack.AddressAndFamily(b)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index e5d486c4e..24e29a2ba 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -250,14 +250,7 @@ type syscallTable struct {
syscalls SyscallMap
}
-// syscallTables contains all syscall tables.
-var syscallTables = []syscallTable{
- {
- os: abi.Linux,
- arch: arch.AMD64,
- syscalls: linuxAMD64,
- },
-}
+var syscallTables []syscallTable
// Lookup returns the SyscallMap for the OS/Arch combination. The returned map
// must not be changed.
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 6766ba587..430d796ba 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -13,6 +13,8 @@ go_library(
"sigset.go",
"sys_aio.go",
"sys_capability.go",
+ "sys_clone_amd64.go",
+ "sys_clone_arm64.go",
"sys_epoll.go",
"sys_eventfd.go",
"sys_file.go",
@@ -30,6 +32,7 @@ go_library(
"sys_random.go",
"sys_read.go",
"sys_rlimit.go",
+ "sys_rseq.go",
"sys_rusage.go",
"sys_sched.go",
"sys_seccomp.go",
@@ -90,6 +93,7 @@ go_library(
"//pkg/sentry/syscalls",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 1d9018c96..60469549d 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -16,13 +16,13 @@ package linux
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index 272ae9991..6b2920900 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -228,10 +228,10 @@ var AMD64 = &kernel.SyscallTable{
185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
186: syscalls.Supported("gettid", Gettid),
187: syscalls.Supported("readahead", Readahead),
- 188: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil),
+ 188: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 191: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil),
+ 191: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
@@ -377,7 +377,7 @@ var AMD64 = &kernel.SyscallTable{
331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
332: syscalls.Supported("statx", Statx),
333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
- 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+ 334: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
// Linux skips ahead to syscall 424 to sync numbers between arches.
424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index 3b584eed9..8c1b20911 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -41,10 +41,10 @@ var ARM64 = &kernel.SyscallTable{
2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 5: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil),
+ 5: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 8: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil),
+ 8: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
@@ -307,7 +307,7 @@ var ARM64 = &kernel.SyscallTable{
290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
291: syscalls.Supported("statx", Statx),
292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
- 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+ 293: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
// Linux skips ahead to syscall 424 to sync numbers between arches.
424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_clone_amd64.go b/pkg/sentry/syscalls/linux/sys_clone_amd64.go
new file mode 100644
index 000000000..dd43cf18d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_clone_amd64.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors. We implement the default one in linux 3.11
+// x86_64:
+// sys_clone(clone_flags, newsp, parent_tidptr, child_tidptr, tls_val)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ childTID := args[3].Pointer()
+ tls := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_clone_arm64.go b/pkg/sentry/syscalls/linux/sys_clone_arm64.go
new file mode 100644
index 000000000..cf68a8949
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_clone_arm64.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors, and we implement the default one in linux 3.11
+// arm64(kernel/fork.c with CONFIG_CLONE_BACKWARDS defined in the config file):
+// sys_clone(clone_flags, newsp, parent_tidptr, tls_val, child_tidptr)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ tls := args[3].Pointer()
+ childTID := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go
new file mode 100644
index 000000000..90db10ea6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rseq.go
@@ -0,0 +1,48 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// RSeq implements syscall rseq(2).
+func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Uint()
+ flags := args[2].Int()
+ signature := args[3].Uint()
+
+ if !t.RSeqAvailable() {
+ // Event for applications that want rseq on a configuration
+ // that doesn't support them.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ switch flags {
+ case 0:
+ // Register.
+ return 0, nil, t.SetRSeq(addr, length, signature)
+ case linux.RSEQ_FLAG_UNREGISTER:
+ return 0, nil, t.ClearRSeq(addr, length, signature)
+ default:
+ // Unknown flag.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index d57ffb3a1..4a8bc24a2 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -39,10 +39,13 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
+ defer segment.DecRef()
return uintptr(segment.ID), nil, nil
}
// findSegment retrives a shm segment by the given id.
+//
+// findSegment returns a reference on Shm.
func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
r := t.IPCNamespace().ShmRegistry()
segment := r.FindByID(id)
@@ -63,6 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{
Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC,
@@ -72,7 +76,6 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- defer segment.DecRef()
addr, err = t.MemoryManager().MMap(t, opts)
return uintptr(addr), nil, err
}
@@ -105,6 +108,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
stat, err := segment.IPCStat(t)
if err == nil {
@@ -128,6 +132,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef()
switch cmd {
case linux.IPC_SET:
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 4b5aafcc0..cda517a81 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -41,7 +41,7 @@ const maxListenBacklog = 1024
const maxAddrLen = 200
// maxOptLen is the maximum sockopt parameter length we're willing to accept.
-const maxOptLen = 1024
+const maxOptLen = 1024 * 8
// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
// willing to accept. Note that this limit is smaller than Linux, which allows
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 4115116ff..b47c3b5c4 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -220,19 +220,6 @@ func clone(t *kernel.Task, flags int, stack usermem.Addr, parentTID usermem.Addr
return uintptr(ntid), ctrl, err
}
-// Clone implements linux syscall clone(2).
-// sys_clone has so many flavors. We implement the default one in linux 3.11
-// x86_64:
-// sys_clone(clone_flags, newsp, parent_tidptr, child_tidptr, tls_val)
-func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- flags := int(args[0].Int())
- stack := args[1].Pointer()
- parentTID := args[2].Pointer()
- childTID := args[3].Pointer()
- tls := args[4].Pointer()
- return clone(t, flags, stack, parentTID, childTID, tls)
-}
-
// Fork implements Linux syscall fork(2).
func Fork(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
// "A call to fork() is equivalent to a call to clone(2) specifying flags
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 97d9a65ea..23d20da6f 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -25,12 +25,12 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-// Getxattr implements linux syscall getxattr(2).
-func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+// GetXattr implements linux syscall getxattr(2).
+func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
pathAddr := args[0].Pointer()
nameAddr := args[1].Pointer()
valueAddr := args[2].Pointer()
- size := args[3].SizeT()
+ size := uint64(args[3].SizeT())
path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
if err != nil {
@@ -39,22 +39,28 @@ func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
valueLen := 0
err = fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
- value, err := getxattr(t, d, dirPath, nameAddr)
+ // If getxattr(2) is called with size 0, the size of the value will be
+ // returned successfully even if it is nonzero. In that case, we need to
+ // retrieve the entire attribute value so we can return the correct size.
+ requestedSize := size
+ if size == 0 || size > linux.XATTR_SIZE_MAX {
+ requestedSize = linux.XATTR_SIZE_MAX
+ }
+
+ value, err := getXattr(t, d, dirPath, nameAddr, uint64(requestedSize))
if err != nil {
return err
}
valueLen = len(value)
- if size == 0 {
- return nil
- }
- if size > linux.XATTR_SIZE_MAX {
- size = linux.XATTR_SIZE_MAX
- }
- if valueLen > int(size) {
+ if uint64(valueLen) > requestedSize {
return syserror.ERANGE
}
+ // Skip copying out the attribute value if size is 0.
+ if size == 0 {
+ return nil
+ }
_, err = t.CopyOutBytes(valueAddr, []byte(value))
return err
})
@@ -64,8 +70,8 @@ func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return uintptr(valueLen), nil, nil
}
-// getxattr implements getxattr from the given *fs.Dirent.
-func getxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr) (string, error) {
+// getXattr implements getxattr(2) from the given *fs.Dirent.
+func getXattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr, size uint64) (string, error) {
if dirPath && !fs.IsDir(d.Inode.StableAttr) {
return "", syserror.ENOTDIR
}
@@ -83,15 +89,15 @@ func getxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr)
return "", syserror.EOPNOTSUPP
}
- return d.Inode.Getxattr(name)
+ return d.Inode.GetXattr(t, name, size)
}
-// Setxattr implements linux syscall setxattr(2).
-func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+// SetXattr implements linux syscall setxattr(2).
+func SetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
pathAddr := args[0].Pointer()
nameAddr := args[1].Pointer()
valueAddr := args[2].Pointer()
- size := args[3].SizeT()
+ size := uint64(args[3].SizeT())
flags := args[4].Uint()
path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
@@ -104,12 +110,12 @@ func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
- return setxattr(t, d, dirPath, nameAddr, valueAddr, size, flags)
+ return setXattr(t, d, dirPath, nameAddr, valueAddr, uint64(size), flags)
})
}
-// setxattr implements setxattr from the given *fs.Dirent.
-func setxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr usermem.Addr, size uint, flags uint32) error {
+// setXattr implements setxattr(2) from the given *fs.Dirent.
+func setXattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr usermem.Addr, size uint64, flags uint32) error {
if dirPath && !fs.IsDir(d.Inode.StableAttr) {
return syserror.ENOTDIR
}
@@ -136,7 +142,7 @@ func setxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr us
return syserror.EOPNOTSUPP
}
- return d.Inode.Setxattr(name, value)
+ return d.Inode.SetXattr(t, d, name, value, flags)
}
func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 18e212dff..3cde3a0be 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "seqatomic_parameters_unsafe.go",
package = "time",
suffix = "Parameters",
- template = "//pkg/syncutil:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "Parameters",
},
@@ -36,7 +36,7 @@ go_library(
deps = [
"//pkg/log",
"//pkg/metric",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go
index 318503277..f9a93115d 100644
--- a/pkg/sentry/time/calibrated_clock.go
+++ b/pkg/sentry/time/calibrated_clock.go
@@ -17,11 +17,11 @@
package time
import (
- "sync"
"time"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD
index c32fe3241..5518ac3d0 100644
--- a/pkg/sentry/usage/BUILD
+++ b/pkg/sentry/usage/BUILD
@@ -18,5 +18,6 @@ go_library(
deps = [
"//pkg/bits",
"//pkg/memutil",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index d6ef644d8..538c645eb 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -17,12 +17,12 @@ package usage
import (
"fmt"
"os"
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// MemoryKind represents a type of memory used by the application.
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index e3e554b88..35c7be259 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -9,6 +9,7 @@ go_library(
"context.go",
"debug.go",
"dentry.go",
+ "device.go",
"file_description.go",
"file_description_impl_util.go",
"filesystem.go",
@@ -33,7 +34,7 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
"//pkg/sentry/usermem",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
@@ -53,6 +54,7 @@ go_test(
"//pkg/sentry/context/contexttest",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 1bc9c4a38..486a76475 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -16,9 +16,9 @@ package vfs
import (
"fmt"
- "sync"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go
new file mode 100644
index 000000000..cb672e36f
--- /dev/null
+++ b/pkg/sentry/vfs/device.go
@@ -0,0 +1,100 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// DeviceKind indicates whether a device is a block or character device.
+type DeviceKind uint32
+
+const (
+ // BlockDevice indicates a block device.
+ BlockDevice DeviceKind = iota
+
+ // CharDevice indicates a character device.
+ CharDevice
+)
+
+// String implements fmt.Stringer.String.
+func (kind DeviceKind) String() string {
+ switch kind {
+ case BlockDevice:
+ return "block"
+ case CharDevice:
+ return "character"
+ default:
+ return fmt.Sprintf("invalid device kind %d", kind)
+ }
+}
+
+type devTuple struct {
+ kind DeviceKind
+ major uint32
+ minor uint32
+}
+
+// A Device backs device special files.
+type Device interface {
+ // Open returns a FileDescription representing this device.
+ Open(ctx context.Context, mnt *Mount, d *Dentry, opts OpenOptions) (*FileDescription, error)
+}
+
+type registeredDevice struct {
+ dev Device
+ opts RegisterDeviceOptions
+}
+
+// RegisterDeviceOptions contains options to
+// VirtualFilesystem.RegisterDevice().
+type RegisterDeviceOptions struct {
+ // GroupName is the name shown for this device registration in
+ // /proc/devices. If GroupName is empty, this registration will not be
+ // shown in /proc/devices.
+ GroupName string
+}
+
+// RegisterDevice registers the given Device in vfs with the given major and
+// minor device numbers.
+func (vfs *VirtualFilesystem) RegisterDevice(kind DeviceKind, major, minor uint32, dev Device, opts *RegisterDeviceOptions) error {
+ tup := devTuple{kind, major, minor}
+ vfs.devicesMu.Lock()
+ defer vfs.devicesMu.Unlock()
+ if existing, ok := vfs.devices[tup]; ok {
+ return fmt.Errorf("%s device number (%d, %d) is already registered to device type %T", kind, major, minor, existing.dev)
+ }
+ vfs.devices[tup] = &registeredDevice{
+ dev: dev,
+ opts: *opts,
+ }
+ return nil
+}
+
+// OpenDeviceSpecialFile returns a FileDescription representing the given
+// device.
+func (vfs *VirtualFilesystem) OpenDeviceSpecialFile(ctx context.Context, mnt *Mount, d *Dentry, kind DeviceKind, major, minor uint32, opts *OpenOptions) (*FileDescription, error) {
+ tup := devTuple{kind, major, minor}
+ vfs.devicesMu.RLock()
+ defer vfs.devicesMu.RUnlock()
+ rd, ok := vfs.devices[tup]
+ if !ok {
+ return nil, syserror.ENXIO
+ }
+ return rd.dev.Open(ctx, mnt, d, *opts)
+}
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 0b053201a..6afe280bc 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -61,11 +61,25 @@ type FileDescriptionOptions struct {
// If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is
// usually only the case if O_DIRECT would actually have an effect.
AllowDirectIO bool
+
+ // If UseDentryMetadata is true, calls to FileDescription methods that
+ // interact with file and filesystem metadata (Stat, SetStat, StatFS,
+ // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling
+ // the corresponding FilesystemImpl methods instead of the corresponding
+ // FileDescriptionImpl methods.
+ //
+ // UseDentryMetadata is intended for file descriptions that are implemented
+ // outside of individual filesystems, such as pipes, sockets, and device
+ // special files. FileDescriptions for which UseDentryMetadata is true may
+ // embed DentryMetadataFileDescriptionImpl to obtain appropriate
+ // implementations of FileDescriptionImpl methods that should not be
+ // called.
+ UseDentryMetadata bool
}
-// Init must be called before first use of fd. It takes ownership of references
-// on mnt and d held by the caller. statusFlags is the initial file description
-// status flags, which is usually the full set of flags passed to open(2).
+// Init must be called before first use of fd. It takes references on mnt and
+// d. statusFlags is the initial file description status flags, which is
+// usually the full set of flags passed to open(2).
func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) {
fd.refs = 1
fd.statusFlags = statusFlags | linux.O_LARGEFILE
@@ -73,6 +87,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mn
mount: mnt,
dentry: d,
}
+ fd.vd.IncRef()
fd.opts = *opts
fd.impl = impl
}
@@ -140,7 +155,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede
// sense. However, the check as actually implemented seems to be "O_APPEND
// cannot be changed if the file is marked as append-only".
if (flags^oldFlags)&linux.O_APPEND != 0 {
- stat, err := fd.impl.Stat(ctx, StatOptions{
+ stat, err := fd.Stat(ctx, StatOptions{
// There is no mask bit for stx_attributes.
Mask: 0,
// Linux just reads inode::i_flags directly.
@@ -154,7 +169,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede
}
}
if (flags&linux.O_NOATIME != 0) && (oldFlags&linux.O_NOATIME == 0) {
- stat, err := fd.impl.Stat(ctx, StatOptions{
+ stat, err := fd.Stat(ctx, StatOptions{
Mask: linux.STATX_UID,
// Linux's inode_owner_or_capable() just reads inode::i_uid
// directly.
@@ -348,17 +363,47 @@ func (fd *FileDescription) OnClose(ctx context.Context) error {
// Stat returns metadata for the file represented by fd.
func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(rp)
+ return stat, err
+ }
return fd.impl.Stat(ctx, opts)
}
// SetStat updates metadata for the file represented by fd.
func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
return fd.impl.SetStat(ctx, opts)
}
// StatFS returns metadata for the filesystem containing the file represented
// by fd.
func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp)
+ vfsObj.putResolvingPath(rp)
+ return statfs, err
+ }
return fd.impl.StatFS(ctx)
}
@@ -417,6 +462,16 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
// Listxattr returns all extended attribute names for the file represented by
// fd.
func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp)
+ vfsObj.putResolvingPath(rp)
+ return names, err
+ }
names, err := fd.impl.Listxattr(ctx)
if err == syserror.ENOTSUP {
// Linux doesn't actually return ENOTSUP in this case; instead,
@@ -431,18 +486,48 @@ func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) {
// Getxattr returns the value associated with the given extended attribute for
// the file represented by fd.
func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, name)
+ vfsObj.putResolvingPath(rp)
+ return val, err
+ }
return fd.impl.Getxattr(ctx, name)
}
// Setxattr changes the value associated with the given extended attribute for
// the file represented by fd.
func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
return fd.impl.Setxattr(ctx, opts)
}
// Removexattr removes the given extended attribute from the file represented
// by fd.
func (fd *FileDescription) Removexattr(ctx context.Context, name string) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ vfsObj.putResolvingPath(rp)
+ return err
+ }
return fd.impl.Removexattr(ctx, name)
}
@@ -464,7 +549,7 @@ func (fd *FileDescription) MappedName(ctx context.Context) string {
// DeviceID implements memmap.MappingIdentity.DeviceID.
func (fd *FileDescription) DeviceID() uint64 {
- stat, err := fd.impl.Stat(context.Background(), StatOptions{
+ stat, err := fd.Stat(context.Background(), StatOptions{
// There is no STATX_DEV; we assume that Stat will return it if it's
// available regardless of mask.
Mask: 0,
@@ -480,7 +565,7 @@ func (fd *FileDescription) DeviceID() uint64 {
// InodeID implements memmap.MappingIdentity.InodeID.
func (fd *FileDescription) InodeID() uint64 {
- stat, err := fd.impl.Stat(context.Background(), StatOptions{
+ stat, err := fd.Stat(context.Background(), StatOptions{
Mask: linux.STATX_INO,
// fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly.
Sync: linux.AT_STATX_DONT_SYNC,
@@ -493,5 +578,5 @@ func (fd *FileDescription) InodeID() uint64 {
// Msync implements memmap.MappingIdentity.Msync.
func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error {
- return fd.impl.Sync(ctx)
+ return fd.Sync(ctx)
}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 3df49991c..c00b3c84b 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -17,13 +17,13 @@ package vfs
import (
"bytes"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -177,6 +177,21 @@ func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src userme
return 0, syserror.EISDIR
}
+// DentryMetadataFileDescriptionImpl may be embedded by implementations of
+// FileDescriptionImpl for which FileDescriptionOptions.UseDentryMetadata is
+// true to obtain implementations of Stat and SetStat that panic.
+type DentryMetadataFileDescriptionImpl struct{}
+
+// Stat implements FileDescriptionImpl.Stat.
+func (DentryMetadataFileDescriptionImpl) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ panic("illegal call to DentryMetadataFileDescriptionImpl.Stat")
+}
+
+// SetStat implements FileDescriptionImpl.SetStat.
+func (DentryMetadataFileDescriptionImpl) SetStat(ctx context.Context, opts SetStatOptions) error {
+ panic("illegal call to DentryMetadataFileDescriptionImpl.SetStat")
+}
+
// DynamicBytesFileDescriptionImpl may be embedded by implementations of
// FileDescriptionImpl that represent read-only regular files whose contents
// are backed by a bytes.Buffer that is regenerated when necessary, consistent
@@ -199,6 +214,17 @@ type DynamicBytesSource interface {
Generate(ctx context.Context, buf *bytes.Buffer) error
}
+// StaticData implements DynamicBytesSource over a static string.
+type StaticData struct {
+ Data string
+}
+
+// Generate implements DynamicBytesSource.
+func (s *StaticData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(s.Data)
+ return nil
+}
+
// SetDataSource must be called exactly once on fd before first use.
func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) {
fd.data = data
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 678be07fe..9ed58512f 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -89,7 +89,7 @@ func TestGenCountFD(t *testing.T) {
creds := auth.CredentialsFromContext(ctx)
vfsObj := New() // vfs.New()
- vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{})
+ vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{}, &RegisterFilesystemTypeOptions{})
mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &GetFilesystemOptions{})
if err != nil {
t.Fatalf("failed to create testfs root mount: %v", err)
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 89bd58864..ea78f555b 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -418,17 +418,38 @@ type FilesystemImpl interface {
UnlinkAt(ctx context.Context, rp *ResolvingPath) error
// ListxattrAt returns all extended attribute names for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem,
+ // ListxattrAt returns nil. (See FileDescription.Listxattr for an
+ // explanation.)
ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error)
// GetxattrAt returns the value associated with the given extended
// attribute for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem, GetxattrAt
+ // returns ENOTSUP.
GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error)
// SetxattrAt changes the value associated with the given extended
// attribute for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem, SetxattrAt
+ // returns ENOTSUP.
SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error
// RemovexattrAt removes the given extended attribute from the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem,
+ // RemovexattrAt returns ENOTSUP.
RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
// PrependPath prepends a path from vd to vd.Mount().Root() to b.
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index c335e206d..023301780 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -15,6 +15,7 @@
package vfs
import (
+ "bytes"
"fmt"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -43,28 +44,70 @@ type GetFilesystemOptions struct {
InternalData interface{}
}
+type registeredFilesystemType struct {
+ fsType FilesystemType
+ opts RegisterFilesystemTypeOptions
+}
+
+// RegisterFilesystemTypeOptions contains options to
+// VirtualFilesystem.RegisterFilesystem().
+type RegisterFilesystemTypeOptions struct {
+ // If AllowUserMount is true, allow calls to VirtualFilesystem.MountAt()
+ // for which MountOptions.InternalMount == false to use this filesystem
+ // type.
+ AllowUserMount bool
+
+ // If AllowUserList is true, make this filesystem type visible in
+ // /proc/filesystems.
+ AllowUserList bool
+
+ // If RequiresDevice is true, indicate that mounting this filesystem
+ // requires a block device as the mount source in /proc/filesystems.
+ RequiresDevice bool
+}
+
// RegisterFilesystemType registers the given FilesystemType in vfs with the
// given name.
-func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType) error {
+func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) error {
vfs.fsTypesMu.Lock()
defer vfs.fsTypesMu.Unlock()
if existing, ok := vfs.fsTypes[name]; ok {
- return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing)
+ return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing.fsType)
+ }
+ vfs.fsTypes[name] = &registeredFilesystemType{
+ fsType: fsType,
+ opts: *opts,
}
- vfs.fsTypes[name] = fsType
return nil
}
// MustRegisterFilesystemType is equivalent to RegisterFilesystemType but
// panics on failure.
-func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType) {
- if err := vfs.RegisterFilesystemType(name, fsType); err != nil {
+func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) {
+ if err := vfs.RegisterFilesystemType(name, fsType, opts); err != nil {
panic(fmt.Sprintf("failed to register filesystem type %T: %v", fsType, err))
}
}
-func (vfs *VirtualFilesystem) getFilesystemType(name string) FilesystemType {
+func (vfs *VirtualFilesystem) getFilesystemType(name string) *registeredFilesystemType {
vfs.fsTypesMu.RLock()
defer vfs.fsTypesMu.RUnlock()
return vfs.fsTypes[name]
}
+
+// GenerateProcFilesystems emits the contents of /proc/filesystems for vfs to
+// buf.
+func (vfs *VirtualFilesystem) GenerateProcFilesystems(buf *bytes.Buffer) {
+ vfs.fsTypesMu.RLock()
+ defer vfs.fsTypesMu.RUnlock()
+ for name, rft := range vfs.fsTypes {
+ if !rft.opts.AllowUserList {
+ continue
+ }
+ var nodev string
+ if !rft.opts.RequiresDevice {
+ nodev = "nodev"
+ }
+ fmt.Fprintf(buf, "%s\t%s\n", nodev, name)
+ }
+}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index ec23ab0dd..00177b371 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -112,11 +112,11 @@ type MountNamespace struct {
// configured by the given arguments. A reference is taken on the returned
// MountNamespace.
func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
- fsType := vfs.getFilesystemType(fsTypeName)
- if fsType == nil {
+ rft := vfs.getFilesystemType(fsTypeName)
+ if rft == nil {
return nil, syserror.ENODEV
}
- fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
if err != nil {
return nil, err
}
@@ -136,11 +136,14 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
// MountAt creates and mounts a Filesystem configured by the given arguments.
func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
- fsType := vfs.getFilesystemType(fsTypeName)
- if fsType == nil {
+ rft := vfs.getFilesystemType(fsTypeName)
+ if rft == nil {
return syserror.ENODEV
}
- fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
+ if !opts.InternalMount && !rft.opts.AllowUserMount {
+ return syserror.ENODEV
+ }
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
if err != nil {
return err
}
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index adff0b94b..3b933468d 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -17,8 +17,9 @@ package vfs
import (
"fmt"
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestMountTableLookupEmpty(t *testing.T) {
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index ab13fa461..bd90d36c4 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -26,7 +26,7 @@ import (
"sync/atomic"
"unsafe"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// mountKey represents the location at which a Mount is mounted. It is
@@ -75,7 +75,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq syncutil.SeqCount
+ seq sync.SeqCount
seed uint32 // for hashing keys
// size holds both length (number of elements) and capacity (number of
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 87d2b0d1c..b7774bf28 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -50,6 +50,10 @@ type MknodOptions struct {
type MountOptions struct {
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
GetFilesystemOptions GetFilesystemOptions
+
+ // If InternalMount is true, allow the use of filesystem types for which
+ // RegisterFilesystemTypeOptions.AllowUserMount == false.
+ InternalMount bool
}
// OpenOptions contains options to VirtualFilesystem.OpenAt() and
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
index 8e155654f..cf80df90e 100644
--- a/pkg/sentry/vfs/pathname.go
+++ b/pkg/sentry/vfs/pathname.go
@@ -15,10 +15,9 @@
package vfs
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index f1edb0680..d279d05ca 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -30,6 +30,26 @@ const (
MayExec = 1
)
+// OnlyRead returns true if access _only_ allows read.
+func (a AccessTypes) OnlyRead() bool {
+ return a == MayRead
+}
+
+// MayRead returns true if access allows read.
+func (a AccessTypes) MayRead() bool {
+ return a&MayRead != 0
+}
+
+// MayWrite returns true if access allows write.
+func (a AccessTypes) MayWrite() bool {
+ return a&MayWrite != 0
+}
+
+// MayExec returns true if access allows exec.
+func (a AccessTypes) MayExec() bool {
+ return a&MayExec != 0
+}
+
// GenericCheckPermissions checks that creds has the given access rights on a
// file with the given permissions, UID, and GID, subject to the rules of
// fs/namei.c:generic_permission(). isDir is true if the file is a directory.
@@ -53,7 +73,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
}
// CAP_DAC_READ_SEARCH allows the caller to read and search arbitrary
// directories, and read arbitrary non-directory files.
- if (isDir && (ats&MayWrite == 0)) || ats == MayRead {
+ if (isDir && !ats.MayWrite()) || ats.OnlyRead() {
if creds.HasCapability(linux.CAP_DAC_READ_SEARCH) {
return nil
}
@@ -61,7 +81,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
// CAP_DAC_OVERRIDE allows arbitrary access to directories, read/write
// access to non-directory files, and execute access to non-directory files
// for which at least one execute bit is set.
- if isDir || (ats&MayExec == 0) || (mode&0111 != 0) {
+ if isDir || !ats.MayExec() || (mode&0111 != 0) {
if creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
return nil
}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index f0641d314..8a0b382f6 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -16,11 +16,11 @@ package vfs
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 3e4df8558..1f21b0b31 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -29,12 +29,12 @@ package vfs
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -75,23 +75,29 @@ type VirtualFilesystem struct {
// mountpoints is analogous to Linux's mountpoint_hashtable.
mountpoints map[*Dentry]map[*Mount]struct{}
+ // devices contains all registered Devices. devices is protected by
+ // devicesMu.
+ devicesMu sync.RWMutex
+ devices map[devTuple]*registeredDevice
+
+ // fsTypes contains all registered FilesystemTypes. fsTypes is protected by
+ // fsTypesMu.
+ fsTypesMu sync.RWMutex
+ fsTypes map[string]*registeredFilesystemType
+
// filesystems contains all Filesystems. filesystems is protected by
// filesystemsMu.
filesystemsMu sync.Mutex
filesystems map[*Filesystem]struct{}
-
- // fsTypes contains all FilesystemTypes that are usable in the
- // VirtualFilesystem. fsTypes is protected by fsTypesMu.
- fsTypesMu sync.RWMutex
- fsTypes map[string]FilesystemType
}
// New returns a new VirtualFilesystem with no mounts or FilesystemTypes.
func New() *VirtualFilesystem {
vfs := &VirtualFilesystem{
mountpoints: make(map[*Dentry]map[*Mount]struct{}),
+ devices: make(map[devTuple]*registeredDevice),
+ fsTypes: make(map[string]*registeredFilesystemType),
filesystems: make(map[*Filesystem]struct{}),
- fsTypes: make(map[string]FilesystemType),
}
vfs.mounts.Init()
return vfs
diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD
index 4d8435265..28f21f13d 100644
--- a/pkg/sentry/watchdog/BUILD
+++ b/pkg/sentry/watchdog/BUILD
@@ -13,5 +13,6 @@ go_library(
"//pkg/metric",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 5e4611333..bfb2fac26 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -32,7 +32,6 @@ package watchdog
import (
"bytes"
"fmt"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -40,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Opts configures the watchdog.
diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go
index 130806c86..af47e2ba1 100644
--- a/pkg/sleep/sleep_test.go
+++ b/pkg/sleep/sleep_test.go
@@ -376,6 +376,37 @@ func TestRace(t *testing.T) {
}
}
+// TestRaceInOrder tests that multiple wakers can continuously send wake requests to
+// the sleeper and that the wakers are retrieved in the order asserted.
+func TestRaceInOrder(t *testing.T) {
+ const wakers = 100
+ const wakeRequests = 10000
+
+ w := make([]Waker, wakers)
+ s := Sleeper{}
+
+ // Associate each waker and start goroutines that will assert them.
+ for i := range w {
+ s.AddWaker(&w[i], i)
+ }
+ go func() {
+ n := 0
+ for n < wakeRequests {
+ wk := w[n%len(w)]
+ wk.Assert()
+ n++
+ }
+ }()
+
+ // Wait for all wake up notifications from all wakers.
+ for i := 0; i < wakeRequests; i++ {
+ v, _ := s.Fetch(true)
+ if got, want := v, i%wakers; got != want {
+ t.Fatalf("got %d want %d", got, want)
+ }
+ }
+}
+
// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up
// from 4 wakers when at least one is already asserted.
func BenchmarkSleeperMultiSelect(b *testing.B) {
diff --git a/pkg/syncutil/BUILD b/pkg/sync/BUILD
index cb1f41628..e8cd16b8f 100644
--- a/pkg/syncutil/BUILD
+++ b/pkg/sync/BUILD
@@ -29,8 +29,9 @@ go_template(
)
go_library(
- name = "syncutil",
+ name = "sync",
srcs = [
+ "aliases.go",
"downgradable_rwmutex_unsafe.go",
"memmove_unsafe.go",
"norace_unsafe.go",
@@ -38,15 +39,15 @@ go_library(
"seqcount.go",
"syncutil.go",
],
- importpath = "gvisor.dev/gvisor/pkg/syncutil",
+ importpath = "gvisor.dev/gvisor/pkg/sync",
)
go_test(
- name = "syncutil_test",
+ name = "sync_test",
size = "small",
srcs = [
"downgradable_rwmutex_test.go",
"seqcount_test.go",
],
- embed = [":syncutil"],
+ embed = [":sync"],
)
diff --git a/pkg/syncutil/LICENSE b/pkg/sync/LICENSE
index 6a66aea5e..6a66aea5e 100644
--- a/pkg/syncutil/LICENSE
+++ b/pkg/sync/LICENSE
diff --git a/pkg/syncutil/README.md b/pkg/sync/README.md
index 2183c4e20..2183c4e20 100644
--- a/pkg/syncutil/README.md
+++ b/pkg/sync/README.md
diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go
new file mode 100644
index 000000000..20c7ca041
--- /dev/null
+++ b/pkg/sync/aliases.go
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sync
+
+import (
+ "sync"
+)
+
+// Aliases of standard library types.
+type (
+ // Mutex is an alias of sync.Mutex.
+ Mutex = sync.Mutex
+
+ // RWMutex is an alias of sync.RWMutex.
+ RWMutex = sync.RWMutex
+
+ // Cond is an alias of sync.Cond.
+ Cond = sync.Cond
+
+ // Locker is an alias of sync.Locker.
+ Locker = sync.Locker
+
+ // Once is an alias of sync.Once.
+ Once = sync.Once
+
+ // Pool is an alias of sync.Pool.
+ Pool = sync.Pool
+
+ // WaitGroup is an alias of sync.WaitGroup.
+ WaitGroup = sync.WaitGroup
+
+ // Map is an alias of sync.Map.
+ Map = sync.Map
+)
diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/sync/atomicptr_unsafe.go
index 525c4beed..525c4beed 100644
--- a/pkg/syncutil/atomicptr_unsafe.go
+++ b/pkg/sync/atomicptr_unsafe.go
diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD
index 63f411a90..418eda29c 100644
--- a/pkg/syncutil/atomicptrtest/BUILD
+++ b/pkg/sync/atomicptrtest/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "atomicptr_int_unsafe.go",
package = "atomicptr",
suffix = "Int",
- template = "//pkg/syncutil:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "int",
},
@@ -18,7 +18,7 @@ go_template_instance(
go_library(
name = "atomicptr",
srcs = ["atomicptr_int_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr",
+ importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr",
)
go_test(
diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go
index 8fdc5112e..8fdc5112e 100644
--- a/pkg/syncutil/atomicptrtest/atomicptr_test.go
+++ b/pkg/sync/atomicptrtest/atomicptr_test.go
diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/sync/downgradable_rwmutex_test.go
index ffaf7ecc7..f04496bc5 100644
--- a/pkg/syncutil/downgradable_rwmutex_test.go
+++ b/pkg/sync/downgradable_rwmutex_test.go
@@ -9,7 +9,7 @@
// addition of downgradingWriter and the renaming of num_iterations to
// numIterations to shut up Golint.
-package syncutil
+package sync
import (
"fmt"
diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/sync/downgradable_rwmutex_unsafe.go
index 51e11555d..9bb55cd3a 100644
--- a/pkg/syncutil/downgradable_rwmutex_unsafe.go
+++ b/pkg/sync/downgradable_rwmutex_unsafe.go
@@ -16,7 +16,7 @@
// - RUnlock -> Lock (via writerSem)
// - DowngradeLock -> RLock (via readerSem)
-package syncutil
+package sync
import (
"sync"
diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go
index 348675baa..ad4a3a37e 100644
--- a/pkg/syncutil/memmove_unsafe.go
+++ b/pkg/sync/memmove_unsafe.go
@@ -8,7 +8,7 @@
// Check go:linkname function signatures when updating Go version.
-package syncutil
+package sync
import (
"unsafe"
diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/sync/norace_unsafe.go
index 0a0a9deda..006055dd6 100644
--- a/pkg/syncutil/norace_unsafe.go
+++ b/pkg/sync/norace_unsafe.go
@@ -5,7 +5,7 @@
// +build !race
-package syncutil
+package sync
import (
"unsafe"
diff --git a/pkg/syncutil/race_unsafe.go b/pkg/sync/race_unsafe.go
index 206067ec1..31d8fa9a6 100644
--- a/pkg/syncutil/race_unsafe.go
+++ b/pkg/sync/race_unsafe.go
@@ -5,7 +5,7 @@
// +build race
-package syncutil
+package sync
import (
"runtime"
diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go
index cb6d2eb22..eda6fb131 100644
--- a/pkg/syncutil/seqatomic_unsafe.go
+++ b/pkg/sync/seqatomic_unsafe.go
@@ -13,7 +13,7 @@ import (
"strings"
"unsafe"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Value is a required type parameter.
@@ -26,17 +26,17 @@ type Value struct{}
// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
// with any writer critical sections in sc.
-func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value {
+func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value {
// This function doesn't use SeqAtomicTryLoad because doing so is
// measurably, significantly (~20%) slower; Go is awful at inlining.
var val Value
for {
epoch := sc.BeginRead()
- if syncutil.RaceEnabled {
+ if sync.RaceEnabled {
// runtime.RaceDisable() doesn't actually stop the race detector,
// so it can't help us here. Instead, call runtime.memmove
// directly, which is not instrumented by the race detector.
- syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
// This is ~40% faster for short reads than going through memmove.
val = *ptr
@@ -52,10 +52,10 @@ func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value {
// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read
// would race with a writer critical section, SeqAtomicTryLoad returns
// (unspecified, false).
-func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) {
+func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) {
var val Value
- if syncutil.RaceEnabled {
- syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ if sync.RaceEnabled {
+ sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
val = *ptr
}
@@ -66,7 +66,7 @@ func init() {
var val Value
typ := reflect.TypeOf(val)
name := typ.Name()
- if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 {
+ if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 {
panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
}
}
diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD
index ba18f3238..eba21518d 100644
--- a/pkg/syncutil/seqatomictest/BUILD
+++ b/pkg/sync/seqatomictest/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "seqatomic_int_unsafe.go",
package = "seqatomic",
suffix = "Int",
- template = "//pkg/syncutil:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "int",
},
@@ -18,9 +18,9 @@ go_template_instance(
go_library(
name = "seqatomic",
srcs = ["seqatomic_int_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic",
+ importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic",
deps = [
- "//pkg/syncutil",
+ "//pkg/sync",
],
)
@@ -29,7 +29,5 @@ go_test(
size = "small",
srcs = ["seqatomic_test.go"],
embed = [":seqatomic"],
- deps = [
- "//pkg/syncutil",
- ],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go
index b0db44999..2c4568b07 100644
--- a/pkg/syncutil/seqatomictest/seqatomic_test.go
+++ b/pkg/sync/seqatomictest/seqatomic_test.go
@@ -19,11 +19,11 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSeqAtomicLoadUncontended(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
const want = 1
data := want
if got := SeqAtomicLoadInt(&seq, &data); got != want {
@@ -32,7 +32,7 @@ func TestSeqAtomicLoadUncontended(t *testing.T) {
}
func TestSeqAtomicLoadAfterWrite(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
var data int
const want = 1
seq.BeginWrite()
@@ -44,7 +44,7 @@ func TestSeqAtomicLoadAfterWrite(t *testing.T) {
}
func TestSeqAtomicLoadDuringWrite(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
var data int
const want = 1
seq.BeginWrite()
@@ -59,7 +59,7 @@ func TestSeqAtomicLoadDuringWrite(t *testing.T) {
}
func TestSeqAtomicTryLoadUncontended(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
const want = 1
data := want
epoch := seq.BeginRead()
@@ -69,7 +69,7 @@ func TestSeqAtomicTryLoadUncontended(t *testing.T) {
}
func TestSeqAtomicTryLoadDuringWrite(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
var data int
epoch := seq.BeginRead()
seq.BeginWrite()
@@ -80,7 +80,7 @@ func TestSeqAtomicTryLoadDuringWrite(t *testing.T) {
}
func TestSeqAtomicTryLoadAfterWrite(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
var data int
epoch := seq.BeginRead()
seq.BeginWrite()
@@ -91,7 +91,7 @@ func TestSeqAtomicTryLoadAfterWrite(t *testing.T) {
}
func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
const want = 42
data := want
b.RunParallel(func(pb *testing.PB) {
@@ -104,7 +104,7 @@ func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) {
}
func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
const want = 42
data := want
b.RunParallel(func(pb *testing.PB) {
diff --git a/pkg/syncutil/seqcount.go b/pkg/sync/seqcount.go
index 11d8dbfaa..a1e895352 100644
--- a/pkg/syncutil/seqcount.go
+++ b/pkg/sync/seqcount.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package syncutil
+package sync
import (
"fmt"
diff --git a/pkg/syncutil/seqcount_test.go b/pkg/sync/seqcount_test.go
index 14d6aedea..6eb7b4b59 100644
--- a/pkg/syncutil/seqcount_test.go
+++ b/pkg/sync/seqcount_test.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package syncutil
+package sync
import (
"reflect"
diff --git a/pkg/syncutil/syncutil.go b/pkg/sync/syncutil.go
index 66e750d06..b16cf5333 100644
--- a/pkg/syncutil/syncutil.go
+++ b/pkg/sync/syncutil.go
@@ -3,5 +3,5 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package syncutil provides synchronization primitives.
-package syncutil
+// Package sync provides synchronization primitives.
+package sync
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 65d4d0cd8..ebc8d0209 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -10,10 +10,12 @@ go_library(
"packet_buffer_state.go",
"tcpip.go",
"time_unsafe.go",
+ "timer.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip/buffer",
"//pkg/tcpip/iptables",
"//pkg/waiter",
@@ -26,3 +28,10 @@ go_test(
srcs = ["tcpip_test.go"],
embed = [":tcpip"],
)
+
+go_test(
+ name = "tcpip_x_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ deps = [":tcpip"],
+)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 78df5a0b1..3df7d18d3 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -9,6 +9,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index cd6ce930a..a2f44b496 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -20,9 +20,9 @@ import (
"errors"
"io"
"net"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 2f15bf1f1..885d773b0 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -33,6 +33,9 @@ type NetworkChecker func(*testing.T, []header.Network)
// TransportChecker is a function to check a property of a transport packet.
type TransportChecker func(*testing.T, header.Transport)
+// ControlMessagesChecker is a function to check a property of ancillary data.
+type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
+
// IPv4 checks the validity and properties of the given IPv4 packet. It is
// expected to be used in conjunction with other network checkers for specific
// properties. For example, to check the source and destination address, one
@@ -158,6 +161,19 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
+func ReceiveTOS(want uint8) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTOS {
+ t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
+ }
+ if got := cm.TOS; got != want {
+ t.Fatalf("got cm.TOS = %d, want %d", got, want)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -754,3 +770,9 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
}
}
}
+
+// NDPRS creates a checker that checks that the packet contains a valid NDP
+// Router Solicitation message (as per the raw wire format).
+func NDPRS() NetworkChecker {
+ return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize)
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index f1d837196..cd747d100 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -20,6 +20,7 @@ go_library(
"ndp_neighbor_solicit.go",
"ndp_options.go",
"ndp_router_advert.go",
+ "ndp_router_solicit.go",
"tcp.go",
"udp.go",
],
@@ -44,6 +45,7 @@ go_test(
],
deps = [
":header",
+ "//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"@com_github_google_go-cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index fc671e439..70e6ce095 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -15,6 +15,7 @@
package header
import (
+ "crypto/sha256"
"encoding/binary"
"strings"
@@ -83,6 +84,13 @@ const (
// The address is ff02::1.
IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ // IPv6AllRoutersMulticastAddress is a link-local multicast group that
+ // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all routers on a link.
+ //
+ // The address is ff02::2.
+ IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
@@ -102,6 +110,11 @@ const (
// bytes including and after the IIDOffsetInIPv6Address-th byte are
// for the IID.
IIDOffsetInIPv6Address = 8
+
+ // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes
+ // for the secret key used to generate an opaque interface identifier as
+ // outlined by RFC 7217.
+ OpaqueIIDSecretKeyMinBytes = 16
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -326,3 +339,85 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
}
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+
+// IsV6UniqueLocalAddress determines if the provided address is an IPv6
+// unique-local address (within the prefix FC00::/7).
+func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ // According to RFC 4193 section 3.1, a unique local address has the prefix
+ // FC00::/7.
+ return (addr[0] & 0xfe) == 0xfc
+}
+
+// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
+// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
+//
+// The opaque IID is generated from the cryptographic hash of the concatenation
+// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret
+// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and
+// MUST be generated to a pseudo-random number. See RFC 4086 for randomness
+// requirements for security.
+//
+// If buf has enough capacity for the IID (IIDSize bytes), a new underlying
+// array for the buffer will not be allocated.
+func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte {
+ // As per RFC 7217 section 5, the opaque identifier can be generated as a
+ // cryptographic hash of the concatenation of each of the function parameters.
+ // Note, we omit the optional Network_ID field.
+ h := sha256.New()
+ // h.Write never returns an error.
+ h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address]))
+ h.Write([]byte(nicName))
+ h.Write([]byte{dadCounter})
+ h.Write(secretKey)
+
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ return append(buf, sum[:IIDSize]...)
+}
+
+// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with
+// an opaque IID.
+func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address {
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey))
+}
+
+// IPv6AddressScope is the scope of an IPv6 address.
+type IPv6AddressScope int
+
+const (
+ // LinkLocalScope indicates a link-local address.
+ LinkLocalScope IPv6AddressScope = iota
+
+ // UniqueLocalScope indicates a unique-local address.
+ UniqueLocalScope
+
+ // GlobalScope indicates a global address.
+ GlobalScope
+)
+
+// ScopeForIPv6Address returns the scope for an IPv6 address.
+func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
+ if len(addr) != IPv6AddressSize {
+ return GlobalScope, tcpip.ErrBadAddress
+ }
+
+ switch {
+ case IsV6LinkLocalAddress(addr):
+ return LinkLocalScope, nil
+
+ case IsV6UniqueLocalAddress(addr):
+ return UniqueLocalScope, nil
+
+ default:
+ return GlobalScope, nil
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 42c5c6fc1..29f54bc57 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -15,14 +15,23 @@
package header_test
import (
+ "bytes"
+ "crypto/sha256"
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+)
func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6}
@@ -43,3 +52,251 @@ func TestLinkLocalAddr(t *testing.T) {
t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
}
}
+
+func TestAppendOpaqueInterfaceIdentifier(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ prefix: header.IPv6LinkLocalPrefix.Subnet(),
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ h := sha256.New()
+ h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address]))
+ h.Write([]byte(test.nicName))
+ h.Write([]byte{test.dadCounter})
+ if k := test.secretKey; k != nil {
+ h.Write(k)
+ }
+ var hashSum [sha256.Size]byte
+ h.Sum(hashSum[:0])
+ want := hashSum[:header.IIDSize]
+
+ // Passing a nil buffer should result in a new buffer returned with the
+ // IID.
+ if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+
+ // Passing a buffer with sufficient capacity for the IID should populate
+ // the buffer provided.
+ var iidBuf [header.IIDSize]byte
+ if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ if got := iidBuf[:]; !bytes.Equal(got, want) {
+ t.Errorf("got iidBuf = %x, want = %x", got, want)
+ }
+ })
+ }
+}
+
+func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ prefix := header.IPv6LinkLocalPrefix.Subnet()
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ addrBytes := [header.IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ test.nicName,
+ test.dadCounter,
+ test.secretKey,
+ ))
+
+ if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want {
+ t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ })
+ }
+}
+
+func TestIsV6UniqueLocalAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Unique 1",
+ addr: uniqueLocalAddr1,
+ expected: true,
+ },
+ {
+ name: "Valid Unique 2",
+ addr: uniqueLocalAddr1,
+ expected: true,
+ },
+ {
+ name: "Link Local",
+ addr: linkLocalAddr,
+ expected: false,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4",
+ addr: "\x01\x02\x03\x04",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestScopeForIPv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ scope header.IPv6AddressScope
+ err *tcpip.Error
+ }{
+ {
+ name: "Unique Local",
+ addr: uniqueLocalAddr1,
+ scope: header.UniqueLocalScope,
+ err: nil,
+ },
+ {
+ name: "Link Local",
+ addr: linkLocalAddr,
+ scope: header.LinkLocalScope,
+ err: nil,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ scope: header.GlobalScope,
+ err: nil,
+ },
+ {
+ name: "IPv4",
+ addr: "\x01\x02\x03\x04",
+ scope: header.GlobalScope,
+ err: tcpip.ErrBadAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := header.ScopeForIPv6Address(test.addr)
+ if err != test.err {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err)
+ }
+ if got != test.scope {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go
new file mode 100644
index 000000000..9e67ba95d
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_solicit.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.1 for more details.
+type NDPRouterSolicit []byte
+
+const (
+ // NDPRSMinimumSize is the minimum size of a valid NDP Router
+ // Solicitation message (body of an ICMPv6 packet).
+ NDPRSMinimumSize = 4
+
+ // ndpRSOptionsOffset is the start of the NDP options in an
+ // NDPRouterSolicit.
+ ndpRSOptionsOffset = 4
+)
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterSolicit) Options() NDPOptions {
+ return NDPOptions(b[ndpRSOptionsOffset:])
+}
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index cc5f531e2..64769c333 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -11,5 +11,8 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
visibility = ["//visibility:public"],
- deps = ["//pkg/tcpip/buffer"],
+ deps = [
+ "//pkg/log",
+ "//pkg/tcpip/buffer",
+ ],
)
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index 68c68d4aa..647970133 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -16,66 +16,114 @@
// tool.
package iptables
+// Table names.
const (
- tablenameNat = "nat"
- tablenameMangle = "mangle"
+ TablenameNat = "nat"
+ TablenameMangle = "mangle"
+ TablenameFilter = "filter"
)
// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
const (
- chainNamePrerouting = "PREROUTING"
- chainNameInput = "INPUT"
- chainNameForward = "FORWARD"
- chainNameOutput = "OUTPUT"
- chainNamePostrouting = "POSTROUTING"
+ ChainNamePrerouting = "PREROUTING"
+ ChainNameInput = "INPUT"
+ ChainNameForward = "FORWARD"
+ ChainNameOutput = "OUTPUT"
+ ChainNamePostrouting = "POSTROUTING"
)
+// HookUnset indicates that there is no hook set for an entrypoint or
+// underflow.
+const HookUnset = -1
+
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables() IPTables {
+ // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
+ // iotas.
return IPTables{
Tables: map[string]Table{
- tablenameNat: Table{
- BuiltinChains: map[Hook]Chain{
- Prerouting: unconditionalAcceptChain(chainNamePrerouting),
- Input: unconditionalAcceptChain(chainNameInput),
- Output: unconditionalAcceptChain(chainNameOutput),
- Postrouting: unconditionalAcceptChain(chainNamePostrouting),
+ TablenameNat: Table{
+ Rules: []Rule{
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Prerouting: 0,
+ Input: 1,
+ Output: 2,
+ Postrouting: 3,
},
- DefaultTargets: map[Hook]Target{
- Prerouting: UnconditionalAcceptTarget{},
- Input: UnconditionalAcceptTarget{},
- Output: UnconditionalAcceptTarget{},
- Postrouting: UnconditionalAcceptTarget{},
+ Underflows: map[Hook]int{
+ Prerouting: 0,
+ Input: 1,
+ Output: 2,
+ Postrouting: 3,
},
- UserChains: map[string]Chain{},
+ UserChains: map[string]int{},
},
- tablenameMangle: Table{
- BuiltinChains: map[Hook]Chain{
- Prerouting: unconditionalAcceptChain(chainNamePrerouting),
- Output: unconditionalAcceptChain(chainNameOutput),
+ TablenameMangle: Table{
+ Rules: []Rule{
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Prerouting: 0,
+ Output: 1,
},
- DefaultTargets: map[Hook]Target{
- Prerouting: UnconditionalAcceptTarget{},
- Output: UnconditionalAcceptTarget{},
+ Underflows: map[Hook]int{
+ Prerouting: 0,
+ Output: 1,
},
- UserChains: map[string]Chain{},
+ UserChains: map[string]int{},
+ },
+ TablenameFilter: Table{
+ Rules: []Rule{
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ },
+ Underflows: map[Hook]int{
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ },
+ UserChains: map[string]int{},
},
},
Priorities: map[Hook][]string{
- Prerouting: []string{tablenameMangle, tablenameNat},
- Output: []string{tablenameMangle, tablenameNat},
+ Input: []string{TablenameNat, TablenameFilter},
+ Prerouting: []string{TablenameMangle, TablenameNat},
+ Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
},
}
}
-func unconditionalAcceptChain(name string) Chain {
- return Chain{
- Name: name,
- Rules: []Rule{
- Rule{
- Target: UnconditionalAcceptTarget{},
- },
+// EmptyFilterTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyFilterTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: map[Hook]int{
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: HookUnset,
+ },
+ Underflows: map[Hook]int{
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: HookUnset,
},
+ UserChains: map[string]int{},
}
}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
index 19a7f77e3..b94a4c941 100644
--- a/pkg/tcpip/iptables/targets.go
+++ b/pkg/tcpip/iptables/targets.go
@@ -16,7 +16,10 @@
package iptables
-import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
// UnconditionalAcceptTarget accepts all packets.
type UnconditionalAcceptTarget struct{}
@@ -33,3 +36,14 @@ type UnconditionalDropTarget struct{}
func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
return Drop, ""
}
+
+// ErrorTarget logs an error and drops the packet. It represents a target that
+// should be unreachable.
+type ErrorTarget struct{}
+
+// Action implements Target.Action.
+func (ErrorTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
+ log.Warningf("ErrorTarget triggered.")
+ return Drop, ""
+
+}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 42a79ef9f..540f8c0b4 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -61,9 +61,12 @@ const (
type Verdict int
const (
+ // Invalid indicates an unkonwn or erroneous verdict.
+ Invalid Verdict = iota
+
// Accept indicates the packet should continue traversing netstack as
// normal.
- Accept Verdict = iota
+ Accept
// Drop inicates the packet should be dropped, stopping traversing
// netstack.
@@ -104,29 +107,22 @@ type IPTables struct {
Priorities map[Hook][]string
}
-// A Table defines a set of chains and hooks into the network stack. The
-// currently supported tables are:
-// * nat
-// * mangle
+// A Table defines a set of chains and hooks into the network stack. It is
+// really just a list of rules with some metadata for entrypoints and such.
type Table struct {
- // BuiltinChains holds the un-deletable chains built into netstack. If
- // a hook isn't present in the map, this table doesn't utilize that
- // hook.
- BuiltinChains map[Hook]Chain
+ // Rules holds the rules that make up the table.
+ Rules []Rule
- // DefaultTargets holds a target for each hook that will be executed if
- // chain traversal doesn't yield a verdict.
- DefaultTargets map[Hook]Target
+ // BuiltinChains maps builtin chains to their entrypoint rule in Rules.
+ BuiltinChains map[Hook]int
+
+ // Underflows maps builtin chains to their underflow rule in Rules
+ // (i.e. the rule to execute if the chain returns without a verdict).
+ Underflows map[Hook]int
// UserChains holds user-defined chains for the keyed by name. Users
// can give their chains arbitrary names.
- UserChains map[string]Chain
-
- // Chains maps names to chains for both builtin and user-defined chains.
- // Its entries point to Chains already either in BuiltinChains or
- // UserChains, and its purpose is to make looking up tables by name
- // fast.
- Chains map[string]*Chain
+ UserChains map[string]int
// Metadata holds information about the Table that is useful to users
// of IPTables, but not to the netstack IPTables code itself.
@@ -152,21 +148,6 @@ func (table *Table) SetMetadata(metadata interface{}) {
table.metadata = metadata
}
-// A Chain defines a list of rules for packet processing. When a packet
-// traverses a chain, it is checked against each rule until either a rule
-// returns a verdict or the chain ends.
-//
-// By convention, builtin chains end with a rule that matches everything and
-// returns either Accept or Drop. User-defined chains end with Return. These
-// aren't strictly necessary here, but the iptables tool writes tables this way.
-type Chain struct {
- // Name is the chain name.
- Name string
-
- // Rules is the list of rules to traverse.
- Rules []Rule
-}
-
// A Rule is a packet processing rule. It consists of two pieces. First it
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 897c94821..66cc53ed4 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -16,6 +16,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index fa8a703d9..b7f60178e 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,10 +41,10 @@ package fdbased
import (
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index a4f9cdd69..09165dd4c 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -15,6 +15,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -31,6 +32,7 @@ go_test(
],
embed = [":sharedmem"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index 6b5bc542c..a0d4ad0be 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -21,4 +21,5 @@ go_test(
"pipe_test.go",
],
embed = [":pipe"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
index 59ef69a8b..dc239a0d0 100644
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -18,8 +18,9 @@ import (
"math/rand"
"reflect"
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSimpleReadWrite(t *testing.T) {
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 080f9d667..655e537c4 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -23,11 +23,11 @@
package sharedmem
import (
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 89603c48f..5c729a439 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -22,11 +22,11 @@ import (
"math/rand"
"os"
"strings"
- "sync"
"syscall"
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index da8482509..42cacb8a6 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -79,16 +79,16 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index acf1e022c..ed16076fd 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -28,6 +28,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
],
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 6da5238ec..92f2aa13a 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -19,9 +19,9 @@ package fragmentation
import (
"fmt"
"log"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 9e002e396..0a83d81f2 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -18,9 +18,9 @@ import (
"container/heap"
"fmt"
"math"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4144a7837..f1bc33adf 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -239,7 +239,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -480,7 +480,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e645cf62c..4ee3d5b45 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -238,11 +238,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
@@ -256,7 +256,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
loopedR.Release()
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
@@ -270,11 +270,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
- if loop&stack.PacketLoop != 0 {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return len(pkts), nil
}
@@ -289,7 +289,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
ip := header.IPv4(pkt.Data.First())
@@ -324,10 +324,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLo
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
e.HandlePacket(r, pkt.Clone())
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index e13f1fabf..58c3c79b9 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -112,11 +112,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
@@ -130,7 +130,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
loopedR.Release()
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
@@ -139,11 +139,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
- if loop&stack.PacketLoop != 0 {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return len(pkts), nil
}
@@ -161,7 +161,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
// TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index e156b01f6..a6ef3bdcc 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -9,6 +9,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/tcpip/ports",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 6c5e19e8f..b937cb84b 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -18,9 +18,9 @@ package ports
import (
"math"
"math/rand"
- "sync"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 69077669a..783351a69 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -36,6 +36,7 @@ go_library(
"//pkg/ilist",
"//pkg/rand",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
@@ -50,7 +51,7 @@ go_library(
go_test(
name = "stack_x_test",
- size = "small",
+ size = "medium",
srcs = [
"ndp_test.go",
"stack_test.go",
@@ -59,6 +60,7 @@ go_test(
],
deps = [
":stack",
+ "//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
@@ -82,6 +84,7 @@ go_test(
embed = [":stack"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
],
)
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 267df60d1..403557fd7 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -16,10 +16,10 @@ package stack
import (
"fmt"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 9946b8fe8..1baa498d0 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -16,12 +16,12 @@ package stack
import (
"fmt"
- "sync"
"sync/atomic"
"testing"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index d9ab59336..c99d387d5 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"log"
+ "math/rand"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,24 +39,36 @@ const (
// Default = 1s (from RFC 4861 section 10).
defaultRetransmitTimer = time.Second
+ // defaultMaxRtrSolicitations is the default number of Router
+ // Solicitation messages to send when a NIC becomes enabled.
+ //
+ // Default = 3 (from RFC 4861 section 10).
+ defaultMaxRtrSolicitations = 3
+
+ // defaultRtrSolicitationInterval is the default amount of time between
+ // sending Router Solicitation messages.
+ //
+ // Default = 4s (from 4861 section 10).
+ defaultRtrSolicitationInterval = 4 * time.Second
+
+ // defaultMaxRtrSolicitationDelay is the default maximum amount of time
+ // to wait before sending the first Router Solicitation message.
+ //
+ // Default = 1s (from 4861 section 10).
+ defaultMaxRtrSolicitationDelay = time.Second
+
// defaultHandleRAs is the default configuration for whether or not to
// handle incoming Router Advertisements as a host.
- //
- // Default = true.
defaultHandleRAs = true
// defaultDiscoverDefaultRouters is the default configuration for
// whether or not to discover default routers from incoming Router
// Advertisements, as a host.
- //
- // Default = true.
defaultDiscoverDefaultRouters = true
// defaultDiscoverOnLinkPrefixes is the default configuration for
// whether or not to discover on-link prefixes from incoming Router
// Advertisements' Prefix Information option, as a host.
- //
- // Default = true.
defaultDiscoverOnLinkPrefixes = true
// defaultAutoGenGlobalAddresses is the default configuration for
@@ -74,26 +87,31 @@ const (
// value of 0 means unspecified, so the smallest valid value is 1.
// Note, the unit of the RetransmitTimer field in the Router
// Advertisement is milliseconds.
- //
- // Min = 1ms.
minimumRetransmitTimer = time.Millisecond
+ // minimumRtrSolicitationInterval is the minimum amount of time to wait
+ // between sending Router Solicitation messages. This limit is imposed
+ // to make sure that Router Solicitation messages are not sent all at
+ // once, defeating the purpose of sending the initial few messages.
+ minimumRtrSolicitationInterval = 500 * time.Millisecond
+
+ // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait
+ // before sending the first Router Solicitation message. It is 0 because
+ // we cannot have a negative delay.
+ minimumMaxRtrSolicitationDelay = 0
+
// MaxDiscoveredDefaultRouters is the maximum number of discovered
// default routers. The stack should stop discovering new routers after
// discovering MaxDiscoveredDefaultRouters routers.
//
// This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
// SHOULD be more.
- //
- // Max = 10.
MaxDiscoveredDefaultRouters = 10
// MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
// on-link prefixes. The stack should stop discovering new on-link
// prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link
// prefixes.
- //
- // Max = 10.
MaxDiscoveredOnLinkPrefixes = 10
// validPrefixLenForAutoGen is the expected prefix length that an
@@ -115,6 +133,30 @@ var (
MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
)
+// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
+// NDP Router Advertisement informed the Stack about.
+type DHCPv6ConfigurationFromNDPRA int
+
+const (
+ // DHCPv6NoConfiguration indicates that no configurations are available via
+ // DHCPv6.
+ DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota
+
+ // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
+ //
+ // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
+ // will return all available configuration information.
+ DHCPv6ManagedAddress
+
+ // DHCPv6OtherConfigurations indicates that other configuration information is
+ // available via DHCPv6.
+ //
+ // Other configurations are configurations other than addresses. Examples of
+ // other configurations are recursive DNS server list, DNS search lists and
+ // default gateway.
+ DHCPv6OtherConfigurations
+)
+
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
@@ -169,6 +211,15 @@ type NDPDispatcher interface {
// call functions on the stack itself.
OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+ // OnAutoGenAddressDeprecated will be called when an auto-generated
+ // address (as part of SLAAC) has been deprecated, but is still
+ // considered valid. Note, if an address is invalidated at the same
+ // time it is deprecated, the deprecation event MAY be omitted.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
+
// OnAutoGenAddressInvalidated will be called when an auto-generated
// address (as part of SLAAC) has been invalidated.
//
@@ -185,7 +236,20 @@ type NDPDispatcher interface {
// already known DNS servers. If called with known DNS servers, their
// valid lifetimes must be refreshed to lifetime (it may be increased,
// decreased, or completely invalidated when lifetime = 0).
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+
+ // OnDHCPv6Configuration will be called with an updated configuration that is
+ // available via DHCPv6 for a specified NIC.
+ //
+ // NDPDispatcher assumes that the initial configuration available by DHCPv6 is
+ // DHCPv6NoConfiguration.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
}
// NDPConfigurations is the NDP configurations for the netstack.
@@ -199,9 +263,24 @@ type NDPConfigurations struct {
// The amount of time to wait between sending Neighbor solicitation
// messages.
//
- // Must be greater than 0.5s.
+ // Must be greater than or equal to 1ms.
RetransmitTimer time.Duration
+ // The number of Router Solicitation messages to send when the NIC
+ // becomes enabled.
+ MaxRtrSolicitations uint8
+
+ // The amount of time between transmitting Router Solicitation messages.
+ //
+ // Must be greater than or equal to 0.5s.
+ RtrSolicitationInterval time.Duration
+
+ // The maximum amount of time before transmitting the first Router
+ // Solicitation message.
+ //
+ // Must be greater than or equal to 0s.
+ MaxRtrSolicitationDelay time.Duration
+
// HandleRAs determines whether or not Router Advertisements will be
// processed.
HandleRAs bool
@@ -232,12 +311,15 @@ type NDPConfigurations struct {
// default values.
func DefaultNDPConfigurations() NDPConfigurations {
return NDPConfigurations{
- DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
- RetransmitTimer: defaultRetransmitTimer,
- HandleRAs: defaultHandleRAs,
- DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
- DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
- AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ MaxRtrSolicitations: defaultMaxRtrSolicitations,
+ RtrSolicitationInterval: defaultRtrSolicitationInterval,
+ MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
}
}
@@ -246,10 +328,24 @@ func DefaultNDPConfigurations() NDPConfigurations {
//
// If RetransmitTimer is less than minimumRetransmitTimer, then a value of
// defaultRetransmitTimer will be used.
+//
+// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then
+// a value of defaultRtrSolicitationInterval will be used.
+//
+// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then
+// a value of defaultMaxRtrSolicitationDelay will be used.
func (c *NDPConfigurations) validate() {
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
}
+
+ if c.RtrSolicitationInterval < minimumRtrSolicitationInterval {
+ c.RtrSolicitationInterval = defaultRtrSolicitationInterval
+ }
+
+ if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay {
+ c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay
+ }
}
// ndpState is the per-interface NDP state.
@@ -270,8 +366,15 @@ type ndpState struct {
// Information option.
onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+ // The timer used to send the next router solicitation message.
+ // If routers are being solicited, rtrSolicitTimer MUST NOT be nil.
+ rtrSolicitTimer *time.Timer
+
// The addresses generated by SLAAC.
autoGenAddresses map[tcpip.Address]autoGenAddressState
+
+ // The last learned DHCPv6 configuration from an NDP RA.
+ dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
}
// dadState holds the Duplicate Address Detection timer and channel to signal
@@ -290,71 +393,27 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- invalidationTimer *time.Timer
-
- // Used to inform the timer not to invalidate the default router (R) in
- // a race condition (T1 is a goroutine that handles an RA from R and T2
- // is the goroutine that handles R's invalidation timer firing):
- // T1: Receive a new RA from R
- // T1: Obtain the NIC's lock before processing the RA
- // T2: R's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends R's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates R immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate R, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
+ invalidationTimer tcpip.CancellableTimer
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- invalidationTimer *time.Timer
-
- // Used to signal the timer not to invalidate the on-link prefix (P) in
- // a race condition (T1 is a goroutine that handles a PI for P and T2
- // is the goroutine that handles P's invalidation timer firing):
- // T1: Receive a new PI for P
- // T1: Obtain the NIC's lock before processing the PI
- // T2: P's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends P's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates P immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate P, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
+ invalidationTimer tcpip.CancellableTimer
}
// autoGenAddressState holds data associated with an address generated via
// SLAAC.
type autoGenAddressState struct {
- invalidationTimer *time.Timer
-
- // Used to signal the timer not to invalidate the SLAAC address (A) in
- // a race condition (T1 is a goroutine that handles a PI for A and T2
- // is the goroutine that handles A's invalidation timer firing):
- // T1: Receive a new PI for A
- // T1: Obtain the NIC's lock before processing the PI
- // T2: A's invalidation timer fires, and gets blocked on obtaining the
- // NIC's lock
- // T1: Refreshes/extends A's lifetime & releases NIC's lock
- // T2: Obtains NIC's lock & invalidates A immediately
- //
- // To resolve this, T1 will check to see if the timer already fired, and
- // inform the timer using doNotInvalidate to not invalidate A, so that
- // once T2 obtains the lock, it will see that it is set to true and do
- // nothing further.
- doNotInvalidate *bool
-
- // Nonzero only when the address is not valid forever (invalidationTimer
- // is not nil).
+ // A reference to the referencedNetworkEndpoint that this autoGenAddressState
+ // is holding state for.
+ ref *referencedNetworkEndpoint
+
+ deprecationTimer tcpip.CancellableTimer
+ invalidationTimer tcpip.CancellableTimer
+
+ // Nonzero only when the address is not valid forever.
validUntil time.Time
}
@@ -496,10 +555,12 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u
// address.
panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
}
+ snmcRef.incRef()
// Use the unspecified address as the source address when performing
// DAD.
r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
+ defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
@@ -556,7 +617,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
// handleRA handles a Router Advertisement message that arrived on the NIC
// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Is the NIC configured to handle RAs at all?
//
@@ -568,6 +629,28 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
return
}
+ // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
+ // only inform the dispatcher on configuration changes. We do nothing else
+ // with the information.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ var configuration DHCPv6ConfigurationFromNDPRA
+ switch {
+ case ra.ManagedAddrConfFlag():
+ configuration = DHCPv6ManagedAddress
+
+ case ra.OtherConfFlag():
+ configuration = DHCPv6OtherConfigurations
+
+ default:
+ configuration = DHCPv6NoConfiguration
+ }
+
+ if ndp.dhcpv6Configuration != configuration {
+ ndp.dhcpv6Configuration = configuration
+ ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration)
+ }
+ }
+
// Is the NIC configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
rtr, ok := ndp.defaultRouters[ip]
@@ -585,27 +668,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
// the invalidation timer.
- timer := rtr.invalidationTimer
-
- // We should ALWAYS have an invalidation timer for a
- // discovered router.
- if timer == nil {
- panic("ndphandlera: RA invalidation timer should not be nil")
- }
-
- if !timer.Stop() {
- // If we reach this point, then we know the
- // timer fired after we already took the NIC
- // lock. Inform the timer not to invalidate the
- // router when it obtains the lock as we just
- // got a new RA that refreshes its lifetime to a
- // non-zero value. See
- // defaultRouterState.doNotInvalidate for more
- // details.
- *rtr.doNotInvalidate = true
- }
-
- timer.Reset(rl)
+ rtr.invalidationTimer.StopLocked()
+ rtr.invalidationTimer.Reset(rl)
+ ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
// We know about the router but it is no longer to be
@@ -672,10 +737,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.Stop()
- rtr.invalidationTimer = nil
- *rtr.doNotInvalidate = true
- rtr.doNotInvalidate = nil
+ rtr.invalidationTimer.StopLocked()
delete(ndp.defaultRouters, ip)
@@ -704,27 +766,15 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
return
}
- // Used to signal the timer not to invalidate the default router (R) in
- // a race condition. See defaultRouterState.doNotInvalidate for more
- // details.
- var doNotInvalidate bool
-
- ndp.defaultRouters[ip] = defaultRouterState{
- invalidationTimer: time.AfterFunc(rl, func() {
- ndp.nic.stack.mu.Lock()
- defer ndp.nic.stack.mu.Unlock()
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if doNotInvalidate {
- doNotInvalidate = false
- return
- }
-
+ state := defaultRouterState{
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
- doNotInvalidate: &doNotInvalidate,
}
+
+ state.invalidationTimer.Reset(rl)
+
+ ndp.defaultRouters[ip] = state
}
// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
@@ -746,21 +796,17 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
return
}
- // Used to signal the timer not to invalidate the on-link prefix (P) in
- // a race condition. See onLinkPrefixState.doNotInvalidate for more
- // details.
- var doNotInvalidate bool
- var timer *time.Timer
+ state := onLinkPrefixState{
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }),
+ }
- // Only create a timer if the lifetime is not infinite.
if l < header.NDPInfiniteLifetime {
- timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate)
+ state.invalidationTimer.Reset(l)
}
- ndp.onLinkPrefixes[prefix] = onLinkPrefixState{
- invalidationTimer: timer,
- doNotInvalidate: &doNotInvalidate,
- }
+ ndp.onLinkPrefixes[prefix] = state
}
// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
@@ -775,13 +821,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- if s.invalidationTimer != nil {
- s.invalidationTimer.Stop()
- s.invalidationTimer = nil
- *s.doNotInvalidate = true
- }
-
- s.doNotInvalidate = nil
+ s.invalidationTimer.StopLocked()
delete(ndp.onLinkPrefixes, prefix)
@@ -791,28 +831,6 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
}
}
-// prefixInvalidationCallback returns a new on-link prefix invalidation timer
-// for prefix that fires after vl.
-//
-// doNotInvalidate is used to signal the timer when it fires at the same time
-// that a prefix's valid lifetime gets refreshed. See
-// onLinkPrefixState.doNotInvalidate for more details.
-func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer {
- return time.AfterFunc(vl, func() {
- ndp.nic.stack.mu.Lock()
- defer ndp.nic.stack.mu.Unlock()
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if *doNotInvalidate {
- *doNotInvalidate = false
- return
- }
-
- ndp.invalidateOnLinkPrefix(prefix)
- })
-}
-
// handleOnLinkPrefixInformation handles a Prefix Information option with
// its on-link flag set, as per RFC 4861 section 6.3.4.
//
@@ -852,42 +870,17 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
+ //
// Update the invalidation timer.
- timer := prefixState.invalidationTimer
- if timer == nil && vl >= header.NDPInfiniteLifetime {
- // Had infinite valid lifetime before and
- // continues to have an invalid lifetime. Do
- // nothing further.
- return
- }
+ prefixState.invalidationTimer.StopLocked()
- if timer != nil && !timer.Stop() {
- // If we reach this point, then we know the timer alread fired
- // after we took the NIC lock. Inform the timer to not
- // invalidate the prefix once it obtains the lock as we just
- // got a new PI that refreshes its lifetime to a non-zero value.
- // See onLinkPrefixState.doNotInvalidate for more details.
- *prefixState.doNotInvalidate = true
- }
-
- if vl >= header.NDPInfiniteLifetime {
- // Prefix is now valid forever so we don't need
- // an invalidation timer.
- prefixState.invalidationTimer = nil
- ndp.onLinkPrefixes[prefix] = prefixState
- return
- }
-
- if timer != nil {
- // We already have a timer so just reset it to
- // expire after the new valid lifetime.
- timer.Reset(vl)
- return
+ if vl < header.NDPInfiniteLifetime {
+ // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // the new valid lifetime.
+ prefixState.invalidationTimer.Reset(vl)
}
- // We do not have a timer so just create a new one.
- prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate)
ndp.onLinkPrefixes[prefix] = prefixState
}
@@ -897,7 +890,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the autonomous flag is set.
//
-// The NIC that ndp belongs to and its associated stack MUST be locked.
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
vl := pi.ValidLifetime()
pl := pi.PreferredLifetime()
@@ -912,103 +905,30 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
prefix := pi.Subnet()
// Check if we already have an auto-generated address for prefix.
- for _, ref := range ndp.nic.endpoints {
- if ref.protocol != header.IPv6ProtocolNumber {
- continue
- }
-
- if ref.configType != slaac {
- continue
- }
-
- addr := ref.ep.ID().LocalAddress
- refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()}
+ for addr, addrState := range ndp.autoGenAddresses {
+ refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()}
if refAddrWithPrefix.Subnet() != prefix {
continue
}
- //
- // At this point, we know we are refreshing a SLAAC generated
- // IPv6 address with the prefix, prefix. Do the work as outlined
- // by RFC 4862 section 5.5.3.e.
- //
-
- addrState, ok := ndp.autoGenAddresses[addr]
- if !ok {
- panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr))
- }
-
- // TODO(b/143713887): Handle deprecating auto-generated address
- // after the preferred lifetime.
-
- // As per RFC 4862 section 5.5.3.e, the valid lifetime of the
- // address generated by SLAAC is as follows:
- //
- // 1) If the received Valid Lifetime is greater than 2 hours or
- // greater than RemainingLifetime, set the valid lifetime of
- // the address to the advertised Valid Lifetime.
- //
- // 2) If RemainingLifetime is less than or equal to 2 hours,
- // ignore the advertised Valid Lifetime.
- //
- // 3) Otherwise, reset the valid lifetime of the address to 2
- // hours.
-
- // Handle the infinite valid lifetime separately as we do not
- // keep a timer in this case.
- if vl >= header.NDPInfiniteLifetime {
- if addrState.invalidationTimer != nil {
- // Valid lifetime was finite before, but now it
- // is valid forever.
- if !addrState.invalidationTimer.Stop() {
- *addrState.doNotInvalidate = true
- }
- addrState.invalidationTimer = nil
- addrState.validUntil = time.Time{}
- ndp.autoGenAddresses[addr] = addrState
- }
-
- return
- }
-
- var effectiveVl time.Duration
- var rl time.Duration
-
- // If the address was originally set to be valid forever,
- // assume the remaining time to be the maximum possible value.
- if addrState.invalidationTimer == nil {
- rl = header.NDPInfiniteLifetime
- } else {
- rl = time.Until(addrState.validUntil)
- }
-
- if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
- effectiveVl = vl
- } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
- ndp.autoGenAddresses[addr] = addrState
- return
- } else {
- effectiveVl = MinPrefixInformationValidLifetimeForUpdate
- }
-
- if addrState.invalidationTimer == nil {
- addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate)
- } else {
- if !addrState.invalidationTimer.Stop() {
- *addrState.doNotInvalidate = true
- }
- addrState.invalidationTimer.Reset(effectiveVl)
- }
-
- addrState.validUntil = time.Now().Add(effectiveVl)
- ndp.autoGenAddresses[addr] = addrState
+ // At this point, we know we are refreshing a SLAAC generated IPv6 address
+ // with the prefix prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.e.
+ ndp.refreshAutoGenAddressLifetimes(addr, pl, vl)
return
}
// We do not already have an address within the prefix, prefix. Do the
// work as outlined by RFC 4862 section 5.5.3.d if n is configured
// to auto-generated global addresses by SLAAC.
+ ndp.newAutoGenAddress(prefix, pl, vl)
+}
+// newAutoGenAddress generates a new SLAAC address with the provided lifetimes
+// for prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) {
// Are we configured to auto-generate new global addresses?
if !ndp.configs.AutoGenGlobalAddresses {
return
@@ -1028,22 +948,24 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
return
}
- // Only attempt to generate an interface-specific IID if we have a valid
- // link address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address
- // (provided by LinkEndpoint.LinkAddress) before reaching this
- // point.
- linkAddr := ndp.nic.linkEP.LinkAddress()
- if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return
- }
+ addrBytes := []byte(prefix.ID())
+ if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey)
+ } else {
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ linkAddr := ndp.nic.linkEP.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return
+ }
- // Generate an address within prefix from the modified EUI-64 of ndp's
- // NIC's Ethernet MAC address.
- addrBytes := make([]byte, header.IPv6AddressSize)
- copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address])
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ // Generate an address within prefix from the modified EUI-64 of ndp's NIC's
+ // Ethernet MAC address.
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ }
addr := tcpip.Address(addrBytes)
addrWithPrefix := tcpip.AddressWithPrefix{
Address: addr,
@@ -1065,29 +987,132 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
return
}
- if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{
+ protocolAddr := tcpip.ProtocolAddress{
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addrWithPrefix,
- }, FirstPrimaryEndpoint, permanent, slaac); err != nil {
- panic(err)
+ }
+ // If the preferred lifetime is zero, then the address should be considered
+ // deprecated.
+ deprecated := pl == 0
+ ref, err := ndp.nic.addPermanentAddressLocked(protocolAddr, FirstPrimaryEndpoint, slaac, deprecated)
+ if err != nil {
+ log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err)
}
- // Setup the timers to deprecate and invalidate this newly generated
+ state := autoGenAddressState{
+ ref: ref,
+ deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ addrState, ok := ndp.autoGenAddresses[addr]
+ if !ok {
+ log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)
+ }
+ addrState.ref.deprecated = true
+ ndp.notifyAutoGenAddressDeprecated(addr)
+ }),
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateAutoGenAddress(addr)
+ }),
+ }
+
+ // Setup the initial timers to deprecate and invalidate this newly generated
// address.
- // TODO(b/143713887): Handle deprecating auto-generated addresses
- // after the preferred lifetime.
+ if !deprecated && pl < header.NDPInfiniteLifetime {
+ state.deprecationTimer.Reset(pl)
+ }
- var doNotInvalidate bool
- var vTimer *time.Timer
if vl < header.NDPInfiniteLifetime {
- vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate)
+ state.invalidationTimer.Reset(vl)
+ state.validUntil = time.Now().Add(vl)
}
- ndp.autoGenAddresses[addr] = autoGenAddressState{
- invalidationTimer: vTimer,
- doNotInvalidate: &doNotInvalidate,
- validUntil: time.Now().Add(vl),
+ ndp.autoGenAddresses[addr] = state
+}
+
+// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated
+// address addr.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) {
+ addrState, ok := ndp.autoGenAddresses[addr]
+ if !ok {
+ log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr)
+ }
+ defer func() { ndp.autoGenAddresses[addr] = addrState }()
+
+ // If the preferred lifetime is zero, then the address should be considered
+ // deprecated.
+ deprecated := pl == 0
+ wasDeprecated := addrState.ref.deprecated
+ addrState.ref.deprecated = deprecated
+
+ // Only send the deprecation event if the deprecated status for addr just
+ // changed from non-deprecated to deprecated.
+ if !wasDeprecated && deprecated {
+ ndp.notifyAutoGenAddressDeprecated(addr)
+ }
+
+ // If addr was preferred for some finite lifetime before, stop the deprecation
+ // timer so it can be reset.
+ addrState.deprecationTimer.StopLocked()
+
+ // Reset the deprecation timer if addr has a finite preferred lifetime.
+ if !deprecated && pl < header.NDPInfiniteLifetime {
+ addrState.deprecationTimer.Reset(pl)
+ }
+
+ // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address
+ //
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or greater than
+ // RemainingLifetime, set the valid lifetime of the address to the
+ // advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
+ // advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the address to 2 hours.
+
+ // Handle the infinite valid lifetime separately as we do not keep a timer in
+ // this case.
+ if vl >= header.NDPInfiniteLifetime {
+ addrState.invalidationTimer.StopLocked()
+ addrState.validUntil = time.Time{}
+ return
+ }
+
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the address was originally set to be valid forever, assume the remaining
+ // time to be the maximum possible value.
+ if addrState.validUntil == (time.Time{}) {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(addrState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
+ return
+ } else {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
+
+ addrState.invalidationTimer.StopLocked()
+ addrState.invalidationTimer.Reset(effectiveVl)
+ addrState.validUntil = time.Now().Add(effectiveVl)
+}
+
+// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr
+// has been deprecated.
+func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: validPrefixLenForAutoGen,
+ })
}
}
@@ -1111,19 +1136,12 @@ func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) {
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool {
state, ok := ndp.autoGenAddresses[addr]
-
if !ok {
return false
}
- if state.invalidationTimer != nil {
- state.invalidationTimer.Stop()
- state.invalidationTimer = nil
- *state.doNotInvalidate = true
- }
-
- state.doNotInvalidate = nil
-
+ state.deprecationTimer.StopLocked()
+ state.invalidationTimer.StopLocked()
delete(ndp.autoGenAddresses, addr)
if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
@@ -1136,26 +1154,6 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
return true
}
-// autoGenAddrInvalidationTimer returns a new invalidation timer for an
-// auto-generated address that fires after vl.
-//
-// doNotInvalidate is used to inform the timer when it fires at the same time
-// that an auto-generated address's valid lifetime gets refreshed. See
-// autoGenAddrState.doNotInvalidate for more details.
-func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer {
- return time.AfterFunc(vl, func() {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if *doNotInvalidate {
- *doNotInvalidate = false
- return
- }
-
- ndp.invalidateAutoGenAddress(addr)
- })
-}
-
// cleanupHostOnlyState cleans up any state that is only useful for hosts.
//
// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a
@@ -1190,3 +1188,84 @@ func (ndp *ndpState) cleanupHostOnlyState() {
log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got)
}
}
+
+// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
+// 6.3.7. If routers are already being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) startSolicitingRouters() {
+ if ndp.rtrSolicitTimer != nil {
+ // We are already soliciting routers.
+ return
+ }
+
+ remaining := ndp.configs.MaxRtrSolicitations
+ if remaining == 0 {
+ return
+ }
+
+ // Calculate the random delay before sending our first RS, as per RFC
+ // 4861 section 6.3.7.
+ var delay time.Duration
+ if ndp.configs.MaxRtrSolicitationDelay > 0 {
+ delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ }
+
+ ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
+ // Send an RS message with the unspecified source address.
+ ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, true)
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ defer r.Release()
+
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
+ pkt := header.ICMPv6(hdr.Prepend(payloadSize))
+ pkt.SetType(header.ICMPv6RouterSolicit)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, tcpip.PacketBuffer{Header: hdr},
+ ); err != nil {
+ sent.Dropped.Increment()
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
+ // Don't send any more messages if we had an error.
+ remaining = 0
+ } else {
+ sent.RouterSolicit.Increment()
+ remaining--
+ }
+
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+ if remaining == 0 {
+ ndp.rtrSolicitTimer = nil
+ } else if ndp.rtrSolicitTimer != nil {
+ // Note, we need to explicitly check to make sure that
+ // the timer field is not nil because if it was nil but
+ // we still reached this point, then we know the NIC
+ // was requested to stop soliciting routers so we don't
+ // need to send the next Router Solicitation message.
+ ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval)
+ }
+ })
+
+}
+
+// stopSolicitingRouters stops soliciting routers. If routers are not currently
+// being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) stopSolicitingRouters() {
+ if ndp.rtrSolicitTimer == nil {
+ // Nothing to do.
+ return
+ }
+
+ ndp.rtrSolicitTimer.Stop()
+ ndp.rtrSolicitTimer = nil
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 64a9a2b20..1a52e0e68 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -21,6 +21,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -29,15 +30,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
- linkAddr1 = "\x02\x02\x03\x04\x05\x06"
- linkAddr2 = "\x02\x02\x03\x04\x05\x07"
- linkAddr3 = "\x02\x02\x03\x04\x05\x08"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
defaultTimeout = 100 * time.Millisecond
)
@@ -45,6 +48,10 @@ var (
llAddr1 = header.LinkLocalAddr(linkAddr1)
llAddr2 = header.LinkLocalAddr(linkAddr2)
llAddr3 = header.LinkLocalAddr(linkAddr3)
+ dstAddr = tcpip.FullAddress{
+ Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ Port: 25,
+ }
)
func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix {
@@ -135,6 +142,7 @@ type ndpAutoGenAddrEventType int
const (
newAddr ndpAutoGenAddrEventType = iota
+ deprecatedAddr
invalidatedAddr
)
@@ -154,18 +162,24 @@ type ndpRDNSSEvent struct {
rdnss ndpRDNSS
}
+type ndpDHCPv6Event struct {
+ nicID tcpip.NICID
+ configuration stack.DHCPv6ConfigurationFromNDPRA
+}
+
var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
// related events happen for test purposes.
type ndpDispatcher struct {
- dadC chan ndpDADEvent
- routerC chan ndpRouterEvent
- rememberRouter bool
- prefixC chan ndpPrefixEvent
- rememberPrefix bool
- autoGenAddrC chan ndpAutoGenAddrEvent
- rdnssC chan ndpRDNSSEvent
+ dadC chan ndpDADEvent
+ routerC chan ndpRouterEvent
+ rememberRouter bool
+ prefixC chan ndpPrefixEvent
+ rememberPrefix bool
+ autoGenAddrC chan ndpAutoGenAddrEvent
+ rdnssC chan ndpRDNSSEvent
+ dhcpv6ConfigurationC chan ndpDHCPv6Event
}
// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
@@ -239,6 +253,16 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi
return true
}
+func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ deprecatedAddr,
+ }
+ }
+}
+
func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
if c := n.autoGenAddrC; c != nil {
c <- ndpAutoGenAddrEvent{
@@ -262,6 +286,16 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc
}
}
+// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
+func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ if c := n.dhcpv6ConfigurationC; c != nil {
+ c <- ndpDHCPv6Event{
+ nicID,
+ configuration,
+ }
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
@@ -779,21 +813,32 @@ func TestSetNDPConfigurations(t *testing.T) {
}
}
-// raBufWithOpts returns a valid NDP Router Advertisement with options.
-//
-// Note, raBufWithOpts does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
+// and DHCPv6 configurations specified.
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- ra := header.NDPRouterAdvert(pkt.NDPPayload())
+ raPayload := pkt.NDPPayload()
+ ra := header.NDPRouterAdvert(raPayload)
+ // Populate the Router Lifetime.
+ binary.BigEndian.PutUint16(raPayload[2:], rl)
+ // Populate the Managed Address flag field.
+ if managedAddress {
+ // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
+ // of the RA payload.
+ raPayload[1] |= (1 << 7)
+ }
+ // Populate the Other Configurations flag field.
+ if otherConfigurations {
+ // The Other Configurations flag field is the 6th bit of byte #1
+ // (0-indexing) of the RA payload.
+ raPayload[1] |= (1 << 6)
+ }
opts := ra.Options()
opts.Serialize(optSer)
- // Populate the Router Lifetime.
- binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -808,6 +853,23 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
+// raBufWithOpts returns a valid NDP Router Advertisement with options.
+//
+// Note, raBufWithOpts does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
+}
+
+// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
+// fields set.
+//
+// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
+// DHCPv6 related ones.
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
+}
+
// raBuf returns a valid NDP Router Advertisement.
//
// Note, raBuf does not populate any of the RA fields other than the
@@ -1011,13 +1073,13 @@ func TestRouterDiscovery(t *testing.T) {
expectRouterEvent(llAddr2, true)
// Rx an RA from another router (lladdr3) with non-zero lifetime.
- l3Lifetime := time.Duration(6)
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime)))
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
expectRouterEvent(llAddr3, true)
// Rx an RA from lladdr2 with lesser lifetime.
- l2Lifetime := time.Duration(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime)))
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
select {
case <-ndpDisp.routerC:
t.Fatal("Should not receive a router event when updating lifetimes for known routers")
@@ -1031,7 +1093,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1048,7 +1110,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1637,12 +1699,541 @@ func TestAutoGenAddr(t *testing.T) {
}
}
+// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher,
+// channel.Endpoint and stack.Stack.
+//
+// stack.Stack will have a default route through the router (llAddr3) installed
+// and a static link-address (linkAddr3) added to the link address cache for the
+// router.
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+ t.Helper()
+ ndpDisp := &ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ return ndpDisp, e, s
+}
+
+// addrForNewConnectionTo returns the local address used when creating a new
+// connection to addr.
+func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Connect(addr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", addr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// addrForNewConnection returns the local address used when creating a new
+// connection.
+func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
+ t.Helper()
+
+ return addrForNewConnectionTo(t, s, dstAddr)
+}
+
+// addrForNewConnectionWithAddr returns the local address used when creating a
+// new connection with a specific local address.
+func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Bind(addr); err != nil {
+ t.Fatalf("ep.Bind(%+v): %s", addr, err)
+ }
+ if err := ep.Connect(dstAddr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
+// receiving a PI with 0 preferred lifetime.
+func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ }
+
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ }
+
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+}
+
+// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// when its preferred lifetime expires.
+func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+ const nicID = 1
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout)
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
+ }
+
+ case <-time.After(newMinVLDuration + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+
+ if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ }
+}
+
+// Tests transitioning a SLAAC address's valid lifetime between finite and
+// infinite values.
+func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
+ const infiniteVLSeconds = 2
+ const minVLSeconds = 1
+ savedIL := header.NDPInfiniteLifetime
+ savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
+ header.NDPInfiniteLifetime = savedIL
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
+ header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ infiniteVL uint32
+ }{
+ {
+ name: "EqualToInfiniteVL",
+ infiniteVL: infiniteVLSeconds,
+ },
+ // Our implementation supports changing header.NDPInfiniteLifetime for tests
+ // such that a packet can be received where the lifetime field has a value
+ // greater than header.NDPInfiniteLifetime. Because of this, we test to make
+ // sure that receiving a value greater than header.NDPInfiniteLifetime is
+ // handled the same as when receiving a value equal to
+ // header.NDPInfiniteLifetime.
+ {
+ name: "MoreThanInfiniteVL",
+ infiniteVL: infiniteVLSeconds + 1,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with finite prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with infinite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
+
+ // Receive a new RA with prefix with finite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ case <-time.After(minVLSeconds*time.Second + defaultTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+ })
+}
+
// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
// auto-generated address only gets updated when required to, as specified in
// RFC 4862 section 5.5.3.e.
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
- const newMinVL = 5
+ const newMinVL = 4
saved := stack.MinPrefixInformationValidLifetimeForUpdate
defer func() {
stack.MinPrefixInformationValidLifetimeForUpdate = saved
@@ -1854,6 +2445,119 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
}
+// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
+// assigned to the NIC but is in the permanentExpired state.
+func TestAutoGenAddrAfterRemoval(t *testing.T) {
+ t.Parallel()
+
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
+
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
+
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
+
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
+
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
+}
+
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
// is already assigned to the NIC, the static address remains.
func TestAutoGenAddrStaticConflict(t *testing.T) {
@@ -1911,6 +2615,110 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
}
+// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
+// opaque interface identifiers when configured to do so.
+func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
+ t.Parallel()
+
+ const nicID = 1
+ const nicName = "nic1"
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
+ prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
+ // addr1 and addr2 are the addresses that are expected to be generated when
+ // stack.Stack is configured to generate opaque interface identifiers as
+ // defined by RFC 7217.
+ addrBytes := []byte(subnet1.ID())
+ addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ addrBytes = []byte(subnet2.ID())
+ addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in a PI.
+ const validLifetimeSecondPrefix1 = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in a PI with a large valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+}
+
// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
// to the integrator when an RA is received with the NDP Recursive DNS Server
// option with at least one valid address.
@@ -2312,3 +3120,318 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
default:
}
}
+
+// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly
+// informed when new information about what configurations are available via
+// DHCPv6 is learned.
+func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ t.Helper()
+ select {
+ case e := <-ndpDisp.dhcpv6ConfigurationC:
+ if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" {
+ t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DHCPv6 configuration event")
+ }
+ }
+
+ expectNoDHCPv6Event := func() {
+ t.Helper()
+ select {
+ case <-ndpDisp.dhcpv6ConfigurationC:
+ t.Fatal("unexpected DHCPv6 configuration event")
+ default:
+ }
+ }
+
+ // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ // Receiving the same update again should not result in an event to the
+ // NDPDispatcher.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to none.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ //
+ // Note, when the M flag is set, the O flag is redundant.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+ // Even though the DHCPv6 flags are different, the effective configuration is
+ // the same so we should not receive a new event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+}
+
+// TestRouterSolicitation tests the initial Router Solicitations that are sent
+// when a NIC newly becomes enabled.
+func TestRouterSolicitation(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxRtrSolicit uint8
+ rtrSolicitInt time.Duration
+ effectiveRtrSolicitInt time.Duration
+ maxRtrSolicitDelay time.Duration
+ effectiveMaxRtrSolicitDelay time.Duration
+ }{
+ {
+ name: "Single RS with delay",
+ maxRtrSolicit: 1,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: time.Second,
+ effectiveMaxRtrSolicitDelay: time.Second,
+ },
+ {
+ name: "Two RS with delay",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: 500 * time.Millisecond,
+ effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
+ },
+ {
+ name: "Single RS without delay",
+ maxRtrSolicit: 1,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS without delay and invalid zero interval",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: 0,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Three RS without delay",
+ maxRtrSolicit: 3,
+ rtrSolicitInt: 500 * time.Millisecond,
+ effectiveRtrSolicitInt: 500 * time.Millisecond,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS with invalid negative delay",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: -3 * time.Second,
+ effectiveMaxRtrSolicitDelay: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+ e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+ select {
+ case p := <-e.C:
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t,
+ p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(),
+ )
+
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for packet")
+ }
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
+ select {
+ case <-e.C:
+ t.Fatal("unexpectedly got a packet")
+ case <-time.After(timeout):
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Make sure each RS got sent at the right
+ // times.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ remaining--
+ }
+ for ; remaining > 0; remaining-- {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
+ waitForPkt(2 * defaultTimeout)
+ }
+
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ }
+
+ // Make sure the counter got properly
+ // incremented.
+ if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
+ })
+}
+
+// TestStopStartSolicitingRouters tests that when forwarding is enabled or
+// disabled, router solicitations are stopped or started, respecitively.
+func TestStopStartSolicitingRouters(t *testing.T) {
+ t.Parallel()
+
+ const interval = 500 * time.Millisecond
+ const delay = time.Second
+ const maxRtrSolicitations = 3
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+ select {
+ case p := <-e.C:
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for packet")
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Enable forwarding which should stop router solicitations.
+ s.SetForwarding(true)
+ select {
+ case <-e.C:
+ // A single RS may have been sent before forwarding was enabled.
+ select {
+ case <-e.C:
+ t.Fatal("Should not have sent more than one RS message")
+ case <-time.After(interval + defaultTimeout):
+ }
+ case <-time.After(delay + defaultTimeout):
+ }
+
+ // Enabling forwarding again should do nothing.
+ s.SetForwarding(true)
+ select {
+ case <-e.C:
+ t.Fatal("unexpectedly got a packet after becoming a router")
+ case <-time.After(delay + defaultTimeout):
+ }
+
+ // Disable forwarding which should start router solicitations.
+ s.SetForwarding(false)
+ waitForPkt(delay + defaultTimeout)
+ waitForPkt(interval + defaultTimeout)
+ waitForPkt(interval + defaultTimeout)
+ select {
+ case <-e.C:
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ case <-time.After(interval + defaultTimeout):
+ }
+
+ // Disabling forwarding again should do nothing.
+ s.SetForwarding(false)
+ select {
+ case <-e.C:
+ t.Fatal("unexpectedly got a packet after becoming a router")
+ case <-time.After(delay + defaultTimeout):
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index ddd014658..4452a1302 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -15,10 +15,12 @@
package stack
import (
+ "log"
+ "sort"
"strings"
- "sync"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -27,11 +29,11 @@ import (
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
- stack *Stack
- id tcpip.NICID
- name string
- linkEP LinkEndpoint
- loopback bool
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ context NICContext
mu sync.RWMutex
spoofing bool
@@ -85,7 +87,7 @@ const (
)
// newNIC returns a new NIC using the default NDP configurations from stack.
-func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
// example, make sure that the link address it provides is a valid
// unicast ethernet address.
@@ -99,7 +101,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
id: id,
name: name,
linkEP: ep,
- loopback: loopback,
+ context: ctx,
primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -174,45 +176,73 @@ func (n *NIC) enable() *tcpip.Error {
return err
}
- if !n.stack.autoGenIPv6LinkLocal {
- return nil
- }
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
+ var addr tcpip.Address
+ if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey)
+ } else {
+ l2addr := n.linkEP.LinkAddress()
+
+ // Only attempt to generate the link-local address if we have a valid MAC
+ // address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
- l2addr := n.linkEP.LinkAddress()
+ addr = header.LinkLocalAddr(l2addr)
+ }
- // Only attempt to generate the link-local address if we have a
- // valid MAC address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address
- // (provided by LinkEndpoint.LinkAddress) before reaching this
- // point.
- if !header.IsValidUnicastEthernetAddress(l2addr) {
- return nil
+ if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ },
+ }, CanBePrimaryEndpoint, static, false /* deprecated */); err != nil {
+ return err
+ }
}
- addr := header.LinkLocalAddr(l2addr)
-
- _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
- },
- }, CanBePrimaryEndpoint)
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyways.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !n.stack.forwarding {
+ n.ndp.startSolicitingRouters()
+ }
- return err
+ return nil
}
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
// routers, discovered on-link prefixes, and auto-generated addresses) will
-// be cleaned up/invalidated.
+// be cleaned up/invalidated and NDP router solicitations will be stopped.
func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
n.ndp.cleanupHostOnlyState()
+ n.ndp.stopSolicitingRouters()
+}
+
+// becomeIPv6Host transitions n into an IPv6 host.
+//
+// When transitioning into an IPv6 host, NDP router solicitations will be
+// started.
+func (n *NIC) becomeIPv6Host() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.ndp.startSolicitingRouters()
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
@@ -235,6 +265,10 @@ func (n *NIC) isPromiscuousMode() bool {
return rv
}
+func (n *NIC) isLoopback() bool {
+ return n.linkEP.Capabilities()&CapabilityLoopback != 0
+}
+
// setSpoofing enables or disables address spoofing.
func (n *NIC) setSpoofing(enable bool) {
n.mu.Lock()
@@ -242,14 +276,145 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-// primaryEndpoint returns the primary endpoint of n for the given network
-// protocol.
-func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
+// primaryEndpoint will return the first non-deprecated endpoint if such an
+// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
+// endpoint exists, the first deprecated endpoint will be returned.
+//
+// If an IPv6 primary endpoint is requested, Source Address Selection (as
+// defined by RFC 6724 section 5) will be performed.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ if protocol == header.IPv6ProtocolNumber && remoteAddr != "" {
+ return n.primaryIPv6Endpoint(remoteAddr)
+ }
+
n.mu.RLock()
defer n.mu.RUnlock()
+ var deprecatedEndpoint *referencedNetworkEndpoint
for _, r := range n.primary[protocol] {
- if r.isValidForOutgoing() && r.tryIncRef() {
+ if !r.isValidForOutgoing() {
+ continue
+ }
+
+ if !r.deprecated {
+ if r.tryIncRef() {
+ // r is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ deprecatedEndpoint.decRefLocked()
+ deprecatedEndpoint = nil
+ }
+
+ return r
+ }
+ } else if deprecatedEndpoint == nil && r.tryIncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of r in
+ // case n doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, r's reference count
+ // will be decremented when such an endpoint is found.
+ deprecatedEndpoint = r
+ }
+ }
+
+ // n doesn't have any valid non-deprecated endpoints, so return
+ // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated
+ // endpoints either).
+ return deprecatedEndpoint
+}
+
+// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC
+// 6724 section 5).
+type ipv6AddrCandidate struct {
+ ref *referencedNetworkEndpoint
+ scope header.IPv6AddressScope
+}
+
+// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ primaryAddrs := n.primary[header.IPv6ProtocolNumber]
+
+ if len(primaryAddrs) == 0 {
+ return nil
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
+ for _, r := range primaryAddrs {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !r.isValidForOutgoing() {
+ continue
+ }
+
+ addr := r.ep.ID().LocalAddress
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err)
+ }
+
+ cs = append(cs, ipv6AddrCandidate{
+ ref: r,
+ scope: scope,
+ })
+ }
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.ref.ep.ID().LocalAddress == remoteAddr {
+ return true
+ }
+ if sb.ref.ep.ID().LocalAddress == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if r := c.ref; r.tryIncRef() {
return r
}
}
@@ -362,13 +527,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, temporary, static)
+ }, peb, temporary, static, false)
n.mu.Unlock()
return ref
}
-func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
+// addPermanentAddressLocked adds a permanent address to n.
+//
+// If n already has the address in a non-permanent state,
+// addPermanentAddressLocked will promote it to permanent and update the
+// endpoint with the properties provided.
+func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
if ref, ok := n.endpoints[id]; ok {
switch ref.getKind() {
@@ -376,10 +546,14 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
// The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
case permanentExpired, temporary:
- // Promote the endpoint to become permanent and respect
- // the new peb.
+ // Promote the endpoint to become permanent and respect the new peb,
+ // configType and deprecated status.
if ref.tryIncRef() {
+ // TODO(b/147748385): Perform Duplicate Address Detection when promoting
+ // an IPv6 endpoint to permanent.
ref.setKind(permanent)
+ ref.deprecated = deprecated
+ ref.configType = configType
refs := n.primary[ref.protocol]
for i, r := range refs {
@@ -411,10 +585,14 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
}
- return n.addAddressLocked(protocolAddress, peb, permanent, static)
+ return n.addAddressLocked(protocolAddress, peb, permanent, configType, deprecated)
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) {
+// addAddressLocked adds a new protocolAddress to n.
+//
+// If the address is already known by n (irrespective of the state it is in),
+// addAddressLocked does nothing and returns tcpip.ErrDuplicateAddress.
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
// TODO(b/141022673): Validate IP address before adding them.
// Sanity check.
@@ -450,6 +628,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
protocol: protocolAddress.Protocol,
kind: kind,
configType: configType,
+ deprecated: deprecated,
}
// Set up cache if link address resolution exists for this protocol.
@@ -487,7 +666,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addPermanentAddressLocked(protocolAddress, peb)
+ _, err := n.addPermanentAddressLocked(protocolAddress, peb, static, false /* deprecated */)
n.mu.Unlock()
return err
@@ -548,6 +727,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
return addrs
}
+// primaryAddress returns the primary address associated with this NIC.
+//
+// primaryAddress will return the first non-deprecated address if such an
+// address exists. If no non-deprecated address exists, the first deprecated
+// address will be returned.
+func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list, ok := n.primary[proto]
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ var deprecatedEndpoint *referencedNetworkEndpoint
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints to avoid confusion
+ // and prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentTentative, permanentExpired, temporary:
+ continue
+ }
+
+ if !ref.deprecated {
+ return tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ }
+ }
+
+ if deprecatedEndpoint == nil {
+ deprecatedEndpoint = ref
+ }
+ }
+
+ if deprecatedEndpoint != nil {
+ return tcpip.AddressWithPrefix{
+ Address: deprecatedEndpoint.ep.ID().LocalAddress,
+ PrefixLen: deprecatedEndpoint.ep.PrefixLen(),
+ }
+ }
+
+ return tcpip.AddressWithPrefix{}
+}
+
// AddAddressRange adds a range of addresses to n, so that it starts accepting
// packets targeted at the given addresses and network protocol. The range is
// given by a subnet address, and all addresses contained in the subnet are
@@ -575,7 +799,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
n.mu.Unlock()
}
-// Subnets returns the Subnets associated with this NIC.
+// AddressRanges returns the Subnets associated with this NIC.
func (n *NIC) AddressRanges() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
@@ -724,7 +948,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A
Address: addr,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, NeverPrimaryEndpoint); err != nil {
+ }, NeverPrimaryEndpoint, static, false /* deprecated */); err != nil {
return err
}
}
@@ -1102,8 +1326,14 @@ type referencedNetworkEndpoint struct {
kind networkEndpointKind
// configType is the method that was used to configure this endpoint.
- // This must never change after the endpoint is added to a NIC.
+ // This must never change except during endpoint creation and promotion to
+ // permanent.
configType networkEndpointConfigType
+
+ // deprecated indicates whether or not the endpoint should be considered
+ // deprecated. That is, when deprecated is true, other endpoints that are not
+ // deprecated should be preferred.
+ deprecated bool
}
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 61fd46d66..2b8751d49 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -234,15 +234,15 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 34307ae07..517f4b941 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -158,7 +158,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack
return tcpip.ErrInvalidEndpointState
}
- err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt)
+ err := r.ref.ep.WritePacket(r, gso, params, pkt)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
@@ -174,7 +174,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network
return 0, tcpip.ErrInvalidEndpointState
}
- n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop)
+ n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n))
}
@@ -195,7 +195,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil {
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 7a9600679..fc56a6d79 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -21,13 +21,13 @@ package stack
import (
"encoding/binary"
- "sync"
"sync/atomic"
"time"
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -352,6 +352,38 @@ func (u *uniqueIDGenerator) UniqueID() uint64 {
return atomic.AddUint64((*uint64)(u), 1)
}
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it will be passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -412,8 +444,8 @@ type Stack struct {
ndpConfigs NDPConfigurations
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled NICs.
- // See the AutoGenIPv6LinkLocal field of Options for more details.
+ // to auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
autoGenIPv6LinkLocal bool
// ndpDisp is the NDP event dispatcher that is used to send the netstack
@@ -422,6 +454,10 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
}
// UniqueID is an abstract generator of unique identifiers.
@@ -460,13 +496,15 @@ type Options struct {
// before assigning an address to a NIC.
NDPConfigs NDPConfigurations
- // AutoGenIPv6LinkLocal determins whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
// Note, setting this to true does not mean that a link-local address
- // will be assigned right away, or at all. If Duplicate Address
- // Detection is enabled, an address will only be assigned if it
- // successfully resolves. If it fails, no further attempt will be made
- // to auto-generate an IPv6 link-local address.
+ // will be assigned right away, or at all. If Duplicate Address Detection
+ // is enabled, an address will only be assigned if it successfully resolves.
+ // If it fails, no further attempt will be made to auto-generate an IPv6
+ // link-local address.
//
// The generated link-local address will follow RFC 4291 Appendix A
// guidelines.
@@ -479,6 +517,10 @@ type Options struct {
// RawFactory produces raw endpoints. Raw endpoints are enabled only if
// this is non-nil.
RawFactory RawFactory
+
+ // OpaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -505,6 +547,49 @@ type TransportEndpointInfo struct {
RegisterNICID tcpip.NICID
}
+// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address
+// and returns the network protocol number to be used to communicate with the
+// specified address. It returns an error if the passed address is incompatible
+// with the receiver.
+func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.NetProto
+ switch len(addr.Addr) {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ if header.IsV4MappedAddress(addr.Addr) {
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == header.IPv4Any {
+ addr.Addr = ""
+ }
+ }
+ }
+
+ switch len(e.ID.LocalAddress) {
+ case header.IPv4AddressSize:
+ if len(addr.Addr) == header.IPv6AddressSize {
+ return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ }
+ case header.IPv6AddressSize:
+ if len(addr.Addr) == header.IPv4AddressSize {
+ return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable
+ }
+ }
+
+ switch {
+ case netProto == e.NetProto:
+ case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ if v6only {
+ return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
+ }
+ default:
+ return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return addr, netProto, nil
+}
+
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
// marker interface.
func (*TransportEndpointInfo) IsEndpointInfo() {}
@@ -549,6 +634,7 @@ func New(opts Options) *Stack {
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
}
// Add specified network protocols.
@@ -664,7 +750,9 @@ func (s *Stack) Stats() tcpip.Stats {
// SetForwarding enables or disables the packet forwarding between NICs.
//
// When forwarding becomes enabled, any host-only state on all NICs will be
-// cleaned up.
+// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started.
+// When forwarding becomes disabled and if IPv6 is enabled, NDP Router
+// Solicitations will be stopped.
func (s *Stack) SetForwarding(enable bool) {
// TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
s.mu.Lock()
@@ -686,6 +774,10 @@ func (s *Stack) SetForwarding(enable bool) {
for _, nic := range s.nics {
nic.becomeIPv6Router()
}
+ } else {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Host()
+ }
}
}
@@ -753,9 +845,30 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum
return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
}
-// createNIC creates a NIC with the provided id and link-layer endpoint, and
-// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
+// NICContext is an opaque pointer used to store client-supplied NIC metadata.
+type NICContext interface{}
+
+// NICOptions specifies the configuration of a NIC as it is being created.
+// The zero value creates an enabled, unnamed NIC.
+type NICOptions struct {
+ // Name specifies the name of the NIC.
+ Name string
+
+ // Disabled specifies whether to avoid calling Attach on the passed
+ // LinkEndpoint.
+ Disabled bool
+
+ // Context specifies user-defined data that will be returned in stack.NICInfo
+ // for the NIC. Clients of this library can use it to add metadata that
+ // should be tracked alongside a NIC, to avoid having to keep a
+ // map[tcpip.NICID]metadata mirroring stack.Stack's nic map.
+ Context NICContext
+}
+
+// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
+// NICOptions. See the documentation on type NICOptions for details on how
+// NICs can be configured.
+func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -764,44 +877,20 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled,
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, name, ep, loopback)
+ n := newNIC(s, id, opts.Name, ep, opts.Context)
s.nics[id] = n
- if enabled {
+ if !opts.Disabled {
return n.enable()
}
return nil
}
-// CreateNIC creates a NIC with the provided id and link-layer endpoint.
+// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
+// `LinkEndpoint.Attach` to start delivering packets to it.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, "", ep, true, false)
-}
-
-// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
-// and a human-readable name.
-func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, true, false)
-}
-
-// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
-// endpoint, and a human-readable name.
-func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, true, true)
-}
-
-// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
-// but leave it disable. Stack.EnableNIC must be called before the link-layer
-// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, "", ep, false, false)
-}
-
-// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
-// CreateDisabledNIC.
-func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, false, false)
+ return s.CreateNICWithOptions(id, ep, NICOptions{})
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -829,7 +918,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
return false
}
-// NICSubnets returns a map of NICIDs to their associated subnets.
+// NICAddressRanges returns a map of NICIDs to their associated subnets.
func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -855,6 +944,18 @@ type NICInfo struct {
MTU uint32
Stats NICStats
+
+ // Context is user-supplied data optionally supplied in CreateNICWithOptions.
+ // See type NICOptions for more details.
+ Context NICContext
+}
+
+// HasNIC returns true if the NICID is defined in the stack.
+func (s *Stack) HasNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ _, ok := s.nics[id]
+ s.mu.RUnlock()
+ return ok
}
// NICInfo returns a map of NICIDs to their associated information.
@@ -868,7 +969,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Up: true, // Netstack interfaces are always up.
Running: nic.linkEP.IsAttached(),
Promiscuous: nic.isPromiscuousMode(),
- Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0,
+ Loopback: nic.isLoopback(),
}
nics[id] = NICInfo{
Name: nic.name,
@@ -877,6 +978,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
+ Context: nic.context,
}
}
return nics
@@ -993,9 +1095,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
return nics
}
-// GetMainNICAddress returns the first primary address and prefix for the given
-// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
-// value if the NIC doesn't have a primary address for the given protocol.
+// GetMainNICAddress returns the first non-deprecated primary address and prefix
+// for the given NIC and protocol. If no non-deprecated primary address exists,
+// a deprecated primary address and prefix will be returned. Returns an error if
+// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
+// address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1005,17 +1109,12 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- for _, a := range nic.PrimaryAddresses() {
- if a.Protocol == protocol {
- return a.AddressWithPrefix, nil
- }
- }
- return tcpip.AddressWithPrefix{}, nil
+ return nic.primaryAddress(protocol), nil
}
-func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
if len(localAddr) == 0 {
- return nic.primaryEndpoint(netProto)
+ return nic.primaryEndpoint(netProto, remoteAddr)
}
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
@@ -1031,8 +1130,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok {
- if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil
+ if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
+ return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
}
} else {
@@ -1041,14 +1140,14 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
if nic, ok := s.nics[route.NIC]; ok {
- if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
+ if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
// provided will refer to the link local address.
remoteAddr = ref.ep.ID().LocalAddress
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback)
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
if needRoute {
r.NextHop = route.Gateway
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 8fc034ca1..4b3d18f1b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,12 +27,15 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
@@ -122,7 +125,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -133,7 +136,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
b[1] = f.id.LocalAddress[0]
b[2] = byte(params.Protocol)
- if loop&stack.PacketLoop != 0 {
+ if r.Loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
@@ -141,7 +144,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
@@ -149,11 +152,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -1894,55 +1897,67 @@ func TestNICForwarding(t *testing.T) {
}
// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
-// (or lack there-of if disabled (default)). Note, DAD will be disabled in
-// these tests.
+// using the modified EUI-64 of the NIC's MAC address (or lack there-of if
+// disabled (default)). Note, DAD will be disabled in these tests.
func TestNICAutoGenAddr(t *testing.T) {
tests := []struct {
name string
autoGen bool
linkAddr tcpip.LinkAddress
+ iidOpts stack.OpaqueInterfaceIdentifierOptions
shouldGen bool
}{
{
"Disabled",
false,
linkAddr1,
+ stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(nicID tcpip.NICID, _ string) string {
+ return fmt.Sprintf("nic%d", nicID)
+ },
+ },
false,
},
{
"Enabled",
true,
linkAddr1,
+ stack.OpaqueInterfaceIdentifierOptions{},
true,
},
{
"Nil MAC",
true,
tcpip.LinkAddress([]byte(nil)),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Empty MAC",
true,
tcpip.LinkAddress(""),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Invalid MAC",
true,
tcpip.LinkAddress("\x01\x02\x03"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Multicast MAC",
true,
tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
{
"Unspecified MAC",
true,
tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ stack.OpaqueInterfaceIdentifierOptions{},
false,
},
}
@@ -1951,13 +1966,12 @@ func TestNICAutoGenAddr(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ OpaqueIIDOpts: test.iidOpts,
}
if test.autoGen {
- // Only set opts.AutoGenIPv6LinkLocal when
- // test.autoGen is true because
- // opts.AutoGenIPv6LinkLocal should be false by
- // default.
+ // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by default.
opts.AutoGenIPv6LinkLocal = true
}
@@ -1973,9 +1987,171 @@ func TestNICAutoGenAddr(t *testing.T) {
}
if test.shouldGen {
+ // Should have auto-generated an address and resolved immediately (DAD
+ // is disabled).
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+ } else {
+ // Should not have auto-generated an address.
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ }
+ })
+ }
+}
+
+// TestNICContextPreservation tests that you can read out via stack.NICInfo the
+// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
+func TestNICContextPreservation(t *testing.T) {
+ var ctx *int
+ tests := []struct {
+ name string
+ opts stack.NICOptions
+ want stack.NICContext
+ }{
+ {
+ "context_set",
+ stack.NICOptions{Context: ctx},
+ ctx,
+ },
+ {
+ "context_not_set",
+ stack.NICOptions{},
+ nil,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ id := tcpip.NICID(1)
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
+ t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
+ }
+ nicinfos := s.NICInfo()
+ nicinfo, ok := nicinfos[id]
+ if !ok {
+ t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
+ }
+ if got, want := nicinfo.Context == test.want, true; got != want {
+ t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
+ }
+ })
+ }
+}
+
+// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local
+// addresses with opaque interface identifiers. Link Local addresses should
+// always be generated with opaque IIDs if configured to use them, even if the
+// NIC has an invalid MAC address.
+func TestNICAutoGenAddrWithOpaque(t *testing.T) {
+ const nicID = 1
+
+ var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte
+ n, err := rand.Read(secretKey[:])
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n)
+ }
+
+ tests := []struct {
+ name string
+ nicName string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ secretKey []byte
+ }{
+ {
+ name: "Disabled",
+ nicName: "nic1",
+ autoGen: false,
+ linkAddr: linkAddr1,
+ secretKey: secretKey[:],
+ },
+ {
+ name: "Enabled",
+ nicName: "nic1",
+ autoGen: true,
+ linkAddr: linkAddr1,
+ secretKey: secretKey[:],
+ },
+ // These are all cases where we would not have generated a
+ // link-local address if opaque IIDs were disabled.
+ {
+ name: "Nil MAC and empty nicName",
+ nicName: "",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress([]byte(nil)),
+ secretKey: secretKey[:1],
+ },
+ {
+ name: "Empty MAC and empty nicName",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress(""),
+ secretKey: secretKey[:2],
+ },
+ {
+ name: "Invalid MAC",
+ nicName: "test",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x01\x02\x03"),
+ secretKey: secretKey[:3],
+ },
+ {
+ name: "Multicast MAC",
+ nicName: "test2",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ secretKey: secretKey[:4],
+ },
+ {
+ name: "Unspecified MAC and nil SecretKey",
+ nicName: "test3",
+ autoGen: true,
+ linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: test.secretKey,
+ },
+ }
+
+ if test.autoGen {
+ // Only set opts.AutoGenIPv6LinkLocal when
+ // test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by
+ // default.
+ opts.AutoGenIPv6LinkLocal = true
+ }
+
+ e := channel.New(10, 1280, test.linkAddr)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: test.nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+
+ if test.autoGen {
// Should have auto-generated an address and
// resolved immediately (DAD is disabled).
- if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
}
} else {
@@ -1988,6 +2164,56 @@ func TestNICAutoGenAddr(t *testing.T) {
}
}
+// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are
+// not auto-generated for loopback NICs.
+func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
+ const nicID = 1
+ const nicName = "nicName"
+
+ tests := []struct {
+ name string
+ opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
+ }{
+ {
+ name: "IID From MAC",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
+ },
+ {
+ name: "Opaque IID",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
+ }
+
+ e := loopback.New()
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
+ }
+ })
+ }
+}
+
// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
// link-local addresses will only be assigned after the DAD process resolves.
func TestNICAutoGenAddrDoesDAD(t *testing.T) {
@@ -2186,3 +2412,154 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
}
}
}
+
+func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
+ const (
+ linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ nicID = 1
+ )
+
+ // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test.
+ tests := []struct {
+ name string
+ nicAddrs []tcpip.Address
+ connectAddr tcpip.Address
+ expectedLocalAddr tcpip.Address
+ }{
+ // Test Rule 1 of RFC 6724 section 5.
+ {
+ name: "Same Global most preferred (last address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr1,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Same Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: globalAddr1,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Same Link Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalAddr1,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Same Link Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalAddr1,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Same Unique Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
+ connectAddr: uniqueLocalAddr1,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Same Unique Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: uniqueLocalAddr1,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+
+ // Test Rule 2 of RFC 6724 section 5.
+ {
+ name: "Global most preferred (last address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Link Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Unique Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Unique Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+
+ // Test returning the endpoint that is closest to the front when
+ // candidate addresses are "equal" from the perspective of RFC 6724
+ // section 5.
+ {
+ name: "Unique Local for Global",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Link Local for Global",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local for Unique Local",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+
+ for _, a := range test.nicAddrs {
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
+ t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ }
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
+ t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 67c21be42..d686e6eb8 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -18,8 +18,8 @@ import (
"fmt"
"math/rand"
"sort"
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
return
}
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt)
+ transEP := selectEndpoint(id, mpep, epsByNic.seed)
+ if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
+ queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ epsByNic.mu.RUnlock()
+ return
+ }
+
+ transEP.HandlePacket(r, id, pkt)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
epsByNic.mu.Lock()
defer epsByNic.mu.Unlock()
@@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort
}
// This is a new binding.
- multiPortEp := &multiPortEndpoint{}
+ multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto}
multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
multiPortEp.reuse = reusePort
epsByNic.endpoints[bindToDevice] = multiPortEp
@@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// newTransportDemuxer.
type transportDemuxer struct {
// protocol is immutable.
- protocol map[protocolIDs]*transportEndpoints
+ protocol map[protocolIDs]*transportEndpoints
+ queuedProtocols map[protocolIDs]queuedTransportProtocol
+}
+
+// queuedTransportProtocol if supported by a protocol implementation will cause
+// the dispatcher to delivery packets to the QueuePacket method instead of
+// calling HandlePacket directly on the endpoint.
+type queuedTransportProtocol interface {
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
- d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+ d := &transportDemuxer{
+ protocol: make(map[protocolIDs]*transportEndpoints),
+ queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
+ }
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
- d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ protoIDs := protocolIDs{netProto, proto}
+ d.protocol[protoIDs] = &transportEndpoints{
endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
+ qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
+ if isQueued {
+ d.queuedProtocols[protoIDs] = qTransProto
+ }
}
}
@@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ demux *transportDemuxer
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
// reuse indicates if more than one endpoint is allowed.
@@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
ep.mu.RLock()
+ queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
for i, endpoint := range ep.endpointsArr {
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
if i == len(ep.endpointsArr)-1 {
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ break
+ }
endpoint.HandlePacket(r, id, pkt)
break
}
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ continue
+ }
endpoint.HandlePacket(r, id, pkt.Clone())
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
@@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if epsByNic, ok := eps.endpoints[id]; ok {
// There was already a binding.
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// This is a new binding.
@@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
}
eps.endpoints[id] = epsByNic
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 3b28b06d0..5e9237de9 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -41,7 +41,7 @@ const (
type testContext struct {
t *testing.T
- linkEPs map[string]*channel.Endpoint
+ linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
ep tcpip.Endpoint
@@ -61,35 +61,29 @@ func (c *testContext) createV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.ep.SetSockOpt(v); err != nil {
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
}
-// newDualTestContextMultiNic creates the testing context and also linkEpNames
-// named NICs.
-func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
+func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- linkEPs := make(map[string]*channel.Endpoint)
- for i, linkEpName := range linkEpNames {
- channelEP := channel.New(256, mtu, "")
- nicID := tcpip.NICID(i + 1)
- if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil {
+ linkEps := make(map[tcpip.NICID]*channel.Endpoint)
+ for _, linkEpID := range linkEpIDs {
+ channelEp := channel.New(256, mtu, "")
+ if err := s.CreateNIC(linkEpID, channelEp); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- linkEPs[linkEpName] = channelEP
+ linkEps[linkEpID] = channelEp
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress IPv4 failed: %v", err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress IPv6 failed: %v", err)
}
}
@@ -108,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string)
return &testContext{
t: t,
s: s,
- linkEPs: linkEPs,
+ linkEps: linkEps,
}
}
@@ -125,7 +119,7 @@ func newPayload() []byte {
return b
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -156,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -186,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
func TestDistribution(t *testing.T) {
type endpointSockopts struct {
reuse int
- bindToDevice string
+ bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
@@ -194,71 +188,71 @@ func TestDistribution(t *testing.T) {
endpoints []endpointSockopts
// wantedDistribution is the wanted ratio of packets received on each
// endpoint for each NIC on which packets are injected.
- wantedDistributions map[string][]float64
+ wantedDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
- "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ 1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- endpointSockopts{0, "dev0"},
- endpointSockopts{0, "dev1"},
- endpointSockopts{0, "dev2"},
+ {0, 1},
+ {0, 2},
+ {0, 3},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
- "dev0": []float64{1, 0, 0},
+ 1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
- "dev1": []float64{0, 1, 0},
+ 2: {0, 1, 0},
// Injected packets on dev2 go only to the endpoint bound to dev2.
- "dev2": []float64{0, 0, 1},
+ 3: {0, 0, 1},
},
},
{
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- endpointSockopts{1, "dev0"},
- endpointSockopts{1, "dev0"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, ""},
+ {1, 1},
+ {1, 1},
+ {1, 2},
+ {1, 2},
+ {1, 2},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
- "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ 1: {0.5, 0.5, 0, 0, 0, 0},
// Injected packets on dev1 get distributed among endpoints bound to
// dev1 or unbound.
- "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
// Injected packets on dev999 go only to the unbound.
- "dev999": []float64{0, 0, 0, 0, 0, 1},
+ 1000: {0, 0, 0, 0, 0, 1},
},
},
} {
t.Run(test.name, func(t *testing.T) {
for device, wantedDistribution := range test.wantedDistributions {
- t.Run(device, func(t *testing.T) {
- var devices []string
+ t.Run(string(device), func(t *testing.T) {
+ var devices []tcpip.NICID
for d := range test.wantedDistributions {
devices = append(devices, d)
}
- c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ c := newDualTestContextMultiNIC(t, defaultMTU, devices)
defer c.cleanup()
c.createV6Endpoint(false)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 748ce4ea5..f50604a8a 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -102,13 +102,23 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptBool sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// SetSockOptInt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f62fd729f..b7813cbc0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -35,10 +35,10 @@ import (
"reflect"
"strconv"
"strings"
- "sync"
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/waiter"
@@ -322,7 +322,7 @@ type ControlMessages struct {
HasTOS bool
// TOS is the IPv4 type of service of the associated packet.
- TOS int8
+ TOS uint8
// HasTClass indicates whether Tclass is valid/set.
HasTClass bool
@@ -423,17 +423,25 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptBool sets a socket option, for simple cases where a value
+ // has the bool type.
+ SetSockOptBool(opt SockOptBool, v bool) *Error
+
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
- SetSockOptInt(opt SockOpt, v int) *Error
+ SetSockOptInt(opt SockOptInt, v int) *Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
+ // GetSockOptBool gets a socket option for simple cases where a return
+ // value has the bool type.
+ GetSockOptBool(SockOptBool) (bool, *Error)
+
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
- GetSockOptInt(SockOpt) (int, *Error)
+ GetSockOptInt(SockOptInt) (int, *Error)
// State returns a socket's lifecycle state. The returned value is
// protocol-specific and is primarily used for diagnostics.
@@ -488,13 +496,26 @@ type WriteOptions struct {
Atomic bool
}
-// SockOpt represents socket options which values have the int type.
-type SockOpt int
+// SockOptBool represents socket options which values have the bool type.
+type SockOptBool int
+
+const (
+ // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
+ // ancillary message is passed with incoming packets.
+ ReceiveTOSOption SockOptBool = iota
+
+ // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
+ // socket is to be restricted to sending and receiving IPv6 packets only.
+ V6OnlyOption
+)
+
+// SockOptInt represents socket options which values have the int type.
+type SockOptInt int
const (
// ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
// number of unread bytes in the input buffer should be returned.
- ReceiveQueueSizeOption SockOpt = iota
+ ReceiveQueueSizeOption SockOptInt = iota
// SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
// specify the send buffer size option.
@@ -521,10 +542,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
-// socket is to be restricted to sending and receiving IPv6 packets only.
-type V6OnlyOption int
-
// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
// held until segments are full by the TCP transport protocol.
type CorkOption int
@@ -539,7 +556,7 @@ type ReusePortOption int
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
-type BindToDeviceOption string
+type BindToDeviceOption NICID
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
new file mode 100644
index 000000000..f5f01f32f
--- /dev/null
+++ b/pkg/tcpip/timer.go
@@ -0,0 +1,161 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "sync"
+ "time"
+)
+
+// cancellableTimerInstance is a specific instance of CancellableTimer.
+//
+// Different instances are created each time CancellableTimer is Reset so each
+// timer has its own earlyReturn signal. This is to address a bug when a
+// CancellableTimer is stopped and reset in quick succession resulting in a
+// timer instance's earlyReturn signal being affected or seen by another timer
+// instance.
+//
+// Consider the following sceneario where timer instances share a common
+// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
+// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
+// (B), third (C), and fourth (D) instance of the timer firing, respectively):
+// T1: Obtain L
+// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T2: instance A fires, blocked trying to obtain L.
+// T1: Attempt to stop instance A (set earlyReturn = true)
+// T1: Reset timer (create instance B)
+// T3: instance B fires, blocked trying to obtain L.
+// T1: Attempt to stop instance B (set earlyReturn = true)
+// T1: Reset timer (create instance C)
+// T4: instance C fires, blocked trying to obtain L.
+// T1: Attempt to stop instance C (set earlyReturn = true)
+// T1: Reset timer (create instance D)
+// T5: instance D fires, blocked trying to obtain L.
+// T1: Release L
+//
+// Now that T1 has released L, any of the 4 timer instances can take L and check
+// earlyReturn. If the timers simply check earlyReturn and then do nothing
+// further, then instance D will never early return even though it was not
+// requested to stop. If the timers reset earlyReturn before early returning,
+// then all but one of the timers will do work when only one was expected to.
+// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// will fire (again, when only one was expected to).
+//
+// To address the above concerns the simplest solution was to give each timer
+// its own earlyReturn signal.
+type cancellableTimerInstance struct {
+ timer *time.Timer
+
+ // Used to inform the timer to early return when it gets stopped while the
+ // lock the timer tries to obtain when fired is held (T1 is a goroutine that
+ // tries to cancel the timer and T2 is the goroutine that handles the timer
+ // firing):
+ // T1: Obtain the lock, then call StopLocked()
+ // T2: timer fires, and gets blocked on obtaining the lock
+ // T1: Releases lock
+ // T2: Obtains lock does unintended work
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using earlyReturn to return early so that once T2 obtains
+ // the lock, it will see that it is set to true and do nothing further.
+ earlyReturn *bool
+}
+
+// stop stops the timer instance t from firing if it hasn't fired already. If it
+// has fired and is blocked at obtaining the lock, earlyReturn will be set to
+// true so that it will early return when it obtains the lock.
+func (t *cancellableTimerInstance) stop() {
+ if t.timer != nil {
+ t.timer.Stop()
+ *t.earlyReturn = true
+ }
+}
+
+// CancellableTimer is a timer that does some work and can be safely cancelled
+// when it fires at the same time some "related work" is being done.
+//
+// The term "related work" is defined as some work that needs to be done while
+// holding some lock that the timer must also hold while doing some work.
+type CancellableTimer struct {
+ // The active instance of a cancellable timer.
+ instance cancellableTimerInstance
+
+ // locker is the lock taken by the timer immediately after it fires and must
+ // be held when attempting to stop the timer.
+ //
+ // Must never change after being assigned.
+ locker sync.Locker
+
+ // fn is the function that will be called when a timer fires and has not been
+ // signaled to early return.
+ //
+ // fn MUST NOT attempt to lock locker.
+ //
+ // Must never change after being assigned.
+ fn func()
+}
+
+// StopLocked prevents the Timer from firing if it has not fired already.
+//
+// If the timer is blocked on obtaining the t.locker lock when StopLocked is
+// called, it will early return instead of calling t.fn.
+//
+// Note, t will be modified.
+//
+// t.locker MUST be locked.
+func (t *CancellableTimer) StopLocked() {
+ t.instance.stop()
+
+ // Nothing to do with the stopped instance anymore.
+ t.instance = cancellableTimerInstance{}
+}
+
+// Reset changes the timer to expire after duration d.
+//
+// Note, t will be modified.
+//
+// Reset should only be called on stopped or expired timers. To be safe, callers
+// should always call StopLocked before calling Reset.
+func (t *CancellableTimer) Reset(d time.Duration) {
+ // Create a new instance.
+ earlyReturn := false
+ t.instance = cancellableTimerInstance{
+ timer: time.AfterFunc(d, func() {
+ t.locker.Lock()
+ defer t.locker.Unlock()
+
+ if earlyReturn {
+ // If we reach this point, it means that the timer fired while another
+ // goroutine called StopLocked while it had the lock. Simply return
+ // here and do nothing further.
+ earlyReturn = false
+ return
+ }
+
+ t.fn()
+ }),
+ earlyReturn: &earlyReturn,
+ }
+}
+
+// MakeCancellableTimer returns an unscheduled CancellableTimer with the given
+// locker and fn.
+//
+// fn MUST NOT attempt to lock locker.
+//
+// Callers must call Reset to schedule the timer to fire.
+func MakeCancellableTimer(locker sync.Locker, fn func()) CancellableTimer {
+ return CancellableTimer{locker: locker, fn: fn}
+}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
new file mode 100644
index 000000000..2d20f7ef3
--- /dev/null
+++ b/pkg/tcpip/timer_test.go
@@ -0,0 +1,236 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip_test
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ shortDuration = 1 * time.Nanosecond
+ middleDuration = 100 * time.Millisecond
+ longDuration = 1 * time.Second
+)
+
+func TestCancellableTimerFire(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ timer := tcpip.MakeCancellableTimer(&lock, func() {
+ ch <- struct{}{}
+ })
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromLongDuration(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(middleDuration)
+
+ lock.Lock()
+ timer.StopLocked()
+ lock.Unlock()
+
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+
+ timer.Reset(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerImmediatelyStop(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ for i := 0; i < 1000; i++ {
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+ }
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ timer.StopLocked()
+ lock.Unlock()
+
+ for i := 0; i < 10; i++ {
+ timer.Reset(middleDuration)
+
+ lock.Lock()
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration * 2)
+ timer.StopLocked()
+ lock.Unlock()
+ }
+
+ // Wait for double the duration so timers that weren't correctly stopped can
+ // fire.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration * 2):
+ }
+}
+
+func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ for i := 0; i < 10; i++ {
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration)
+ timer.StopLocked()
+ timer.Reset(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan struct{})
+ var lock sync.Mutex
+
+ lock.Lock()
+ timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
+ timer.Reset(shortDuration)
+ for i := 0; i < 10; i++ {
+ timer.StopLocked()
+ timer.Reset(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index d8c5b5058..3aa23d529 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -28,6 +28,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 9c40931b5..42afb3f5b 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,8 +15,7 @@
package icmp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -289,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
toCopy := *to
to = &toCopy
- netProto, err := e.checkV4Mapped(to, true)
+ netProto, err := e.checkV4Mapped(to)
if err != nil {
return 0, nil, err
}
@@ -350,13 +349,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// SetSockOptBool sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ return nil
+}
+
// SetSockOptInt sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -466,18 +475,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
})
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if header.IsV4MappedAddress(addr.Addr) {
- return 0, tcpip.ErrNoRoute
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */)
+ if err != nil {
+ return 0, err
}
-
+ *addr = unwrapped
return netProto, nil
}
@@ -509,7 +512,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, false)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
@@ -622,7 +625,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, false)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
index 44b58ff6b..4858d150c 100644
--- a/pkg/tcpip/transport/packet/BUILD
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -28,6 +28,7 @@ go_library(
deps = [
"//pkg/log",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 0010b5e5f..fc5bc69fa 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,8 +25,7 @@
package packet
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -247,17 +246,17 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
+ return tcpip.ErrUnknownProtocolOption
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -265,6 +264,16 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrNotSupported
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
ep.rcvMu.Lock()
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 00991ac8e..2f2131ff7 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -29,6 +29,7 @@ go_library(
deps = [
"//pkg/log",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 5aafe2615..ee9c4c58b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,8 +26,7 @@
package raw
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -509,13 +508,38 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -544,21 +568,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) {
e.rcvMu.Lock()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 3b353d56c..0e3ab05ad 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -16,6 +16,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tcp_endpoint_list",
+ out = "tcp_endpoint_list.go",
+ package = "tcp",
+ prefix = "endpoint",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*endpoint",
+ "Linker": "*endpoint",
+ },
+)
+
go_library(
name = "tcp",
srcs = [
@@ -23,6 +35,7 @@ go_library(
"connect.go",
"cubic.go",
"cubic_state.go",
+ "dispatcher.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
@@ -38,6 +51,7 @@ go_library(
"segment_state.go",
"snd.go",
"snd_state.go",
+ "tcp_endpoint_list.go",
"tcp_segment_list.go",
"timer.go",
],
@@ -45,9 +59,9 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
- "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 5422ae80c..1a2e3efa9 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -19,11 +19,11 @@ import (
"encoding/binary"
"hash"
"io"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// listenEP is nil when listenContext is used by tcp.Forwarder.
if l.listenEP != nil {
l.listenEP.mu.Lock()
- if l.listenEP.state != StateListen {
+ if l.listenEP.EndpointState() != StateListen {
l.listenEP.mu.Unlock()
return nil, tcpip.ErrConnectionAborted
}
@@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() {
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.state
+ state := e.EndpointState()
e.pendingAccepted.Add(1)
defer e.pendingAccepted.Done()
acceptedChan := e.acceptedChan
e.mu.Unlock()
+
if state == StateListen {
acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
@@ -562,8 +563,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// We do not use transitionToStateEstablishedLocked here as there is
// no handshake state available when doing a SYN cookie based accept.
n.stack.Stats().TCP.CurrentEstablished.Increment()
- n.state = StateEstablished
n.isConnectNotified = true
+ n.setEndpointState(StateEstablished)
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -596,7 +597,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
e.mu.Lock()
- e.state = StateClose
+ e.setEndpointState(StateClose)
// close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index cdd69f360..a2f384384 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -16,11 +16,11 @@ package tcp
import (
"encoding/binary"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.mss = opts.MSS
h.sndWndScale = opts.WS
h.ep.mu.Lock()
- h.ep.state = StateSynRecv
+ h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
}
@@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// SYN-RCVD state.
h.state = handshakeSynRcvd
h.ep.mu.Lock()
- h.ep.state = StateSynRecv
ttl := h.ep.ttl
+ h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
// permits SACK. This is not explicitly defined in the RFC but
@@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
WS: h.rcvWndScale,
TS: h.ep.sendTSOk,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
@@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error {
WS: h.rcvWndScale,
TS: true,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
@@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
}
if e.sackPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error {
}
func (e *endpoint) handleClose() *tcpip.Error {
+ if !e.EndpointState().connected() {
+ return nil
+ }
// Drain the send queue.
e.handleWrite()
@@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error {
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
- if e.state == StateEstablished || e.state == StateCloseWait {
- e.stack.Stats().TCP.EstablishedResets.Increment()
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
- }
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
@@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
}
// completeWorkerLocked is called by the worker goroutine when it's about to
-// exit. It marks the worker as completed and performs cleanup work if requested
-// by Close().
+// exit.
func (e *endpoint) completeWorkerLocked() {
+ // Worker is terminating(either due to moving to
+ // CLOSED or ERROR state, ensure we release all
+ // registrations port reservations even if the socket
+ // itself is not yet closed by the application.
e.workerRunning = false
if e.workerCleanup {
e.cleanupLocked()
@@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
e.rcvAutoParams.prevCopied = int(h.rcvWnd)
e.rcvListMu.Unlock()
}
- h.ep.stack.Stats().TCP.CurrentEstablished.Increment()
- e.state = StateEstablished
+ e.setEndpointState(StateEstablished)
}
// transitionToStateCloseLocked ensures that the endpoint is
@@ -927,11 +928,12 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.state == StateClose {
+ if e.EndpointState() == StateClose {
return
}
+ // Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
- e.state = StateClose
+ e.setEndpointState(StateClose)
e.stack.Stats().TCP.EstablishedClosed.Increment()
}
@@ -946,7 +948,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
s.decRef()
return
}
- ep.(*endpoint).enqueueSegment(s)
+ if ep.(*endpoint).enqueueSegment(s) {
+ ep.(*endpoint).newSegmentWaker.Assert()
+ }
}
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
@@ -955,9 +959,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
- s.decRef()
e.mu.Lock()
- switch e.state {
+ switch e.EndpointState() {
// In case of a RST in CLOSE-WAIT linux moves
// the socket to closed state with an error set
// to indicate EPIPE.
@@ -981,103 +984,53 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
e.mu.Unlock()
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+
+ // Notify protocol goroutine. This is required when
+ // handleSegment is invoked from the processor goroutine
+ // rather than the worker goroutine.
+ e.notifyProtocolGoroutine(notifyResetByPeer)
return false, tcpip.ErrConnectionReset
}
}
return true, nil
}
-// handleSegments pulls segments from the queue and processes them. It returns
-// no error if the protocol loop should continue, an error otherwise.
-func (e *endpoint) handleSegments() *tcpip.Error {
+// handleSegments processes all inbound segments.
+func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
+ if e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ return nil
+ }
s := e.segmentQueue.dequeue()
if s == nil {
checkRequeue = false
break
}
- // Invoke the tcp probe if installed.
- if e.probe != nil {
- e.probe(e.completeState())
+ cont, err := e.handleSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
}
-
- if s.flagIsSet(header.TCPFlagRst) {
- if ok, err := e.handleReset(s); !ok {
- return err
- }
- } else if s.flagIsSet(header.TCPFlagSyn) {
- // See: https://tools.ietf.org/html/rfc5961#section-4.1
- // 1) If the SYN bit is set, irrespective of the sequence number, TCP
- // MUST send an ACK (also referred to as challenge ACK) to the remote
- // peer:
- //
- // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
- //
- // After sending the acknowledgment, TCP MUST drop the unacceptable
- // segment and stop processing further.
- //
- // By sending an ACK, the remote peer is challenged to confirm the loss
- // of the previous connection and the request to start a new connection.
- // A legitimate peer, after restart, would not have a TCB in the
- // synchronized state. Thus, when the ACK arrives, the peer should send
- // a RST segment back with the sequence number derived from the ACK
- // field that caused the RST.
-
- // This RST will confirm that the remote peer has indeed closed the
- // previous connection. Upon receipt of a valid RST, the local TCP
- // endpoint MUST terminate its connection. The local TCP endpoint
- // should then rely on SYN retransmission from the remote end to
- // re-establish the connection.
-
- e.snd.sendAck()
- } else if s.flagIsSet(header.TCPFlagAck) {
- // Patch the window size in the segment according to the
- // send window scale.
- s.window <<= e.snd.sndWndScale
-
- // RFC 793, page 41 states that "once in the ESTABLISHED
- // state all segments must carry current acknowledgment
- // information."
- drop, err := e.rcv.handleRcvdSegment(s)
- if err != nil {
- s.decRef()
- return err
- }
- if drop {
- s.decRef()
- continue
- }
-
- // Now check if the received segment has caused us to transition
- // to a CLOSED state, if yes then terminate processing and do
- // not invoke the sender.
- e.mu.RLock()
- state := e.state
- e.mu.RUnlock()
- if state == StateClose {
- // When we get into StateClose while processing from the queue,
- // return immediately and let the protocolMainloop handle it.
- //
- // We can reach StateClose only while processing a previous segment
- // or a notification from the protocolMainLoop (caller goroutine).
- // This means that with this return, the segment dequeue below can
- // never occur on a closed endpoint.
- s.decRef()
- return nil
- }
- e.snd.handleRcvdSegment(s)
+ if !cont {
+ s.decRef()
+ return nil
}
- s.decRef()
}
- // If the queue is not empty, make sure we'll wake up in the next
- // iteration.
- if checkRequeue && !e.segmentQueue.empty() {
+ // When fastPath is true we don't want to wake up the worker
+ // goroutine. If the endpoint has more segments to process the
+ // dispatcher will call handleSegments again anyway.
+ if !fastPath && checkRequeue && !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
@@ -1086,11 +1039,88 @@ func (e *endpoint) handleSegments() *tcpip.Error {
e.snd.sendAck()
}
- e.resetKeepaliveTimer(true)
+ e.resetKeepaliveTimer(true /* receivedData */)
return nil
}
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if ok, err := e.handleReset(s); !ok {
+ return false, err
+ }
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ return false, err
+ }
+ if drop {
+ return true, nil
+ }
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return false, nil
+ }
+
+ e.snd.handleRcvdSegment(s)
+ }
+
+ return true, nil
+}
+
// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
@@ -1160,7 +1190,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1182,6 +1212,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.mu.Unlock()
+ e.workMu.Unlock()
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1193,7 +1224,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
initialRcvWnd := e.initialReceiveWindow()
h := newHandshake(e, seqnum.Size(initialRcvWnd))
e.mu.Lock()
- h.ep.state = StateSynSent
+ h.ep.setEndpointState(StateSynSent)
e.mu.Unlock()
if err := h.execute(); err != nil {
@@ -1202,12 +1233,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.mu.Lock()
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
// Lock released below.
epilogue()
-
return err
}
}
@@ -1215,7 +1245,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.keepalive.timer.init(&e.keepalive.waker)
defer e.keepalive.timer.cleanup()
- // Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
drained := e.drainDone != nil
e.mu.Unlock()
@@ -1224,8 +1253,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
<-e.undrain
}
- e.waiterQueue.Notify(waiter.EventOut)
-
// Set up the functions that will be called when the main protocol loop
// wakes up.
funcs := []struct {
@@ -1241,17 +1268,14 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
f: e.handleClose,
},
{
- w: &e.newSegmentWaker,
- f: e.handleSegments,
- },
- {
w: &closeWaker,
f: func() *tcpip.Error {
// This means the socket is being closed due
- // to the TCP_FIN_WAIT2 timeout was hit. Just
+ // to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
e.mu.Lock()
e.transitionToStateCloseLocked()
+ e.workerCleanup = true
e.mu.Unlock()
return nil
},
@@ -1267,6 +1291,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
},
},
{
+ w: &e.newSegmentWaker,
+ f: func() *tcpip.Error {
+ return e.handleSegments(false /* fastPath */)
+ },
+ },
+ {
w: &e.keepalive.waker,
f: e.keepaliveTimerExpired,
},
@@ -1293,14 +1323,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
if n&notifyReset != 0 {
- e.mu.Lock()
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.mu.Unlock()
+ return tcpip.ErrConnectionAborted
+ }
+
+ if n&notifyResetByPeer != 0 {
+ return tcpip.ErrConnectionReset
}
if n&notifyClose != 0 && closeTimer == nil {
e.mu.Lock()
- if e.state == StateFinWait2 && e.closed {
+ if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
@@ -1320,11 +1352,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
- if err := e.handleSegments(); err != nil {
+ if err := e.handleSegments(false /* fastPath */); err != nil {
return err
}
}
- if e.state != StateClose && e.state != StateError {
+ if e.EndpointState() != StateClose && e.EndpointState() != StateError {
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
@@ -1349,14 +1381,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.AddWaker(funcs[i].w, i)
}
+ // Notify the caller that the waker initialization is complete and the
+ // endpoint is ready.
+ if wakerInitDone != nil {
+ close(wakerInitDone)
+ }
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.waiterQueue.Notify(waiter.EventOut)
+
// The following assertions and notifications are needed for restored
// endpoints. Fresh newly created endpoints have empty states and should
// not invoke any.
- e.segmentQueue.mu.Lock()
- if !e.segmentQueue.list.Empty() {
+ if !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
- e.segmentQueue.mu.Unlock()
e.rcvListMu.Lock()
if !e.rcvList.Empty() {
@@ -1371,28 +1410,53 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
+ cleanupOnError := func(err *tcpip.Error) {
+ e.mu.Lock()
+ e.workerCleanup = true
+ if err != nil {
+ e.resetConnectionLocked(err)
+ }
+ // Lock released below.
+ epilogue()
+ }
- for e.state != StateTimeWait && e.state != StateClose && e.state != StateError {
+loop:
+ for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
e.mu.Unlock()
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
- if err := funcs[v].f(); err != nil {
- e.mu.Lock()
- // Ensure we release all endpoint registration and route
- // references as the connection is now in an error
- // state.
- e.workerCleanup = true
- e.resetConnectionLocked(err)
- // Lock released below.
- epilogue()
+ // We need to double check here because the notification maybe
+ // stale by the time we got around to processing it.
+ //
+ // NOTE: since we now hold the workMu the processors cannot
+ // change the state of the endpoint so it's safe to proceed
+ // after this check.
+ switch e.EndpointState() {
+ case StateError:
+ // If the endpoint has already transitioned to an ERROR
+ // state just pass nil here as any reset that may need
+ // to be sent etc should already have been done and we
+ // just want to terminate the loop and cleanup the
+ // endpoint.
+ cleanupOnError(nil)
return nil
+ case StateTimeWait:
+ fallthrough
+ case StateClose:
+ e.mu.Lock()
+ break loop
+ default:
+ if err := funcs[v].f(); err != nil {
+ cleanupOnError(err)
+ return nil
+ }
+ e.mu.Lock()
}
- e.mu.Lock()
}
- state := e.state
+ state := e.EndpointState()
e.mu.Unlock()
var reuseTW func()
if state == StateTimeWait {
@@ -1405,13 +1469,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.Done()
// Wake up any waiters before we enter TIME_WAIT.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.mu.Lock()
+ e.workerCleanup = true
+ e.mu.Unlock()
reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
e.mu.Lock()
- if e.state != StateError {
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ if e.EndpointState() != StateError {
e.transitionToStateCloseLocked()
}
@@ -1468,7 +1534,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
tcpEP := listenEP.(*endpoint)
if EndpointState(tcpEP.State()) == StateListen {
reuseTW = func() {
- tcpEP.enqueueSegment(s)
+ if !tcpEP.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+ tcpEP.newSegmentWaker.Assert()
}
// We explicitly do not decRef
// the segment as it's still
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
new file mode 100644
index 000000000..e18012ac0
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -0,0 +1,224 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// epQueue is a queue of endpoints.
+type epQueue struct {
+ mu sync.Mutex
+ list endpointList
+}
+
+// enqueue adds e to the queue if the endpoint is not already on the queue.
+func (q *epQueue) enqueue(e *endpoint) {
+ q.mu.Lock()
+ if e.pendingProcessing {
+ q.mu.Unlock()
+ return
+ }
+ q.list.PushBack(e)
+ e.pendingProcessing = true
+ q.mu.Unlock()
+}
+
+// dequeue removes and returns the first element from the queue if available,
+// returns nil otherwise.
+func (q *epQueue) dequeue() *endpoint {
+ q.mu.Lock()
+ if e := q.list.Front(); e != nil {
+ q.list.Remove(e)
+ e.pendingProcessing = false
+ q.mu.Unlock()
+ return e
+ }
+ q.mu.Unlock()
+ return nil
+}
+
+// empty returns true if the queue is empty, false otherwise.
+func (q *epQueue) empty() bool {
+ q.mu.Lock()
+ v := q.list.Empty()
+ q.mu.Unlock()
+ return v
+}
+
+// processor is responsible for processing packets queued to a tcp endpoint.
+type processor struct {
+ epQ epQueue
+ newEndpointWaker sleep.Waker
+ id int
+}
+
+func newProcessor(id int) *processor {
+ p := &processor{
+ id: id,
+ }
+ go p.handleSegments()
+ return p
+}
+
+func (p *processor) queueEndpoint(ep *endpoint) {
+ // Queue an endpoint for processing by the processor goroutine.
+ p.epQ.enqueue(ep)
+ p.newEndpointWaker.Assert()
+}
+
+func (p *processor) handleSegments() {
+ const newEndpointWaker = 1
+ s := sleep.Sleeper{}
+ s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ defer s.Done()
+ for {
+ s.Fetch(true)
+ for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ if ep.segmentQueue.empty() {
+ continue
+ }
+
+ // If socket has transitioned out of connected state
+ // then just let the worker handle the packet.
+ //
+ // NOTE: We read this outside of e.mu lock which means
+ // that by the time we get to handleSegments the
+ // endpoint may not be in ESTABLISHED. But this should
+ // be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a
+ // CLOSED/ERROR state then handleSegments is a noop.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ continue
+ }
+
+ if !ep.workMu.TryLock() {
+ ep.newSegmentWaker.Assert()
+ continue
+ }
+ // If the endpoint is in a connected state then we do
+ // direct delivery to ensure low latency and avoid
+ // scheduler interactions.
+ if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
+ // Send any active resets if required.
+ if err != nil {
+ ep.mu.Lock()
+ ep.resetConnectionLocked(err)
+ ep.mu.Unlock()
+ }
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ ep.workMu.Unlock()
+ continue
+ }
+
+ if !ep.segmentQueue.empty() {
+ p.epQ.enqueue(ep)
+ }
+
+ ep.workMu.Unlock()
+ }
+ }
+}
+
+// dispatcher manages a pool of TCP endpoint processors which are responsible
+// for the processing of inbound segments. This fixed pool of processor
+// goroutines do full tcp processing. The processor is selected based on the
+// hash of the endpoint id to ensure that delivery for the same endpoint happens
+// in-order.
+type dispatcher struct {
+ processors []*processor
+ seed uint32
+}
+
+func newDispatcher(nProcessors int) *dispatcher {
+ processors := []*processor{}
+ for i := 0; i < nProcessors; i++ {
+ processors = append(processors, newProcessor(i))
+ }
+ return &dispatcher{
+ processors: processors,
+ seed: generateRandUint32(),
+ }
+}
+
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ ep := stackEP.(*endpoint)
+ s := newSegment(r, id, pkt)
+ if !s.parse() {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.ChecksumErrors.Increment()
+ ep.stats.ReceiveErrors.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ ep.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ ep.stats.SegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ ep.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ if !ep.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+
+ // For sockets not in established state let the worker goroutine
+ // handle the packets.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ return
+ }
+
+ d.selectProcessor(id).queueEndpoint(ep)
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8)}
+
+ h := jenkins.Sum32(d.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+
+ return d.processors[h.Sum32()%uint32(len(d.processors))]
+}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index dfaa4a559..4f361b226 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) {
// Make sure we get the same error when calling the original ep and the
// new one. This validates that v4-mapped endpoints are still able to
// query the V6Only flag, whereas pure v4 endpoints are not.
- var v tcpip.V6OnlyOption
- expected := c.EP.GetSockOpt(&v)
- if err := nep.GetSockOpt(&v); err != expected {
+ _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
}
@@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) {
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
- var v tcpip.V6OnlyOption
- if err := nep.GetSockOpt(&v); err != nil {
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
t.Fatalf("GetSockOpt failed failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index fe629aa40..4797f11d1 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -19,12 +19,12 @@ import (
"fmt"
"math"
"strings"
- "sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -120,6 +120,7 @@ const (
notifyMTUChanged
notifyDrain
notifyReset
+ notifyResetByPeer
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -127,6 +128,7 @@ const (
// ensures the loop terminates if the final state of the endpoint is
// say TIME_WAIT.
notifyTickleWorker
+ notifyError
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {}
type endpoint struct {
EndpointInfo
+ // endpointEntry is used to queue endpoints for processing to the
+ // a given tcp processor goroutine.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ endpointEntry `state:"nosave"`
+
+ // pendingProcessing is true if this endpoint is queued for processing
+ // to a TCP processor.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ pendingProcessing bool `state:"nosave"`
+
// workMu is used to arbitrate which goroutine may perform protocol
// work. Only the main protocol goroutine is expected to call Lock() on
// it, but other goroutines (e.g., send) may call TryLock() to eagerly
@@ -324,6 +338,7 @@ type endpoint struct {
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
+ // state must be read/set using the EndpointState()/setEndpointState() methods.
state EndpointState `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
@@ -359,7 +374,7 @@ type endpoint struct {
workerRunning bool
// workerCleanup specifies if the worker goroutine must perform cleanup
- // before exitting. This can only be set to true when workerRunning is
+ // before exiting. This can only be set to true when workerRunning is
// also true, and they're both protected by the mutex.
workerCleanup bool
@@ -371,6 +386,8 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
+ //
+ // recentTS must be read/written atomically.
recentTS uint32
// tsOffset is a randomized offset added to the value of the
@@ -567,6 +584,47 @@ func (e *endpoint) ResumeWork() {
e.workMu.Unlock()
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ switch state {
+ case StateEstablished:
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
+ case StateError:
+ fallthrough
+ case StateClose:
+ if oldstate == StateCloseWait || oldstate == StateEstablished {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ }
+ fallthrough
+ default:
+ if oldstate == StateEstablished {
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
+ }
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
+// setRecentTimestamp atomically sets the recentTS field to the
+// provided value.
+func (e *endpoint) setRecentTimestamp(recentTS uint32) {
+ atomic.StoreUint32(&e.recentTS, recentTS)
+}
+
+// recentTimestamp atomically reads and returns the value of the recentTS field.
+func (e *endpoint) recentTimestamp() uint32 {
+ return atomic.LoadUint32(&e.recentTS)
+}
+
// keepalive is a synchronization wrapper used to appease stateify. See the
// comment in endpoint, where it is used.
//
@@ -656,7 +714,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
e.mu.RLock()
defer e.mu.RUnlock()
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
@@ -672,7 +730,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
}
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
@@ -733,14 +791,20 @@ func (e *endpoint) Close() {
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdown()
+}
+// closeNoShutdown closes the endpoint without doing a full shutdown. This is
+// used when a connection needs to be aborted with a RST and we want to skip
+// a full 4 way TCP shutdown.
+func (e *endpoint) closeNoShutdown() {
e.mu.Lock()
// For listening sockets, we always release ports inline so that they
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
// in Listen() when trying to register.
- if e.state == StateListen && e.isPortReserved {
+ if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
@@ -780,6 +844,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
defer close(done)
for n := range e.acceptedChan {
n.notifyProtocolGoroutine(notifyReset)
+ // close all connections that have completed but
+ // not accepted by the application.
n.Close()
}
}()
@@ -797,11 +863,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
// after Close() is called and the worker goroutine (if any) is done with its
// work.
func (e *endpoint) cleanupLocked() {
+
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
e.closePendingAcceptableConnectionsLocked()
}
+
e.workerCleanup = false
if e.isRegistered {
@@ -885,8 +953,14 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
+ availBefore := e.receiveBufferAvailableLocked()
e.rcvBufSize = rcvWnd
- e.notifyProtocolGoroutine(notifyReceiveWindowChanged)
+ availAfter := e.receiveBufferAvailableLocked()
+ mask := uint32(notifyReceiveWindowChanged)
+ if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -914,7 +988,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
bufUsed := e.rcvBufUsed
- if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
+ if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
e.mu.RUnlock()
@@ -938,7 +1012,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
@@ -955,11 +1029,12 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
}
e.rcvBufUsed -= len(v)
- // If the window was zero before this read and if the read freed up
- // enough buffer space for the scaled window to be non-zero then notify
- // the protocol goroutine to send a window update.
- if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) {
- e.zeroWindow = false
+
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -973,8 +1048,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
- if !e.state.connected() {
- switch e.state {
+ if !e.EndpointState().connected() {
+ switch e.EndpointState() {
case StateError:
return 0, e.HardError
default:
@@ -1032,42 +1107,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, perr
}
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if opts.Atomic {
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
+ }
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
+ if e.workMu.TryLock() {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
}
- }
-
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
-
- if e.workMu.TryLock() {
// Do the work inline.
e.handleWrite()
e.workMu.Unlock()
} else {
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
+
+ }
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
@@ -1084,7 +1203,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
- if s := e.state; !s.connected() && s != StateClose {
+ if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
return 0, tcpip.ControlMessages{}, e.HardError
}
@@ -1096,7 +1215,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
defer e.rcvListMu.Unlock()
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
e.stats.ReadErrors.ReadClosed.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
@@ -1133,20 +1252,65 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// zeroReceiveWindow checks if the receive window to be announced now would be
-// zero, based on the amount of available buffer and the receive window scaling.
+// windowCrossedACKThreshold checks if the receive window to be announced now
+// would be under aMSS or under half receive buffer, whichever smaller. This is
+// useful as a receive side silly window syndrome prevention mechanism. If
+// window grows to reasonable value, we should send ACK to the sender to inform
+// the rx space is now large. We also want ensure a series of small read()'s
+// won't trigger a flood of spurious tiny ACK's.
//
-// It must be called with rcvListMu held.
-func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
- if e.rcvBufUsed >= e.rcvBufSize {
- return true
+// For large receive buffers, the threshold is aMSS - once reader reads more
+// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
+// receive buffer size. This is chosen arbitrairly.
+// crossed will be true if the window size crossed the ACK threshold.
+// above will be true if the new window is >= ACK threshold and false
+// otherwise.
+func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
+ newAvail := e.receiveBufferAvailableLocked()
+ oldAvail := newAvail - deltaBefore
+ if oldAvail < 0 {
+ oldAvail = 0
+ }
+
+ threshold := int(e.amss)
+ if threshold > e.rcvBufSize/2 {
+ threshold = e.rcvBufSize / 2
+ }
+
+ switch {
+ case oldAvail < threshold && newAvail >= threshold:
+ return true, true
+ case oldAvail >= threshold && newAvail < threshold:
+ return true, false
}
+ return false, false
+}
+
+// SetSockOptBool sets a socket option.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
- return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
+ // We only allow this to be set when we're in the initial state.
+ if e.EndpointState() != StateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v
+ }
+
+ return nil
}
// SetSockOptInt sets a socket option.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
@@ -1181,10 +1345,16 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
size = math.MaxInt32 / 2
}
+ availBefore := e.receiveBufferAvailableLocked()
e.rcvBufSize = size
+ availAfter := e.receiveBufferAvailableLocked()
+
e.rcvAutoParams.disabled = true
- if e.zeroWindow && !e.zeroReceiveWindow(scale) {
- e.zeroWindow = false
+
+ // Immediately send an ACK to uncork the sender silly window
+ // syndrome prevetion, when our available space grows above aMSS
+ // or half receive buffer, whichever smaller.
+ if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
@@ -1256,19 +1426,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.QuickAckOption:
if v == 0 {
@@ -1289,23 +1454,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.v6only = v != 0
- return nil
-
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
@@ -1366,14 +1514,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// Acquire the work mutex as we may need to
// reinitialize the congestion control state.
e.mu.Lock()
- state := e.state
+ state := e.EndpointState()
e.cc = v
e.mu.Unlock()
switch state {
case StateEstablished:
e.workMu.Lock()
e.mu.Lock()
- if e.state == state {
+ if e.EndpointState() == state {
e.snd.cc = e.snd.initCongestionControl(e.cc)
}
e.mu.Unlock()
@@ -1436,7 +1584,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
defer e.mu.RUnlock()
// The endpoint cannot be in listen state.
- if e.state == StateListen {
+ if e.EndpointState() == StateListen {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1446,8 +1594,27 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ return v, nil
+ }
+
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1525,12 +1692,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = ""
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.QuickAckOption:
@@ -1540,22 +1703,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.TTLOption:
e.mu.Lock()
*o = tcpip.TTLOption(e.ttl)
@@ -1656,26 +1803,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if header.IsV4MappedAddress(addr.Addr) {
- // Fail if using a v4 mapped address on a v6only endpoint.
- if e.v6only {
- return 0, tcpip.ErrNoRoute
- }
-
- netProto = header.IPv4ProtocolNumber
- addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == header.IPv4Any {
- addr.Addr = ""
- }
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+ if err != nil {
+ return 0, err
}
-
+ *addr = unwrapped
return netProto, nil
}
@@ -1711,7 +1843,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return err
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// The endpoint is already connected. If caller hasn't been
// notified yet, return success.
if !e.isConnectNotified {
@@ -1723,7 +1855,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
nicID := addr.NIC
- switch e.state {
+ switch e.EndpointState() {
case StateBound:
// If we're already bound to a NIC but the caller is requesting
// that we use a different one now, we cannot proceed.
@@ -1830,7 +1962,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.isRegistered = true
- e.state = StateConnecting
+ e.setEndpointState(StateConnecting)
e.route = r.Clone()
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
@@ -1851,14 +1983,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
- e.state = StateEstablished
- e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.setEndpointState(StateEstablished)
}
if run {
e.workerRunning = true
e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
}
return tcpip.ErrConnectStarted
@@ -1876,7 +2007,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.shutdownFlags |= flags
finQueued := false
switch {
- case e.state.connected():
+ case e.EndpointState().connected():
// Close for read.
if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
// Mark read side as closed.
@@ -1888,8 +2019,18 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
- e.notifyProtocolGoroutine(notifyReset)
e.mu.Unlock()
+ // Try to send an active reset immediately if the
+ // work mutex is available.
+ if e.workMu.TryLock() {
+ e.mu.Lock()
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ } else {
+ e.notifyProtocolGoroutine(notifyReset)
+ }
return nil
}
}
@@ -1911,11 +2052,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
-
e.sndBufMu.Unlock()
}
- case e.state == StateListen:
+ case e.EndpointState() == StateListen:
// Tell protocolListenLoop to stop.
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
@@ -1956,7 +2096,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// When the endpoint shuts down, it sets workerCleanup to true, and from
// that point onward, acceptedChan is the responsibility of the cleanup()
// method (and should not be touched anywhere else, including here).
- if e.state == StateListen && !e.workerCleanup {
+ if e.EndpointState() == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
if len(e.acceptedChan) > backlog {
@@ -1974,7 +2114,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
return nil
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
// The listen is called on an unbound socket, the socket is
// automatically bound to a random free port with the local
// address set to INADDR_ANY.
@@ -1984,7 +2124,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Endpoint must be bound before it can transition to listen mode.
- if e.state != StateBound {
+ if e.EndpointState() != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
return tcpip.ErrInvalidEndpointState
}
@@ -1995,24 +2135,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
e.isRegistered = true
- e.state = StateListen
+ e.setEndpointState(StateListen)
+
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
e.workerRunning = true
-
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
-
return nil
}
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+ e.mu.Lock()
e.waiterQueue = waiterQueue
e.workerRunning = true
- go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
+ e.mu.Unlock()
+ wakerInitDone := make(chan struct{})
+ go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save.
+ <-wakerInitDone
}
// Accept returns a new endpoint if a peer has established a connection
@@ -2022,7 +2165,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections.
- if e.state != StateListen {
+ if e.EndpointState() != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
@@ -2049,7 +2192,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrAlreadyBound
}
@@ -2111,7 +2254,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
return nil
}
@@ -2133,7 +2276,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if !e.state.connected() {
+ if !e.EndpointState().connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -2144,45 +2287,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
- s := newSegment(r, id, pkt)
- if !s.parse() {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
- e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
- s.decRef()
- return
- }
-
- if !s.csumValid {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.ChecksumErrors.Increment()
- e.stats.ReceiveErrors.ChecksumErrors.Increment()
- s.decRef()
- return
- }
-
- e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
- e.stats.SegmentsReceived.Increment()
- if (s.flags & header.TCPFlagRst) != 0 {
- e.stack.Stats().TCP.ResetsReceived.Increment()
- }
-
- e.enqueueSegment(s)
+ // TCP HandlePacket is not required anymore as inbound packets first
+ // land at the Dispatcher which then can either delivery using the
+ // worker go routine or directly do the invoke the tcp processing inline
+ // based on the state of the endpoint.
}
-func (e *endpoint) enqueueSegment(s *segment) {
+func (e *endpoint) enqueueSegment(s *segment) bool {
// Send packet to worker goroutine.
- if e.segmentQueue.enqueue(s) {
- e.newSegmentWaker.Assert()
- } else {
+ if !e.segmentQueue.enqueue(s) {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
- s.decRef()
+ return false
}
+ return true
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
@@ -2225,13 +2345,10 @@ func (e *endpoint) readyToRead(s *segment) {
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
- // Check if the receive window is now closed. If so make sure
- // we set the zero window before we deliver the segment to ensure
- // that a subsequent read of the segment will correctly trigger
- // a non-zero notification.
- if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ // Increase counter if the receive window falls down below MSS
+ // or half receive buffer size, whichever smaller.
+ if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- e.zeroWindow = true
}
e.rcvList.PushBack(s)
} else {
@@ -2302,8 +2419,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int {
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
- if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
- e.recentTS = tsVal
+ if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.setRecentTimestamp(tsVal)
}
}
@@ -2313,7 +2430,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
e.sendTSOk = true
- e.recentTS = synOpts.TSVal
+ e.setRecentTimestamp(synOpts.TSVal)
}
}
@@ -2402,7 +2519,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
// Endpoint TCP Option state.
s.SendTSOk = e.sendTSOk
- s.RecentTS = e.recentTS
+ s.RecentTS = e.recentTimestamp()
s.TSOffset = e.tsOffset
s.SACKPermitted = e.sackPermitted
s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
@@ -2509,9 +2626,7 @@ func (e *endpoint) initGSO() {
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 7aa4c3f0e..4a46f0ec5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -16,9 +16,10 @@ package tcp
import (
"fmt"
- "sync"
+ "sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound:
// TODO(b/138137272): this enumeration duplicates
// EndpointState.connected. remove it.
@@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() {
fallthrough
case StateListen, StateConnecting:
e.drainSegmentLocked()
- if e.state != StateClose && e.state != StateError {
+ if e.EndpointState() != StateClose && e.EndpointState() != StateError {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
break
}
- fallthrough
case StateError, StateClose:
- for (e.state == StateError || e.state == StateClose) && e.workerRunning {
+ for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
}
if e.workerRunning {
- panic("endpoint still has worker running in closed or error state")
+ panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID))
}
default:
- panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
}
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
- if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) {
+ if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) {
panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
}
}
@@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
// saveState is invoked by stateify.
func (e *endpoint) saveState() EndpointState {
- return e.state
+ return e.EndpointState()
}
// Endpoint loading must be done in the following ordering by their state, to
@@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup
func (e *endpoint) loadState(state EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
- if state.connected() {
+ // For restore purposes we treat TimeWait like a connected endpoint.
+ if state.connected() || state == StateTimeWait {
connectedLoading.Add(1)
}
switch state {
@@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) {
case StateConnecting, StateSynSent, StateSynRecv:
connectingLoading.Add(1)
}
- e.state = state
+ // Directly update the state here rather than using e.setEndpointState
+ // as the endpoint is still being loaded and the stack reference to increment
+ // metrics is not yet initialized.
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
}
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- // Freeze segment queue before registering to prevent any segments
- // from being delivered while it is being restored.
e.origEndpointState = e.state
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
@@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
state := e.origEndpointState
-
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
@@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Wait()
connectingLoading.Wait()
bind()
- e.state = StateClose
+ e.setEndpointState(StateClose)
tcpip.AsyncLoading.Done()
}()
}
@@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
+
}
// saveLastError is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 4983bca81..7eb613be5 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -15,8 +15,7 @@
package tcp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index bc718064c..958c06fa7 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -21,10 +21,11 @@
package tcp
import (
+ "runtime"
"strings"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -104,6 +105,7 @@ type protocol struct {
moderateReceiveBuffer bool
tcpLingerTimeout time.Duration
tcpTimeWaitTimeout time.Duration
+ dispatcher *dispatcher
}
// Number returns the tcp protocol number.
@@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
+// QueuePacket queues packets targeted at an endpoint after hashing the packet
+// to a specific processing queue. Each queue is serviced by its own processor
+// goroutine which is responsible for dequeuing and doing full TCP dispatch of
+// the packet.
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ p.dispatcher.queuePacket(r, ep, id, pkt)
+}
+
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
//
@@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol {
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 0a5534959..958f03ac1 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -98,12 +98,6 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// in such cases we may need to send an ack to indicate to our peer that it can
// resume sending data.
func (r *receiver) nonZeroWindow() {
- if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
- // We never got around to announcing a zero window size, so we
- // don't need to immediately announce a nonzero one.
- return
- }
-
// Immediately send an ack.
r.ep.snd.sendAck()
}
@@ -175,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// We just received a FIN, our next state depends on whether we sent a
// FIN already or not.
r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateEstablished:
- r.ep.state = StateCloseWait
+ r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
if s.flagIsSet(header.TCPFlagAck) {
// FIN-ACK, transition to TIME-WAIT.
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
} else {
// Simultaneous close, expecting a final ACK.
- r.ep.state = StateClosing
+ r.ep.setEndpointState(StateClosing)
}
case StateFinWait2:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
}
r.ep.mu.Unlock()
@@ -211,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// shutdown states.
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateFinWait1:
- r.ep.state = StateFinWait2
+ r.ep.setEndpointState(StateFinWait2)
// Notify protocol goroutine that we have received an
// ACK to our FIN so that it can start the FIN_WAIT2
// timer to abort connection if the other side does
// not close within 2MSL.
r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
case StateLastAck:
r.ep.transitionToStateCloseLocked()
}
@@ -273,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
switch state {
case StateCloseWait, StateClosing, StateLastAck:
if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
- s.decRef()
// Just drop the segment as we have
// already received a FIN and this
// segment is after the sequence number
@@ -290,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// trigger a RST.
endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
if rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
- s.decRef()
return true, tcpip.ErrConnectionAborted
}
if state == StateFinWait1 {
@@ -320,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// the last actual data octet in a segment in
// which it occurs.
if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
- s.decRef()
return true, tcpip.ErrConnectionAborted
}
}
@@ -342,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
r.ep.mu.RLock()
- state := r.ep.state
+ state := r.ep.EndpointState()
closed := r.ep.closed
r.ep.mu.RUnlock()
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index e0759225e..bd20a7ee9 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -15,7 +15,7 @@
package tcp
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// segmentQueue is a bounded, thread-safe queue of TCP segments.
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 8a947dc66..b74b61e7d 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -16,11 +16,11 @@ package tcp
import (
"math"
- "sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -442,6 +442,13 @@ func (s *sender) retransmitTimerExpired() bool {
return true
}
+ // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
+ // when writeList is empty. Remove this once we have a proper fix for this
+ // issue.
+ if s.writeList.Front() == nil {
+ return true
+ }
+
s.ep.stack.Stats().TCP.Timeouts.Increment()
s.ep.stats.SendErrors.Timeouts.Increment()
@@ -698,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
seg.flags = header.TCPFlagAck | header.TCPFlagFin
segEnd = seg.sequenceNumber.Add(1)
- // Transition to FIN-WAIT1 state since we're initiating an active close.
- s.ep.mu.Lock()
- switch s.ep.state {
+ // Update the state to reflect that we have now
+ // queued a FIN.
+ switch s.ep.EndpointState() {
case StateCloseWait:
- // We've already received a FIN and are now sending our own. The
- // sender is now awaiting a final ACK for this FIN.
- s.ep.state = StateLastAck
+ s.ep.setEndpointState(StateLastAck)
default:
- s.ep.state = StateFinWait1
+ s.ep.setEndpointState(StateFinWait1)
}
- s.ep.mu.Unlock()
+
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index e8fe4dab5..a9dfbe857 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -293,7 +293,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.SeqNum(uint32(c.IRS+1)),
checker.AckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
finHeaders := &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -459,6 +458,9 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence number
+ // of the FIN itself.
checker.SeqNum(uint32(c.IRS)+2),
checker.AckNum(0),
checker.TCPFlags(header.TCPFlagRst),
@@ -1083,12 +1085,12 @@ func TestTrafficClassV6(t *testing.T) {
func TestConnectBindToDevice(t *testing.T) {
for _, test := range []struct {
name string
- device string
+ device tcpip.NICID
want tcp.EndpointState
}{
- {"RightDevice", "nic1", tcp.StateEstablished},
- {"WrongDevice", "nic2", tcp.StateSynSent},
- {"AnyDevice", "", tcp.StateEstablished},
+ {"RightDevice", 1, tcp.StateEstablished},
+ {"WrongDevice", 2, tcp.StateSynSent},
+ {"AnyDevice", 0, tcp.StateEstablished},
} {
t.Run(test.name, func(t *testing.T) {
c := context.New(t, defaultMTU)
@@ -1500,6 +1502,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence
+ // number of the FIN itself.
checker.SeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
@@ -2091,10 +2096,14 @@ func TestZeroScaledWindowReceive(t *testing.T) {
)
}
- // Read some data. An ack should be sent in response to that.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
+ // Read at least 1MSS of data. An ack should be sent in response to that.
+ sz := 0
+ for sz < defaultMTU {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ sz += len(v)
}
checker.IPv4(t, c.GetPacket(),
@@ -2103,7 +2112,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(len(v)>>ws)),
+ checker.Window(uint16(sz>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3794,46 +3803,41 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
- }
-
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ if err := s.CreateNIC(321, loopback.New()); err != nil {
t.Errorf("CreateNIC failed: %v", err)
}
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -4027,12 +4031,12 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err)
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err)
}
case "dual":
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err)
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err)
}
default:
t.Fatalf("unknown network: '%s'", network)
@@ -5442,6 +5446,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
packetsSent++
}
+
// Resume the worker so that it only sees the packets once all of them
// are waiting to be read.
worker.ResumeWork()
@@ -5509,7 +5514,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
stk := c.Stack()
// Set lower limits for auto-tuning tests. This is required because the
// test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 500.
+ // the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
@@ -5564,6 +5569,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
totalSent += mss
packetsSent++
}
+
// Resume it so that it only sees the packets once all of them
// are waiting to be read.
worker.ResumeWork()
@@ -6561,3 +6567,140 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
}
}
+
+func TestIncreaseWindowOnReceive(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after recv() when the window grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // We now have < 1 MSS in the buffer space. Read the data! An
+ // ack should be sent in response to that. The window was not
+ // zero, but it grew to larger than MSS.
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestIncreaseWindowOnBufferResize(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after available recv buffer grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // Increasing the buffer from should generate an ACK,
+ // since window grew from small value to larger equal MSS
+ c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index b0a376eba..822907998 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -158,15 +158,17 @@ func New(t *testing.T, mtu uint32) *Context {
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
if testing.Verbose() {
wep2 = sniffer.New(channel.New(1000, mtu, ""))
}
- if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ opts2 := stack.NICOptions{Name: "nic2"}
+ if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
@@ -473,11 +475,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.EP.SetSockOpt(v); err != nil {
+ if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 97e4d5825..57ff123e3 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -30,6 +30,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1ac4705af..c9cbed8f4 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,8 +15,7 @@
package udp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -32,6 +31,7 @@ type udpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -114,6 +114,10 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
+ // receiveTOS determines if the incoming IPv4 TOS header field is passed
+ // as ancillary data to ControlMessages on Read.
+ receiveTOS bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -244,7 +248,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
*addr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ cm := tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ }
+ e.mu.RLock()
+ receiveTOS := e.receiveTOS
+ e.mu.RUnlock()
+ if receiveTOS {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+ return p.data.ToView(), cm, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -403,7 +418,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- netProto, err := e.checkV4Mapped(to, false)
+ netProto, err := e.checkV4Mapped(to)
if err != nil {
return 0, nil, err
}
@@ -456,14 +471,15 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
- return nil
-}
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.ReceiveTOSOption:
+ e.mu.Lock()
+ e.receiveTOS = v
+ e.mu.Unlock()
+ return nil
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -478,8 +494,20 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- e.v6only = v != 0
+ e.v6only = v
+ }
+ return nil
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
@@ -495,7 +523,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- netProto, err := e.checkV4Mapped(&fa, false)
+ netProto, err := e.checkV4Mapped(&fa)
if err != nil {
return err
}
@@ -624,19 +652,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.BroadcastOption:
e.mu.Lock()
@@ -660,8 +683,33 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveTOSOption:
+ e.mu.RLock()
+ v := e.receiveTOS
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.RLock()
+ v := e.v6only
+ e.mu.RUnlock()
+
+ return v, nil
+ }
+
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -695,22 +743,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.TTLOption:
e.mu.Lock()
*o = tcpip.TTLOption(e.ttl)
@@ -757,12 +789,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = tcpip.BindToDeviceOption("")
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.KeepaliveEnabledOption:
@@ -839,35 +867,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if len(addr.Addr) == 0 {
- return netProto, nil
- }
- if header.IsV4MappedAddress(addr.Addr) {
- // Fail if using a v4 mapped address on a v6only endpoint.
- if e.v6only {
- return 0, tcpip.ErrNoRoute
- }
-
- netProto = header.IPv4ProtocolNumber
- addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == header.IPv4Any {
- addr.Addr = ""
- }
-
- // Fail if we are bound to an IPv6 address.
- if !allowMismatch && len(e.ID.LocalAddress) == 16 {
- return 0, tcpip.ErrNetworkUnreachable
- }
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+ if err != nil {
+ return 0, err
}
-
+ *addr = unwrapped
return netProto, nil
}
@@ -916,7 +921,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- netProto, err := e.checkV4Mapped(&addr, false)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
@@ -1074,7 +1079,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, true)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
@@ -1238,6 +1243,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvList.PushBack(packet)
e.rcvBufSize += pkt.Data.Size()
+ // Save any useful information from the network header to the packet.
+ switch r.NetProto {
+ case header.IPv4ProtocolNumber:
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ }
+
packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 7051a7a9c..ee9d10555 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -56,6 +56,7 @@ const (
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
+ testTOS = 0x80
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -335,7 +336,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
} else if flow.isBroadcast() {
@@ -453,6 +454,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
+ TOS: testTOS,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
@@ -508,46 +510,42 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "my_device"}
+ if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
- }
-
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -556,8 +554,8 @@ func TestBindToDeviceOption(t *testing.T) {
// testReadInternal sends a packet of the given test flow into the stack by
// injecting it into the link endpoint. It then attempts to read it from the
// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
+// correctness including any additional checker functions provided.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
payload := newPayload()
@@ -572,12 +570,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
+ v, cm, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, _, err = c.ep.Read(&addr)
+ v, cm, err = c.ep.Read(&addr)
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -610,15 +608,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
if !bytes.Equal(payload, v) {
c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+
+ // Run any checkers against the ControlMessages.
+ for _, f := range checkers {
+ f(c.t, cm)
+ }
+
c.checkEndpointReadStats(1, epstats, err)
}
// testRead sends a packet of the given test flow into the stack by injecting it
// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness.
-func testRead(c *testContext, flow testFlow) {
+// its correctness including any additional checker functions provided.
+func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
}
// testFailingRead sends a packet of the given test flow into the stack by
@@ -1286,7 +1290,7 @@ func TestTOSV4(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1321,7 +1325,7 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1348,6 +1352,47 @@ func TestTOSV6(t *testing.T) {
}
}
+func TestReceiveTOSV4(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false)
+ }
+
+ want := true
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err)
+ }
+
+ got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
+ }
+ if got != want {
+ c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want)
+ }
+
+ // Verify that the correct received TOS is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
index 6afdb29b7..07778e4f7 100644
--- a/pkg/tmutex/BUILD
+++ b/pkg/tmutex/BUILD
@@ -15,4 +15,5 @@ go_test(
size = "medium",
srcs = ["tmutex_test.go"],
embed = [":tmutex"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go
index ce34c7962..05540696a 100644
--- a/pkg/tmutex/tmutex_test.go
+++ b/pkg/tmutex/tmutex_test.go
@@ -17,10 +17,11 @@ package tmutex
import (
"fmt"
"runtime"
- "sync"
"sync/atomic"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestBasicLock(t *testing.T) {
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index 8f6f180e5..d1885ae66 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -24,4 +24,5 @@ go_test(
"unet_test.go",
],
embed = [":unet"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go
index a3cc6f5d3..5c4b9e8e9 100644
--- a/pkg/unet/unet_test.go
+++ b/pkg/unet/unet_test.go
@@ -19,10 +19,11 @@ import (
"os"
"path/filepath"
"reflect"
- "sync"
"syscall"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func randomFilename() (string, error) {
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD
index b6bbb0ea2..b8fdc3125 100644
--- a/pkg/urpc/BUILD
+++ b/pkg/urpc/BUILD
@@ -11,6 +11,7 @@ go_library(
deps = [
"//pkg/fd",
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
],
)
diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go
index df59ffab1..13b2ea314 100644
--- a/pkg/urpc/urpc.go
+++ b/pkg/urpc/urpc.go
@@ -27,10 +27,10 @@ import (
"os"
"reflect"
"runtime"
- "sync"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
index 0427bc41f..1c6890e52 100644
--- a/pkg/waiter/BUILD
+++ b/pkg/waiter/BUILD
@@ -24,6 +24,7 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/waiter",
visibility = ["//visibility:public"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 8a65ed164..f708e95fa 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -58,7 +58,7 @@
package waiter
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// EventMask represents io events as used in the poll() syscall.
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 6226b63f8..3e20f8f2f 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -74,6 +74,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/sentry/watchdog",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/link/fdbased",
@@ -114,6 +115,7 @@ go_test(
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
"//pkg/unet",
"//runsc/fsgofer",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index 352e710d2..9c23b9553 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -17,7 +17,6 @@ package boot
import (
"fmt"
"os"
- "sync"
"syscall"
"github.com/golang/protobuf/proto"
@@ -27,6 +26,7 @@ import (
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
"gvisor.dev/gvisor/pkg/sentry/strace"
spb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
func initCompatLogs(fd int) error {
diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go
index d1c0bb9b5..ce62236e5 100644
--- a/runsc/boot/limits.go
+++ b/runsc/boot/limits.go
@@ -16,12 +16,12 @@ package boot
import (
"fmt"
- "sync"
"syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Mapping from linux resource names to limits.LimitType.
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index bc1d0c1bb..fad72f4ab 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -20,7 +20,6 @@ import (
mrand "math/rand"
"os"
"runtime"
- "sync"
"sync/atomic"
"syscall"
gtime "time"
@@ -46,6 +45,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 147ff7703..bec0dc292 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -19,7 +19,6 @@ import (
"math/rand"
"os"
"reflect"
- "sync"
"syscall"
"testing"
"time"
@@ -30,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/fsgofer"
)
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index dd4926bb9..6a8765ec8 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -126,7 +126,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
linkEP := loopback.New()
log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
return err
}
@@ -173,7 +173,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
return err
}
@@ -218,15 +218,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
// createNICWithAddrs creates a NIC in the network stack and adds the given
// addresses.
-func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error {
- if loopback {
- if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil {
- return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, ep, err)
- }
- } else {
- if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil {
- return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, ep, err)
- }
+func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP) error {
+ opts := stack.NICOptions{Name: name}
+ if err := n.Stack.CreateNICWithOptions(id, sniffer.New(ep), opts); err != nil {
+ return fmt.Errorf("CreateNICWithOptions(%d, _, %+v) failed: %v", id, opts, err)
}
// Always start with an arp address for the NIC.
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 250845ad7..b94bc4fa0 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -44,6 +44,7 @@ go_library(
"//pkg/sentry/control",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
"//runsc/boot",
diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go
index a4e3071b3..1815c93b9 100644
--- a/runsc/cmd/create.go
+++ b/runsc/cmd/create.go
@@ -16,6 +16,7 @@ package cmd
import (
"context"
+
"flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 4831210c0..7df7995f0 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -21,7 +21,6 @@ import (
"os"
"path/filepath"
"strings"
- "sync"
"syscall"
"flag"
@@ -30,6 +29,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/fsgofer"
diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go
index de2115dff..5e9bc53ab 100644
--- a/runsc/cmd/start.go
+++ b/runsc/cmd/start.go
@@ -16,6 +16,7 @@ package cmd
import (
"context"
+
"flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index 2bd12120d..6dea179e4 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -18,6 +18,7 @@ go_library(
deps = [
"//pkg/log",
"//pkg/sentry/control",
+ "//pkg/sync",
"//runsc/boot",
"//runsc/cgroup",
"//runsc/sandbox",
@@ -53,6 +54,7 @@ go_test(
"//pkg/sentry/control",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
"//runsc/boot",
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 5ed131a7f..060b63bf3 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -20,7 +20,6 @@ import (
"io"
"os"
"path/filepath"
- "sync"
"syscall"
"testing"
"time"
@@ -29,6 +28,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/urpc"
"gvisor.dev/gvisor/runsc/testutil"
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index c10f85992..b54d8f712 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -26,7 +26,6 @@ import (
"reflect"
"strconv"
"strings"
- "sync"
"syscall"
"testing"
"time"
@@ -39,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
"gvisor.dev/gvisor/runsc/specutils"
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 4ad09ceab..2da93ec5b 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -22,7 +22,6 @@ import (
"path"
"path/filepath"
"strings"
- "sync"
"syscall"
"testing"
"time"
@@ -30,6 +29,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
"gvisor.dev/gvisor/runsc/testutil"
diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go
index d95151ea5..17a251530 100644
--- a/runsc/container/state_file.go
+++ b/runsc/container/state_file.go
@@ -20,10 +20,10 @@ import (
"io/ioutil"
"os"
"path/filepath"
- "sync"
"github.com/gofrs/flock"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
)
const stateFileExtension = ".state"
diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD
index afcb41801..a9582d92b 100644
--- a/runsc/fsgofer/BUILD
+++ b/runsc/fsgofer/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/fd",
"//pkg/log",
"//pkg/p9",
+ "//pkg/sync",
"//pkg/syserr",
"//runsc/specutils",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index b59e1a70e..4d84ad999 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -29,7 +29,6 @@ import (
"path/filepath"
"runtime"
"strconv"
- "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -37,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -767,6 +767,16 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
return err
}
+// TODO(b/127675828): support getxattr.
+func (l *localFile) GetXattr(name string, size uint64) (string, error) {
+ return "", syscall.EOPNOTSUPP
+}
+
+// TODO(b/127675828): support setxattr.
+func (l *localFile) SetXattr(name, value string, flags uint32) error {
+ return syscall.EOPNOTSUPP
+}
+
// Allocate implements p9.File.
func (l *localFile) Allocate(mode p9.AllocateMode, offset, length uint64) error {
if !l.isOpen() {
diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD
index 8001949d5..ddbc37456 100644
--- a/runsc/sandbox/BUILD
+++ b/runsc/sandbox/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/platform",
+ "//pkg/sync",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
"//pkg/urpc",
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index be8b72b3e..ff48f5646 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -321,16 +321,21 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
}
}
- // Use SO_RCVBUFFORCE because on linux the receive buffer for an
- // AF_PACKET socket is capped by "net.core.rmem_max". rmem_max
- // defaults to a unusually low value of 208KB. This is too low
- // for gVisor to be able to receive packets at high throughputs
- // without incurring packet drops.
- const rcvBufSize = 4 << 20 // 4MB.
-
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, rcvBufSize); err != nil {
- return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", rcvBufSize, err)
+ // Use SO_RCVBUFFORCE/SO_SNDBUFFORCE because on linux the receive/send buffer
+ // for an AF_PACKET socket is capped by "net.core.rmem_max/wmem_max".
+ // wmem_max/rmem_max default to a unusually low value of 208KB. This is too low
+ // for gVisor to be able to receive packets at high throughputs without
+ // incurring packet drops.
+ const bufSize = 4 << 20 // 4MB.
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil {
+ return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", bufSize, err)
}
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil {
+ return nil, fmt.Errorf("failed to increase socket snd buffer to %d: %v", bufSize, err)
+ }
+
return &socketEntry{deviceFile, gsoMaxSize}, nil
}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index ce1452b87..ec72bdbfd 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -22,7 +22,6 @@ import (
"os"
"os/exec"
"strconv"
- "sync"
"syscall"
"time"
@@ -34,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/urpc"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD
index c96ca2eb6..3c3027cb5 100644
--- a/runsc/testutil/BUILD
+++ b/runsc/testutil/BUILD
@@ -10,6 +10,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//runsc/boot",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index 9632776d2..fb22eae39 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -34,7 +34,6 @@ import (
"path/filepath"
"strconv"
"strings"
- "sync"
"sync/atomic"
"syscall"
"time"
@@ -42,6 +41,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/scripts/common.sh b/scripts/common.sh
index 6dabad141..fdb1aa142 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -73,7 +73,7 @@ function install_runsc() {
sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@"
# Clear old logs files that may exist.
- sudo rm -f "${RUNSC_LOGS_DIR}"/*
+ sudo rm -f "${RUNSC_LOGS_DIR}"/'*'
# Restart docker to pick up the new runtime configuration.
sudo systemctl restart docker
diff --git a/scripts/issue_reviver.sh b/scripts/issue_reviver.sh
new file mode 100755
index 000000000..bac9b9192
--- /dev/null
+++ b/scripts/issue_reviver.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+DIR=$(dirname $0)
+source "${DIR}"/common.sh
+
+# Provide a credential file if available.
+export OAUTH_TOKEN_FILE=""
+if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
+ OAUTH_TOKEN_FILE="${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}"
+fi
+
+REPO_ROOT=$(cd "$(dirname "${DIR}")"; pwd)
+run //tools/issue_reviver:issue_reviver --path "${REPO_ROOT}" --oauth-token-file="${OAUTH_TOKEN_FILE}"
diff --git a/test/iptables/BUILD b/test/iptables/BUILD
index fa833c3b2..372ba7abf 100644
--- a/test/iptables/BUILD
+++ b/test/iptables/BUILD
@@ -6,8 +6,10 @@ go_library(
name = "iptables",
srcs = [
"filter_input.go",
+ "filter_output.go",
"iptables.go",
"iptables_util.go",
+ "nat.go",
],
importpath = "gvisor.dev/gvisor/test/iptables",
visibility = ["//test/iptables:__subpackages__"],
diff --git a/test/iptables/README.md b/test/iptables/README.md
index b37cb2a96..9f8e34420 100644
--- a/test/iptables/README.md
+++ b/test/iptables/README.md
@@ -1,6 +1,6 @@
# iptables Tests
-iptables tests are run via `scripts/iptables\_test.sh`.
+iptables tests are run via `scripts/iptables_test.sh`.
## Test Structure
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index 923f44e68..4b8bbb093 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -31,6 +31,8 @@ func init() {
RegisterTestCase(FilterInputDropUDP{})
RegisterTestCase(FilterInputDropUDPPort{})
RegisterTestCase(FilterInputDropDifferentUDPPort{})
+ RegisterTestCase(FilterInputDropTCPDestPort{})
+ RegisterTestCase(FilterInputDropTCPSrcPort{})
}
// FilterInputDropUDP tests that we can drop UDP traffic.
@@ -122,3 +124,65 @@ func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error {
func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
+
+// FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports.
+type FilterInputDropTCPDestPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropTCPDestPort) Name() string {
+ return "FilterInputDropTCPDestPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on drop port.
+ if err := listenTCP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// FilterInputDropTCPSrcPort tests that connections are not accepted on specified source ports.
+type FilterInputDropTCPSrcPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDropTCPSrcPort) Name() string {
+ return "FilterInputDropTCPSrcPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be acceptedi, but got accepted", dropPort)
+ }
+
+ return nil
+}
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
new file mode 100644
index 000000000..ee2c49f9a
--- /dev/null
+++ b/test/iptables/filter_output.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "fmt"
+ "net"
+)
+
+func init() {
+ RegisterTestCase(FilterOutputDropTCPDestPort{})
+ RegisterTestCase(FilterOutputDropTCPSrcPort{})
+}
+
+// FilterOutputDropTCPDestPort tests that connections are not accepted on specified source ports.
+type FilterOutputDropTCPDestPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPDestPort) Name() string {
+ return "FilterOutputDropTCPDestPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// FilterOutputDropTCPSrcPort tests that connections are not accepted on specified source ports.
+type FilterOutputDropTCPSrcPort struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPSrcPort) Name() string {
+ return "FilterOutputDropTCPSrcPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on drop port.
+ if err := listenTCP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index bfbf1bb87..d268ea9b4 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -28,7 +28,7 @@ import (
"gvisor.dev/gvisor/runsc/testutil"
)
-const timeout time.Duration = 10 * time.Second
+const timeout = 18 * time.Second
var image = flag.String("image", "bazel/test/iptables/runner:runner", "image to run tests in")
@@ -177,3 +177,39 @@ func TestFilterInputDropDifferentUDPPort(t *testing.T) {
t.Fatal(err)
}
}
+
+func TestNATRedirectUDPPort(t *testing.T) {
+ if err := singleTest(NATRedirectUDPPort{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATDropUDP(t *testing.T) {
+ if err := singleTest(NATDropUDP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterInputDropTCPDestPort(t *testing.T) {
+ if err := singleTest(FilterInputDropTCPDestPort{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterInputDropTCPSrcPort(t *testing.T) {
+ if err := singleTest(FilterInputDropTCPSrcPort{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputDropTCPDestPort(t *testing.T) {
+ if err := singleTest(FilterOutputDropTCPDestPort{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputDropTCPSrcPort(t *testing.T) {
+ if err := singleTest(FilterOutputDropTCPSrcPort{}); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index 3a4d11f1a..1c4f4f665 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -80,3 +80,56 @@ func sendUDPLoop(ip net.IP, port int, duration time.Duration) error {
return nil
}
+
+// listenTCP listens for connections on a TCP port.
+func listenTCP(port int, timeout time.Duration) error {
+ localAddr := net.TCPAddr{
+ Port: port,
+ }
+
+ // Starts listening on port.
+ lConn, err := net.ListenTCP("tcp4", &localAddr)
+ if err != nil {
+ return err
+ }
+ defer lConn.Close()
+
+ // Accept connections on port.
+ lConn.SetDeadline(time.Now().Add(timeout))
+ conn, err := lConn.AcceptTCP()
+ if err != nil {
+ return err
+ }
+ conn.Close()
+ return nil
+}
+
+// connectTCP connects the TCP server over specified local port, server IP and remote/server port.
+func connectTCP(ip net.IP, remotePort, localPort int, duration time.Duration) error {
+ remote := net.TCPAddr{
+ IP: ip,
+ Port: remotePort,
+ }
+
+ local := net.TCPAddr{
+ Port: localPort,
+ }
+
+ // Container may not be up. Retry DialTCP over a duration.
+ to := time.After(duration)
+ for {
+ conn, err := net.DialTCP("tcp4", &local, &remote)
+ if err == nil {
+ conn.Close()
+ return nil
+ }
+ select {
+ // Timed out waiting for connection to be accepted.
+ case <-to:
+ return err
+ default:
+ time.Sleep(200 * time.Millisecond)
+ }
+ }
+ return fmt.Errorf("Failed to establish connection on port %d", localPort)
+}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
new file mode 100644
index 000000000..b5c6f927e
--- /dev/null
+++ b/test/iptables/nat.go
@@ -0,0 +1,80 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "fmt"
+ "net"
+)
+
+const (
+ redirectPort = 42
+)
+
+func init() {
+ RegisterTestCase(NATRedirectUDPPort{})
+ RegisterTestCase(NATDropUDP{})
+}
+
+// NATRedirectUDPPort tests that packets are redirected to different port.
+type NATRedirectUDPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATRedirectUDPPort) Name() string {
+ return "NATRedirectUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATRedirectUDPPort) ContainerAction(ip net.IP) error {
+ if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ if err := listenUDP(redirectPort, sendloopDuration); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err)
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATRedirectUDPPort) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATDropUDP tests that packets are not received in ports other than redirect port.
+type NATDropUDP struct{}
+
+// Name implements TestCase.Name.
+func (NATDropUDP) Name() string {
+ return "NATDropUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATDropUDP) ContainerAction(ip net.IP) error {
+ if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ if err := listenUDP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATDropUDP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index a3a85917d..829693e8e 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -717,11 +717,6 @@ syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test")
syscall_test(test = "//test/syscalls/linux:proc_net_udp_test")
-syscall_test(
- add_overlay = True,
- test = "//test/syscalls/linux:xattr_test",
-)
-
go_binary(
name = "syscall_test_runner",
testonly = 1,
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 064ce8429..4c7ec3f06 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2693,6 +2693,7 @@ cc_binary(
srcs = ["socket_inet_loopback.cc"],
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
":socket_test_util",
"//test/util:file_descriptor",
"//test/util:posix_error",
@@ -2888,7 +2889,6 @@ cc_library(
":unix_domain_socket_test_util",
"//test/util:test_util",
"//test/util:thread_util",
- "//test/util:timer_util",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index 7384c27dc..fdef646eb 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -977,7 +977,7 @@ TEST(Inotify, WatchOnRelativePath) {
ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
// Change working directory to root.
- const char* old_working_dir = get_current_dir_name();
+ const FileDescriptor cwd = ASSERT_NO_ERRNO_AND_VALUE(Open(".", O_PATH));
EXPECT_THAT(chdir(root.path().c_str()), SyscallSucceeds());
// Add a watch on file1 with a relative path.
@@ -997,7 +997,7 @@ TEST(Inotify, WatchOnRelativePath) {
// continue to hold a reference, random save/restore tests can fail if a save
// is triggered after "root" is unlinked; we can't save deleted fs objects
// with active references.
- EXPECT_THAT(chdir(old_working_dir), SyscallSucceeds());
+ EXPECT_THAT(fchdir(cwd.get()), SyscallSucceeds());
}
TEST(Inotify, ZeroLengthReadWriteDoesNotGenerateEvent) {
@@ -1591,6 +1591,34 @@ TEST(Inotify, EpollNoDeadlock) {
}
}
+TEST(Inotify, SpliceEvent) {
+ int pipes[2];
+ ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ const int watcher = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ char buf;
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr,
+ sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK),
+ SyscallSucceedsWithValue(sizeof(struct inotify_event)));
+
+ const FileDescriptor read_fd(pipes[0]);
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)}));
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index 8398fc95f..6b472eb2f 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -187,24 +187,24 @@ PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) {
return InterfaceIndex(name);
}
-std::string GetAddr4Str(in_addr* a) {
+std::string GetAddr4Str(const in_addr* a) {
char str[INET_ADDRSTRLEN];
inet_ntop(AF_INET, a, str, sizeof(str));
return std::string(str);
}
-std::string GetAddr6Str(in6_addr* a) {
+std::string GetAddr6Str(const in6_addr* a) {
char str[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, a, str, sizeof(str));
return std::string(str);
}
-std::string GetAddrStr(sockaddr* a) {
+std::string GetAddrStr(const sockaddr* a) {
if (a->sa_family == AF_INET) {
- auto src = &(reinterpret_cast<sockaddr_in*>(a)->sin_addr);
+ auto src = &(reinterpret_cast<const sockaddr_in*>(a)->sin_addr);
return GetAddr4Str(src);
} else if (a->sa_family == AF_INET6) {
- auto src = &(reinterpret_cast<sockaddr_in6*>(a)->sin6_addr);
+ auto src = &(reinterpret_cast<const sockaddr_in6*>(a)->sin6_addr);
return GetAddr6Str(src);
}
return std::string("<invalid>");
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 9cb4566db..0f58e0f77 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -105,14 +105,14 @@ class IfAddrHelper {
};
// GetAddr4Str returns the given IPv4 network address structure as a string.
-std::string GetAddr4Str(in_addr* a);
+std::string GetAddr4Str(const in_addr* a);
// GetAddr6Str returns the given IPv6 network address structure as a string.
-std::string GetAddr6Str(in6_addr* a);
+std::string GetAddr6Str(const in6_addr* a);
// GetAddrStr returns the given IPv4 or IPv6 network address structure as a
// string.
-std::string GetAddrStr(sockaddr* a);
+std::string GetAddrStr(const sockaddr* a);
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc
index 33822ee57..df7129acc 100644
--- a/test/syscalls/linux/partial_bad_buffer.cc
+++ b/test/syscalls/linux/partial_bad_buffer.cc
@@ -18,7 +18,9 @@
#include <netinet/tcp.h>
#include <sys/mman.h>
#include <sys/socket.h>
+#include <sys/stat.h>
#include <sys/syscall.h>
+#include <sys/types.h>
#include <sys/uio.h>
#include <unistd.h>
@@ -62,9 +64,9 @@ class PartialBadBufferTest : public ::testing::Test {
// Write some initial data.
size_t size = sizeof(kMessage) - 1;
EXPECT_THAT(WriteFd(fd_, &kMessage, size), SyscallSucceedsWithValue(size));
-
ASSERT_THAT(lseek(fd_, 0, SEEK_SET), SyscallSucceeds());
+ // Map a useable buffer.
addr_ = mmap(0, 2 * kPageSize, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
ASSERT_NE(addr_, MAP_FAILED);
@@ -79,6 +81,15 @@ class PartialBadBufferTest : public ::testing::Test {
bad_buffer_ = buf + kPageSize - 1;
}
+ off_t Size() {
+ struct stat st;
+ int rc = fstat(fd_, &st);
+ if (rc < 0) {
+ return static_cast<off_t>(rc);
+ }
+ return st.st_size;
+ }
+
void TearDown() override {
EXPECT_THAT(munmap(addr_, 2 * kPageSize), SyscallSucceeds()) << addr_;
EXPECT_THAT(close(fd_), SyscallSucceeds());
@@ -165,97 +176,99 @@ TEST_F(PartialBadBufferTest, PreadvSmall) {
}
TEST_F(PartialBadBufferTest, WriteBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, kPageSize),
- SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(write)(fd_, bad_buffer_, kPageSize)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WriteSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, 10),
- SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(write)(fd_, bad_buffer_, 10)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwriteBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwriteSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwrite)(fd_, bad_buffer_, 10, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwrite)(fd_, bad_buffer_, 10, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WritevBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = kPageSize;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(writev)(fd_, &vec, 1), SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(writev)(fd_, &vec, 1)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WritevSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = 10;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(writev)(fd_, &vec, 1), SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(writev)(fd_, &vec, 1)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwritevBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = kPageSize;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwritev)(fd_, &vec, 1, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwritevSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = 10;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwritev)(fd_, &vec, 1, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
// getdents returns EFAULT when the you claim the buffer is large enough, but
@@ -283,29 +296,6 @@ TEST_F(PartialBadBufferTest, GetdentsOneEntry) {
SyscallSucceedsWithValue(Gt(0)));
}
-// Verify that when write returns EFAULT the kernel hasn't silently written
-// the initial valid bytes.
-TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
- bad_buffer_[0] = 'A';
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, 10),
- SyscallFailsWithErrno(EFAULT));
-
- size_t size = 255;
- char buf[255];
- memset(buf, 0, size);
-
- EXPECT_THAT(RetryEINTR(pread)(fd_, buf, size, 0),
- SyscallSucceedsWithValue(sizeof(kMessage) - 1));
-
- // 'A' has not been written.
- EXPECT_STREQ(buf, kMessage);
-}
-
PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
struct sockaddr_storage addr;
memset(&addr, 0, sizeof(addr));
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
index 9e5aa7fd0..c42472474 100644
--- a/test/syscalls/linux/poll.cc
+++ b/test/syscalls/linux/poll.cc
@@ -275,7 +275,8 @@ TEST_F(PollTest, Nfds) {
// Each entry in the 'fds' array refers to the eventfd and polls for
// "writable" events (events=POLLOUT). This essentially guarantees that the
// poll() is a no-op and allows negative testing of the 'nfds' parameter.
- std::vector<struct pollfd> fds(max_fds, {.fd = efd.get(), .events = POLLOUT});
+ std::vector<struct pollfd> fds(max_fds + 1,
+ {.fd = efd.get(), .events = POLLOUT});
// Verify that 'nfds' up to RLIMIT_NOFILE are allowed.
EXPECT_THAT(RetryEINTR(poll)(fds.data(), 1, 1), SyscallSucceedsWithValue(1));
diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc
index c9246367d..cd936ea90 100644
--- a/test/syscalls/linux/preadv2.cc
+++ b/test/syscalls/linux/preadv2.cc
@@ -202,7 +202,7 @@ TEST(Preadv2Test, TestInvalidOffset) {
iov[0].iov_len = 0;
EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/-8,
- /*flags=*/RWF_HIPRI),
+ /*flags=*/0),
SyscallFailsWithErrno(EINVAL));
}
diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc
index 491d5f40f..2694dc64f 100644
--- a/test/syscalls/linux/readv_common.cc
+++ b/test/syscalls/linux/readv_common.cc
@@ -154,7 +154,7 @@ void ReadBuffersOverlapping(int fd) {
char* expected_ptr = expected.data();
memcpy(expected_ptr, &kReadvTestData[overlap_bytes], overlap_bytes);
memcpy(&expected_ptr[overlap_bytes], &kReadvTestData[overlap_bytes],
- kReadvTestDataSize);
+ kReadvTestDataSize - overlap_bytes);
struct iovec iovs[2];
iovs[0].iov_base = buffer.data();
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
index 5767181a1..5ed57625c 100644
--- a/test/syscalls/linux/socket_bind_to_device_distribution.cc
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -183,7 +183,14 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
}
// Receive some data from a socket to be sure that the connect()
// system call has been completed on another side.
- int data;
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
EXPECT_THAT(
RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -198,15 +205,29 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
}
for (int i = 0; i < kConnectAttempts; i++) {
- FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len),
SyscallSucceeds());
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
}
// Join threads to be sure that all connections have been counted.
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 619d41901..2f9821555 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -32,6 +32,7 @@
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
@@ -102,6 +103,161 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) {
SyscallFailsWithErrno(EAFNOSUPPORT));
}
+enum class Operation {
+ Bind,
+ Connect,
+ SendTo,
+};
+
+std::string OperationToString(Operation operation) {
+ switch (operation) {
+ case Operation::Bind:
+ return "Bind";
+ case Operation::Connect:
+ return "Connect";
+ case Operation::SendTo:
+ return "SendTo";
+ }
+}
+
+using OperationSequence = std::vector<Operation>;
+
+using DualStackSocketTest =
+ ::testing::TestWithParam<std::tuple<TestAddress, OperationSequence>>;
+
+TEST_P(DualStackSocketTest, AddressOperations) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, 0));
+
+ const TestAddress& addr = std::get<0>(GetParam());
+ const OperationSequence& operations = std::get<1>(GetParam());
+
+ auto addr_in = reinterpret_cast<const sockaddr*>(&addr.addr);
+
+ // sockets may only be bound once. Both `connect` and `sendto` cause a socket
+ // to be bound.
+ bool bound = false;
+ for (const Operation& operation : operations) {
+ bool sockname = false;
+ bool peername = false;
+ switch (operation) {
+ case Operation::Bind: {
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 0));
+
+ int bind_ret = bind(fd.get(), addr_in, addr.addr_len);
+
+ // Dual stack sockets may only be bound to AF_INET6.
+ if (!bound && addr.family() == AF_INET6) {
+ EXPECT_THAT(bind_ret, SyscallSucceeds());
+ bound = true;
+
+ sockname = true;
+ } else {
+ EXPECT_THAT(bind_ret, SyscallFailsWithErrno(EINVAL));
+ }
+ break;
+ }
+ case Operation::Connect: {
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337));
+
+ EXPECT_THAT(connect(fd.get(), addr_in, addr.addr_len),
+ SyscallSucceeds())
+ << GetAddrStr(addr_in);
+ bound = true;
+
+ sockname = true;
+ peername = true;
+
+ break;
+ }
+ case Operation::SendTo: {
+ const char payload[] = "hello";
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337));
+
+ ssize_t sendto_ret = sendto(fd.get(), &payload, sizeof(payload), 0,
+ addr_in, addr.addr_len);
+
+ EXPECT_THAT(sendto_ret, SyscallSucceedsWithValue(sizeof(payload)));
+ sockname = !bound;
+ bound = true;
+ break;
+ }
+ }
+
+ if (sockname) {
+ sockaddr_storage sock_addr;
+ socklen_t addrlen = sizeof(sock_addr);
+ ASSERT_THAT(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&sock_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ auto sock_addr_in6 = reinterpret_cast<const sockaddr_in6*>(&sock_addr);
+
+ if (operation == Operation::SendTo) {
+ EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6);
+ EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getsocknam="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+
+ EXPECT_NE(sock_addr_in6->sin6_port, 0);
+ } else if (IN6_IS_ADDR_V4MAPPED(
+ reinterpret_cast<const sockaddr_in6*>(addr_in)
+ ->sin6_addr.s6_addr32)) {
+ EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getsocknam="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+ }
+ }
+
+ if (peername) {
+ sockaddr_storage peer_addr;
+ socklen_t addrlen = sizeof(peer_addr);
+ ASSERT_THAT(getpeername(fd.get(), reinterpret_cast<sockaddr*>(&peer_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ if (addr.family() == AF_INET ||
+ IN6_IS_ADDR_V4MAPPED(reinterpret_cast<const sockaddr_in6*>(addr_in)
+ ->sin6_addr.s6_addr32)) {
+ EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(
+ reinterpret_cast<const sockaddr_in6*>(&peer_addr)
+ ->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getpeername="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&peer_addr));
+ }
+ }
+ }
+}
+
+// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny.
+INSTANTIATE_TEST_SUITE_P(
+ All, DualStackSocketTest,
+ ::testing::Combine(
+ ::testing::Values(V4Any(), V4Loopback(), /*V4MappedAny(),*/
+ V4MappedLoopback(), V6Any(), V6Loopback()),
+ ::testing::ValuesIn<OperationSequence>(
+ {{Operation::Bind, Operation::Connect, Operation::SendTo},
+ {Operation::Bind, Operation::SendTo, Operation::Connect},
+ {Operation::Connect, Operation::Bind, Operation::SendTo},
+ {Operation::Connect, Operation::SendTo, Operation::Bind},
+ {Operation::SendTo, Operation::Bind, Operation::Connect},
+ {Operation::SendTo, Operation::Connect, Operation::Bind}})),
+ [](::testing::TestParamInfo<
+ std::tuple<TestAddress, OperationSequence>> const& info) {
+ const TestAddress& addr = std::get<0>(info.param);
+ const OperationSequence& operations = std::get<1>(info.param);
+ std::string s = addr.description;
+ for (const Operation& operation : operations) {
+ absl::StrAppend(&s, OperationToString(operation));
+ }
+ return s;
+ });
+
void tcpSimpleConnectTest(TestAddress const& listener,
TestAddress const& connector, bool unbound) {
// Create the listening socket.
@@ -377,7 +533,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
// Sleep for a little over the linger timeout to reduce flakiness in
// save/restore tests.
- absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1));
+ absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2));
ds.reset();
@@ -714,7 +870,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
sockaddr_storage listen_addr = listener.addr;
sockaddr_storage conn_addr = connector.addr;
constexpr int kThreadCount = 3;
- constexpr int kConnectAttempts = 4096;
+ constexpr int kConnectAttempts = 10000;
// Create the listening socket.
FileDescriptor listener_fds[kThreadCount];
@@ -729,7 +885,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
ASSERT_THAT(
bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
SyscallSucceeds());
- ASSERT_THAT(listen(fd, kConnectAttempts / 3), SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
// On the first bind we need to determine which port was bound.
if (i != 0) {
@@ -772,7 +928,14 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
}
// Receive some data from a socket to be sure that the connect()
// system call has been completed on another side.
- int data;
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
EXPECT_THAT(
RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -795,8 +958,22 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
connector.addr_len),
SyscallSucceeds());
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
}
});
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 66eb68857..53290bed7 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -209,6 +209,46 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) {
EXPECT_EQ(get, kSockOptOn);
}
+// Ensure that Receiving TOS is off by default.
+TEST_P(UDPSocketPairTest, RecvTosDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test that setting and getting IP_RECVTOS works as expected.
+TEST_P(UDPSocketPairTest, SetRecvTos) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
TEST_P(UDPSocketPairTest, ReuseAddrDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc
index b6754111f..ca597e267 100644
--- a/test/syscalls/linux/socket_ip_unbound.cc
+++ b/test/syscalls/linux/socket_ip_unbound.cc
@@ -129,6 +129,7 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTtl) {
struct TOSOption {
int level;
int option;
+ int cmsg_level;
};
constexpr int INET_ECN_MASK = 3;
@@ -139,10 +140,12 @@ static TOSOption GetTOSOption(int domain) {
case AF_INET:
opt.level = IPPROTO_IP;
opt.option = IP_TOS;
+ opt.cmsg_level = SOL_IP;
break;
case AF_INET6:
opt.level = IPPROTO_IPV6;
opt.option = IPV6_TCLASS;
+ opt.cmsg_level = SOL_IPV6;
break;
}
return opt;
@@ -386,6 +389,36 @@ TEST_P(IPUnboundSocketTest, NullTOS) {
SyscallFailsWithErrno(EFAULT));
}
+TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) {
+ SKIP_IF(GetParam().protocol == IPPROTO_TCP);
+
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ TOSOption t = GetTOSOption(GetParam().domain);
+
+ in_addr addr4;
+ in6_addr addr6;
+ ASSERT_THAT(inet_pton(AF_INET, "127.0.0.1", &addr4), ::testing::Eq(1));
+ ASSERT_THAT(inet_pton(AF_INET6, "fe80::", &addr6), ::testing::Eq(1));
+
+ cmsghdr cmsg = {};
+ cmsg.cmsg_len = sizeof(cmsg);
+ cmsg.cmsg_level = t.cmsg_level;
+ cmsg.cmsg_type = t.option;
+
+ msghdr msg = {};
+ msg.msg_control = &cmsg;
+ msg.msg_controllen = sizeof(cmsg);
+ if (GetParam().domain == AF_INET) {
+ msg.msg_name = &addr4;
+ msg.msg_namelen = sizeof(addr4);
+ } else {
+ msg.msg_name = &addr6;
+ msg.msg_namelen = sizeof(addr6);
+ }
+
+ EXPECT_THAT(sendmsg(socket->get(), &msg, 0), SyscallFailsWithErrno(EINVAL));
+}
+
INSTANTIATE_TEST_SUITE_P(
IPUnboundSockets, IPUnboundSocketTest,
::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>(
diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc
index d91c5ed39..c61817f14 100644
--- a/test/syscalls/linux/socket_non_stream.cc
+++ b/test/syscalls/linux/socket_non_stream.cc
@@ -113,7 +113,7 @@ TEST_P(NonStreamSocketPairTest, RecvmsgMsghdrFlagMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were updated.
- EXPECT_EQ(msg.msg_flags, MSG_TRUNC);
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
}
// Stream sockets allow data sent with multiple sends to be peeked at in a
@@ -193,7 +193,7 @@ TEST_P(NonStreamSocketPairTest, MsgTruncTruncationRecvmsgMsghdrFlagMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were updated.
- EXPECT_EQ(msg.msg_flags, MSG_TRUNC);
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
}
TEST_P(NonStreamSocketPairTest, MsgTruncSameSize) {
@@ -224,5 +224,114 @@ TEST_P(NonStreamSocketPairTest, MsgTruncNotFull) {
EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
}
+// This test tests reading from a socket with MSG_TRUNC and a zero length
+// receive buffer. The user should be able to get the message length.
+TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // The receive buffer is of zero length.
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ // The syscall succeeds returning the full size of the message on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is set on msghdr flags.
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer. The user should be able to get the message length
+// without reading data off the socket.
+TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncMsgPeekZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ // The syscall succeeds returning the full size of the message on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is set on msghdr flags because the receive buffer is
+ // smaller than the message size.
+ EXPECT_EQ(peek_msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+
+ char received_data[sizeof(sent_data)] = {};
+
+ struct iovec received_iov;
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = sizeof(received_data);
+ struct msghdr received_msg = {};
+ received_msg.msg_flags = -1;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+
+ // Next we can read the actual data.
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(sockets->second_fd(), &received_msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is not set on msghdr flags because we read the whole
+ // message.
+ EXPECT_EQ(received_msg.msg_flags & MSG_TRUNC, 0);
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer and MSG_DONTWAIT. The user should be able to get an
+// EAGAIN or EWOULDBLOCK error response.
+TEST_P(NonStreamSocketPairTest, RecvmsgTruncPeekDontwaitZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // NOTE: We don't send any data on the socket.
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ // recvmsg fails with EAGAIN because no data is available on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK | MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc
index 62d87c1af..b052f6e61 100644
--- a/test/syscalls/linux/socket_non_stream_blocking.cc
+++ b/test/syscalls/linux/socket_non_stream_blocking.cc
@@ -25,6 +25,7 @@
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -44,5 +45,41 @@ TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) {
SyscallSucceedsWithValue(sizeof(sent_data)));
}
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer and MSG_DONTWAIT. The recvmsg call should block on
+// reading the data.
+TEST_P(BlockingNonStreamSocketPairTest,
+ RecvmsgTruncPeekDontwaitZeroLenBlocking) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // NOTE: We don't initially send any data on the socket.
+ const int data_size = 10;
+ char sent_data[data_size];
+ RandomizeBuffer(sent_data, data_size);
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ ScopedThread t([&]() {
+ // The syscall succeeds returning the full size of the message on the
+ // socket. This should block until there is data on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(data_size));
+ });
+
+ absl::SleepFor(absl::Seconds(1));
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), sent_data, data_size, 0),
+ SyscallSucceedsWithValue(data_size));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc
index 346443f96..6522b2e01 100644
--- a/test/syscalls/linux/socket_stream.cc
+++ b/test/syscalls/linux/socket_stream.cc
@@ -104,7 +104,60 @@ TEST_P(StreamSocketPairTest, RecvmsgMsghdrFlagsNoMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were cleared (MSG_TRUNC was not set).
- EXPECT_EQ(msg.msg_flags, 0);
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgTruncZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(0));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgTruncPeekZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(0));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
}
TEST_P(StreamSocketPairTest, MsgTrunc) {
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 6b99c021d..33a5ac66c 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -814,6 +814,20 @@ TEST_P(TcpSocketTest, FullBuffer) {
t_ = -1;
}
+TEST_P(TcpSocketTest, PollAfterShutdown) {
+ ScopedThread client_thread([this]() {
+ EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0));
+ struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+ });
+
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0));
+ struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+}
+
TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
// Initialize address to the loopback one.
sockaddr_storage addr =
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
index dc35c2f50..68e0a8109 100644
--- a/test/syscalls/linux/udp_socket_test_cases.cc
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -1349,8 +1349,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
// outgoing packets, and that a receiving socket with IP_RECVTOS or
// IPV6_RECVTCLASS will create the corresponding control message.
TEST_P(UdpSocketTest, SetAndReceiveTOS) {
- // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack.
- SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet());
+ // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack.
+ SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() &&
+ !IsRunningWithHostinet());
ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
@@ -1421,7 +1422,8 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) {
// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or
// IPV6_RECVTCLASS will create the corresponding control message.
TEST_P(UdpSocketTest, SendAndReceiveTOS) {
- // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack.
+ // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack.
+ // TODO(b/146661005): Setting TOS via cmsg not supported for netstack.
SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet());
ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
index 75740238c..b3bc3463e 100644
--- a/test/syscalls/linux/xattr.cc
+++ b/test/syscalls/linux/xattr.cc
@@ -59,7 +59,8 @@ TEST_F(XattrTest, XattrLargeName) {
std::string name = "user.";
name += std::string(XATTR_NAME_MAX - name.length(), 'a');
- // TODO(b/127675828): Support setxattr and getxattr.
+ // An xattr should be whitelisted before it can be accessed--do not allow
+ // arbitrary xattrs to be read/written in gVisor.
if (!IsRunningOnGvisor()) {
EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
SyscallSucceeds());
@@ -83,59 +84,53 @@ TEST_F(XattrTest, XattrInvalidPrefix) {
SyscallFailsWithErrno(EOPNOTSUPP));
}
-TEST_F(XattrTest, XattrReadOnly) {
+// Do not allow save/restore cycles after making the test file read-only, as
+// the restore will fail to open it with r/w permissions.
+TEST_F(XattrTest, XattrReadOnly_NoRandomSave) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
size_t size = sizeof(val);
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
- SyscallSucceeds());
- }
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+ DisableSave ds;
ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IRUSR));
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
SyscallFailsWithErrno(EACCES));
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- char buf = '-';
- EXPECT_THAT(getxattr(path, name, &buf, size),
- SyscallSucceedsWithValue(size));
- EXPECT_EQ(buf, val);
- }
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
}
-TEST_F(XattrTest, XattrWriteOnly) {
+// Do not allow save/restore cycles after making the test file write-only, as
+// the restore will fail to open it with r/w permissions.
+TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+ DisableSave ds;
ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR));
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
size_t size = sizeof(val);
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
- SyscallSucceeds());
- }
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(EACCES));
}
TEST_F(XattrTest, XattrTrustedWithNonadmin) {
- // TODO(b/127675828): Support setxattr and getxattr.
+ // TODO(b/127675828): Support setxattr and getxattr with "trusted" prefix.
SKIP_IF(IsRunningOnGvisor());
SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
@@ -147,11 +142,8 @@ TEST_F(XattrTest, XattrTrustedWithNonadmin) {
}
TEST_F(XattrTest, XattrOnDirectory) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(dir.path().c_str(), name, NULL, 0, /*flags=*/0),
SyscallSucceeds());
EXPECT_THAT(getxattr(dir.path().c_str(), name, NULL, 0),
@@ -159,13 +151,10 @@ TEST_F(XattrTest, XattrOnDirectory) {
}
TEST_F(XattrTest, XattrOnSymlink) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo(dir.path(), test_file_name_));
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(link.path().c_str(), name, NULL, 0, /*flags=*/0),
SyscallSucceeds());
EXPECT_THAT(getxattr(link.path().c_str(), name, NULL, 0),
@@ -173,7 +162,7 @@ TEST_F(XattrTest, XattrOnSymlink) {
}
TEST_F(XattrTest, XattrOnInvalidFileTypes) {
- char name[] = "user.abc";
+ const char name[] = "user.test";
char char_device[] = "/dev/zero";
EXPECT_THAT(setxattr(char_device, name, NULL, 0, /*flags=*/0),
@@ -191,11 +180,8 @@ TEST_F(XattrTest, XattrOnInvalidFileTypes) {
}
TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
size_t size = 1;
EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
@@ -209,11 +195,8 @@ TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
}
TEST_F(XattrTest, SetxattrZeroSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds());
@@ -225,7 +208,7 @@ TEST_F(XattrTest, SetxattrZeroSize) {
TEST_F(XattrTest, SetxattrSizeTooLarge) {
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
// Note that each particular fs implementation may stipulate a lower size
// limit, in which case we actually may fail (e.g. error with ENOSPC) for
@@ -235,43 +218,29 @@ TEST_F(XattrTest, SetxattrSizeTooLarge) {
EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
SyscallFailsWithErrno(E2BIG));
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- EXPECT_THAT(getxattr(path, name, nullptr, 0),
- SyscallFailsWithErrno(ENODATA));
- }
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) {
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0),
SyscallFailsWithErrno(EFAULT));
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- EXPECT_THAT(getxattr(path, name, nullptr, 0),
- SyscallFailsWithErrno(ENODATA));
- }
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
TEST_F(XattrTest, SetxattrNullValueAndZeroSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
}
TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val(XATTR_SIZE_MAX + 1);
std::fill(val.begin(), val.end(), 'a');
size_t size = 1;
@@ -286,11 +255,8 @@ TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) {
}
TEST_F(XattrTest, SetxattrReplaceWithSmaller) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0),
SyscallSucceeds());
@@ -304,11 +270,8 @@ TEST_F(XattrTest, SetxattrReplaceWithSmaller) {
}
TEST_F(XattrTest, SetxattrReplaceWithLarger) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0),
SyscallSucceeds());
@@ -321,11 +284,8 @@ TEST_F(XattrTest, SetxattrReplaceWithLarger) {
}
TEST_F(XattrTest, SetxattrCreateFlag) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
SyscallSucceeds());
EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
@@ -335,11 +295,8 @@ TEST_F(XattrTest, SetxattrCreateFlag) {
}
TEST_F(XattrTest, SetxattrReplaceFlag) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE),
SyscallFailsWithErrno(ENODATA));
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
@@ -357,11 +314,8 @@ TEST_F(XattrTest, SetxattrInvalidFlags) {
}
TEST_F(XattrTest, Getxattr) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
int val = 1234;
size_t size = sizeof(val);
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
@@ -372,11 +326,8 @@ TEST_F(XattrTest, Getxattr) {
}
TEST_F(XattrTest, GetxattrSizeSmallerThanValue) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
size_t size = val.size();
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
@@ -387,11 +338,8 @@ TEST_F(XattrTest, GetxattrSizeSmallerThanValue) {
}
TEST_F(XattrTest, GetxattrSizeLargerThanValue) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
EXPECT_THAT(setxattr(path, name, &val, 1, /*flags=*/0), SyscallSucceeds());
@@ -405,11 +353,8 @@ TEST_F(XattrTest, GetxattrSizeLargerThanValue) {
}
TEST_F(XattrTest, GetxattrZeroSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0),
SyscallSucceeds());
@@ -421,11 +366,8 @@ TEST_F(XattrTest, GetxattrZeroSize) {
}
TEST_F(XattrTest, GetxattrSizeTooLarge) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0),
SyscallSucceeds());
@@ -440,11 +382,8 @@ TEST_F(XattrTest, GetxattrSizeTooLarge) {
}
TEST_F(XattrTest, GetxattrNullValue) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
size_t size = sizeof(val);
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
@@ -454,11 +393,8 @@ TEST_F(XattrTest, GetxattrNullValue) {
}
TEST_F(XattrTest, GetxattrNullValueAndZeroSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
size_t size = sizeof(val);
// Set value with zero size.
@@ -473,13 +409,9 @@ TEST_F(XattrTest, GetxattrNullValueAndZeroSize) {
}
TEST_F(XattrTest, GetxattrNonexistentName) {
- // TODO(b/127675828): Support getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- std::string name = "user.nonexistent";
- EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
- SyscallFailsWithErrno(ENODATA));
+ const char name[] = "user.test";
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
} // namespace
diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD
new file mode 100644
index 000000000..ee7ea11fd
--- /dev/null
+++ b/tools/issue_reviver/BUILD
@@ -0,0 +1,12 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "issue_reviver",
+ srcs = ["main.go"],
+ deps = [
+ "//tools/issue_reviver/github",
+ "//tools/issue_reviver/reviver",
+ ],
+)
diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD
new file mode 100644
index 000000000..6da22ba1c
--- /dev/null
+++ b/tools/issue_reviver/github/BUILD
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "github",
+ srcs = ["github.go"],
+ importpath = "gvisor.dev/gvisor/tools/issue_reviver/github",
+ visibility = [
+ "//tools/issue_reviver:__subpackages__",
+ ],
+ deps = [
+ "//tools/issue_reviver/reviver",
+ "@com_github_google_go-github//github:go_default_library",
+ "@org_golang_x_oauth2//:go_default_library",
+ ],
+)
diff --git a/tools/issue_reviver/github/github.go b/tools/issue_reviver/github/github.go
new file mode 100644
index 000000000..e07949c8f
--- /dev/null
+++ b/tools/issue_reviver/github/github.go
@@ -0,0 +1,164 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package github implements reviver.Bugger interface on top of Github issues.
+package github
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/go-github/github"
+ "golang.org/x/oauth2"
+ "gvisor.dev/gvisor/tools/issue_reviver/reviver"
+)
+
+// Bugger implements reviver.Bugger interface for github issues.
+type Bugger struct {
+ owner string
+ repo string
+ dryRun bool
+
+ client *github.Client
+ issues map[int]*github.Issue
+}
+
+// NewBugger creates a new Bugger.
+func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) {
+ b := &Bugger{
+ owner: owner,
+ repo: repo,
+ dryRun: dryRun,
+ issues: map[int]*github.Issue{},
+ }
+ if err := b.load(token); err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+func (b *Bugger) load(token string) error {
+ ctx := context.Background()
+ if len(token) == 0 {
+ fmt.Print("No OAUTH token provided, using unauthenticated account.\n")
+ b.client = github.NewClient(nil)
+ } else {
+ ts := oauth2.StaticTokenSource(
+ &oauth2.Token{AccessToken: token},
+ )
+ tc := oauth2.NewClient(ctx, ts)
+ b.client = github.NewClient(tc)
+ }
+
+ err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) {
+ opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts}
+ tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts)
+ if err != nil {
+ return resp, err
+ }
+ for _, issue := range tmps {
+ b.issues[issue.GetNumber()] = issue
+ }
+ return resp, nil
+ })
+ if err != nil {
+ return err
+ }
+
+ fmt.Printf("Loaded %d issues from github.com/%s/%s\n", len(b.issues), b.owner, b.repo)
+ return nil
+}
+
+// Activate implements reviver.Bugger.
+func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) {
+ const prefix = "gvisor.dev/issue/"
+
+ // First check if I can handle the TODO.
+ idStr := strings.TrimPrefix(todo.Issue, prefix)
+ if len(todo.Issue) == len(idStr) {
+ return false, nil
+ }
+
+ id, err := strconv.Atoi(idStr)
+ if err != nil {
+ return true, err
+ }
+
+ // Check against active issues cache.
+ if _, ok := b.issues[id]; ok {
+ fmt.Printf("%q is active: OK\n", todo.Issue)
+ return true, nil
+ }
+
+ fmt.Printf("%q is not active: reopening issue %d\n", todo.Issue, id)
+
+ // Format comment with TODO locations and search link.
+ comment := strings.Builder{}
+ fmt.Fprintln(&comment, "There are TODOs still referencing this issue:")
+ for _, l := range todo.Locations {
+ fmt.Fprintf(&comment,
+ "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#%d): %s\n",
+ l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment)
+ }
+ fmt.Fprintf(&comment,
+ "\n\nSearch [TODO](https://github.com/%s/%s/search?q=%%22%s%d%%22)", b.owner, b.repo, prefix, id)
+
+ if b.dryRun {
+ fmt.Printf("[dry-run: skipping change to issue %d]\n%s\n=======================\n", id, comment.String())
+ return true, nil
+ }
+
+ ctx := context.Background()
+ req := &github.IssueRequest{State: github.String("open")}
+ _, _, err = b.client.Issues.Edit(ctx, b.owner, b.repo, id, req)
+ if err != nil {
+ return true, fmt.Errorf("failed to reactivate issue %d: %v", id, err)
+ }
+
+ cmt := &github.IssueComment{
+ Body: github.String(comment.String()),
+ Reactions: &github.Reactions{Confused: github.Int(1)},
+ }
+ if _, _, err := b.client.Issues.CreateComment(ctx, b.owner, b.repo, id, cmt); err != nil {
+ return true, fmt.Errorf("failed to add comment to issue %d: %v", id, err)
+ }
+
+ return true, nil
+}
+
+func processAllPages(fn func(github.ListOptions) (*github.Response, error)) error {
+ opts := github.ListOptions{PerPage: 1000}
+ for {
+ resp, err := fn(opts)
+ if err != nil {
+ if rateErr, ok := err.(*github.RateLimitError); ok {
+ duration := rateErr.Rate.Reset.Sub(time.Now())
+ if duration > 5*time.Minute {
+ return fmt.Errorf("Rate limited for too long: %v", duration)
+ }
+ fmt.Printf("Rate limited, sleeping for: %v\n", duration)
+ time.Sleep(duration)
+ continue
+ }
+ return err
+ }
+ if resp.NextPage == 0 {
+ return nil
+ }
+ opts.Page = resp.NextPage
+ }
+}
diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go
new file mode 100644
index 000000000..4256f5a6c
--- /dev/null
+++ b/tools/issue_reviver/main.go
@@ -0,0 +1,89 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package main is the entry point for issue_reviver.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "os"
+
+ "gvisor.dev/gvisor/tools/issue_reviver/github"
+ "gvisor.dev/gvisor/tools/issue_reviver/reviver"
+)
+
+var (
+ owner string
+ repo string
+ tokenFile string
+ path string
+ dryRun bool
+)
+
+// Keep the options simple for now. Supports only a single path and repo.
+func init() {
+ flag.StringVar(&owner, "owner", "google", "Github project org/owner to look for issues")
+ flag.StringVar(&repo, "repo", "gvisor", "Github repo to look for issues")
+ flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github")
+ flag.StringVar(&path, "path", "", "Path to scan for TODOs")
+ flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues")
+}
+
+func main() {
+ flag.Parse()
+
+ // Check for mandatory parameters.
+ if len(owner) == 0 {
+ fmt.Println("missing --owner option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(repo) == 0 {
+ fmt.Println("missing --repo option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(path) == 0 {
+ fmt.Println("missing --path option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // Token is passed as a file so it doesn't show up in command line arguments.
+ var token string
+ if len(tokenFile) != 0 {
+ bytes, err := ioutil.ReadFile(tokenFile)
+ if err != nil {
+ fmt.Println(err.Error())
+ os.Exit(1)
+ }
+ token = string(bytes)
+ }
+
+ bugger, err := github.NewBugger(token, owner, repo, dryRun)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error getting github issues:", err)
+ os.Exit(1)
+ }
+ rev := reviver.New([]string{path}, []reviver.Bugger{bugger})
+ if errs := rev.Run(); len(errs) > 0 {
+ fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs))
+ for _, err := range errs {
+ fmt.Fprintf(os.Stderr, "\t%v\n", err)
+ }
+ os.Exit(1)
+ }
+}
diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD
new file mode 100644
index 000000000..2c3675977
--- /dev/null
+++ b/tools/issue_reviver/reviver/BUILD
@@ -0,0 +1,19 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "reviver",
+ srcs = ["reviver.go"],
+ importpath = "gvisor.dev/gvisor/tools/issue_reviver/reviver",
+ visibility = [
+ "//tools/issue_reviver:__subpackages__",
+ ],
+)
+
+go_test(
+ name = "reviver_test",
+ size = "small",
+ srcs = ["reviver_test.go"],
+ embed = [":reviver"],
+)
diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/issue_reviver/reviver/reviver.go
new file mode 100644
index 000000000..682db0c01
--- /dev/null
+++ b/tools/issue_reviver/reviver/reviver.go
@@ -0,0 +1,192 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package reviver scans the code looking for TODOs and pass them to registered
+// Buggers to ensure TODOs point to active issues.
+package reviver
+
+import (
+ "bufio"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "regexp"
+ "sync"
+)
+
+// This is how a TODO looks like.
+var regexTodo = regexp.MustCompile(`(\/\/|#)\s*(TODO|FIXME)\(([a-zA-Z0-9.\/]+)\):\s*(.+)`)
+
+// Bugger interface is called for every TODO found in the code. If it can handle
+// the TODO, it must return true. If it returns false, the next Bugger is
+// called. If no Bugger handles the TODO, it's dropped on the floor.
+type Bugger interface {
+ Activate(todo *Todo) (bool, error)
+}
+
+// Location saves the location where the TODO was found.
+type Location struct {
+ Comment string
+ File string
+ Line uint
+}
+
+// Todo represents a unique TODO. There can be several TODOs pointing to the
+// same issue in the code. They are all grouped together.
+type Todo struct {
+ Issue string
+ Locations []Location
+}
+
+// Reviver scans the given paths for TODOs and calls Buggers to handle them.
+type Reviver struct {
+ paths []string
+ buggers []Bugger
+
+ mu sync.Mutex
+ todos map[string]*Todo
+ errs []error
+}
+
+// New create a new Reviver.
+func New(paths []string, buggers []Bugger) *Reviver {
+ return &Reviver{
+ paths: paths,
+ buggers: buggers,
+ todos: map[string]*Todo{},
+ }
+}
+
+// Run runs. It returns all errors found during processing, it doesn't stop
+// on errors.
+func (r *Reviver) Run() []error {
+ // Process each directory in parallel.
+ wg := sync.WaitGroup{}
+ for _, path := range r.paths {
+ wg.Add(1)
+ go func(path string) {
+ defer wg.Done()
+ r.processPath(path, &wg)
+ }(path)
+ }
+
+ wg.Wait()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ fmt.Printf("Processing %d TODOs (%d errors)...\n", len(r.todos), len(r.errs))
+ dropped := 0
+ for _, todo := range r.todos {
+ ok, err := r.processTodo(todo)
+ if err != nil {
+ r.errs = append(r.errs, err)
+ }
+ if !ok {
+ dropped++
+ }
+ }
+ fmt.Printf("Processed %d TODOs, %d were skipped (%d errors)\n", len(r.todos)-dropped, dropped, len(r.errs))
+
+ return r.errs
+}
+
+func (r *Reviver) processPath(path string, wg *sync.WaitGroup) {
+ fmt.Printf("Processing dir %q\n", path)
+ fis, err := ioutil.ReadDir(path)
+ if err != nil {
+ r.addErr(fmt.Errorf("error processing dir %q: %v", path, err))
+ return
+ }
+
+ for _, fi := range fis {
+ childPath := filepath.Join(path, fi.Name())
+ switch {
+ case fi.Mode().IsDir():
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r.processPath(childPath, wg)
+ }()
+
+ case fi.Mode().IsRegular():
+ file, err := os.Open(childPath)
+ if err != nil {
+ r.addErr(err)
+ continue
+ }
+
+ scanner := bufio.NewScanner(file)
+ lineno := uint(0)
+ for scanner.Scan() {
+ lineno++
+ line := scanner.Text()
+ if todo := r.processLine(line, childPath, lineno); todo != nil {
+ r.addTodo(todo)
+ }
+ }
+ }
+ }
+}
+
+func (r *Reviver) processLine(line, path string, lineno uint) *Todo {
+ matches := regexTodo.FindStringSubmatch(line)
+ if matches == nil {
+ return nil
+ }
+ if len(matches) != 5 {
+ panic(fmt.Sprintf("regex returned wrong matches for %q: %v", line, matches))
+ }
+ return &Todo{
+ Issue: matches[3],
+ Locations: []Location{
+ {
+ File: path,
+ Line: lineno,
+ Comment: matches[4],
+ },
+ },
+ }
+}
+
+func (r *Reviver) addTodo(newTodo *Todo) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if todo := r.todos[newTodo.Issue]; todo == nil {
+ r.todos[newTodo.Issue] = newTodo
+ } else {
+ todo.Locations = append(todo.Locations, newTodo.Locations...)
+ }
+}
+
+func (r *Reviver) addErr(err error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.errs = append(r.errs, err)
+}
+
+func (r *Reviver) processTodo(todo *Todo) (bool, error) {
+ for _, bugger := range r.buggers {
+ ok, err := bugger.Activate(todo)
+ if err != nil {
+ return false, err
+ }
+ if ok {
+ return true, nil
+ }
+ }
+ return false, nil
+}
diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/issue_reviver/reviver/reviver_test.go
new file mode 100644
index 000000000..a9fb1f9f1
--- /dev/null
+++ b/tools/issue_reviver/reviver/reviver_test.go
@@ -0,0 +1,88 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reviver
+
+import (
+ "testing"
+)
+
+func TestProcessLine(t *testing.T) {
+ for _, tc := range []struct {
+ line string
+ want *Todo
+ }{
+ {
+ line: "// TODO(foobar.com/issue/123): comment, bla. blabla.",
+ want: &Todo{
+ Issue: "foobar.com/issue/123",
+ Locations: []Location{
+ {Comment: "comment, bla. blabla."},
+ },
+ },
+ },
+ {
+ line: "// FIXME(b/123): internal bug",
+ want: &Todo{
+ Issue: "b/123",
+ Locations: []Location{
+ {Comment: "internal bug"},
+ },
+ },
+ },
+ {
+ line: "TODO(issue): not todo",
+ },
+ {
+ line: "FIXME(issue): not todo",
+ },
+ {
+ line: "// TODO (issue): not todo",
+ },
+ {
+ line: "// TODO(issue) not todo",
+ },
+ {
+ line: "// todo(issue): not todo",
+ },
+ {
+ line: "// TODO(issue):",
+ },
+ } {
+ t.Logf("Testing: %s", tc.line)
+ r := Reviver{}
+ got := r.processLine(tc.line, "test", 0)
+ if got == nil {
+ if tc.want != nil {
+ t.Errorf("failed to process line, want: %+v", tc.want)
+ }
+ } else {
+ if tc.want == nil {
+ t.Errorf("expected error, got: %+v", got)
+ continue
+ }
+ if got.Issue != tc.want.Issue {
+ t.Errorf("wrong issue, got: %v, want: %v", got.Issue, tc.want.Issue)
+ }
+ if len(got.Locations) != len(tc.want.Locations) {
+ t.Errorf("wrong number of locations, got: %v, want: %v, locations: %+v", len(got.Locations), len(tc.want.Locations), got.Locations)
+ }
+ for i, wantLoc := range tc.want.Locations {
+ if got.Locations[i].Comment != wantLoc.Comment {
+ t.Errorf("wrong comment, got: %v, want: %v", got.Locations[i].Comment, wantLoc.Comment)
+ }
+ }
+ }
+ }
+}