summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/abi.go41
-rw-r--r--pkg/abi/abi_linux.go20
-rwxr-xr-xpkg/abi/abi_state_autogen.go4
-rw-r--r--pkg/abi/flag.go85
-rw-r--r--pkg/abi/linux/aio.go20
-rw-r--r--pkg/abi/linux/ashmem.go29
-rw-r--r--pkg/abi/linux/audit.go23
-rw-r--r--pkg/abi/linux/binder.go20
-rw-r--r--pkg/abi/linux/bpf.go34
-rw-r--r--pkg/abi/linux/capability.go105
-rw-r--r--pkg/abi/linux/dev.go55
-rw-r--r--pkg/abi/linux/elf.go91
-rw-r--r--pkg/abi/linux/errors.go172
-rw-r--r--pkg/abi/linux/eventfd.go22
-rw-r--r--pkg/abi/linux/exec.go18
-rw-r--r--pkg/abi/linux/fcntl.go36
-rw-r--r--pkg/abi/linux/file.go267
-rw-r--r--pkg/abi/linux/fs.go84
-rw-r--r--pkg/abi/linux/futex.go62
-rw-r--r--pkg/abi/linux/inotify.go97
-rw-r--r--pkg/abi/linux/ioctl.go99
-rw-r--r--pkg/abi/linux/ip.go151
-rw-r--r--pkg/abi/linux/ipc.go53
-rw-r--r--pkg/abi/linux/limits.go88
-rw-r--r--pkg/abi/linux/linux.go39
-rwxr-xr-xpkg/abi/linux/linux_state_autogen.go68
-rw-r--r--pkg/abi/linux/mm.go116
-rw-r--r--pkg/abi/linux/netdevice.go86
-rw-r--r--pkg/abi/linux/netfilter.go240
-rw-r--r--pkg/abi/linux/netlink.go124
-rw-r--r--pkg/abi/linux/netlink_route.go191
-rw-r--r--pkg/abi/linux/poll.go42
-rw-r--r--pkg/abi/linux/prctl.go157
-rw-r--r--pkg/abi/linux/ptrace.go89
-rw-r--r--pkg/abi/linux/rusage.go46
-rw-r--r--pkg/abi/linux/sched.go30
-rw-r--r--pkg/abi/linux/seccomp.go65
-rw-r--r--pkg/abi/linux/sem.go52
-rw-r--r--pkg/abi/linux/shm.go86
-rw-r--r--pkg/abi/linux/signal.go232
-rw-r--r--pkg/abi/linux/socket.go385
-rw-r--r--pkg/abi/linux/splice.go23
-rw-r--r--pkg/abi/linux/tcp.go60
-rw-r--r--pkg/abi/linux/time.go228
-rw-r--r--pkg/abi/linux/timer.go23
-rw-r--r--pkg/abi/linux/tty.go344
-rw-r--r--pkg/abi/linux/uio.go18
-rw-r--r--pkg/abi/linux/utsname.go49
-rw-r--r--pkg/abi/linux/wait.go36
-rw-r--r--pkg/amutex/amutex.go120
-rwxr-xr-xpkg/amutex/amutex_state_autogen.go4
-rw-r--r--pkg/atomicbitops/atomic_bitops.go59
-rw-r--r--pkg/atomicbitops/atomic_bitops_amd64.s115
-rw-r--r--pkg/atomicbitops/atomic_bitops_common.go147
-rwxr-xr-xpkg/atomicbitops/atomicbitops_state_autogen.go4
-rw-r--r--pkg/binary/binary.go256
-rwxr-xr-xpkg/binary/binary_state_autogen.go4
-rw-r--r--pkg/bits/bits.go16
-rwxr-xr-xpkg/bits/bits32.go25
-rwxr-xr-xpkg/bits/bits64.go25
-rwxr-xr-xpkg/bits/bits_state_autogen.go4
-rw-r--r--pkg/bits/uint64_arch_amd64.go36
-rw-r--r--pkg/bits/uint64_arch_amd64_asm.s31
-rw-r--r--pkg/bits/uint64_arch_generic.go55
-rw-r--r--pkg/bpf/bpf.go129
-rwxr-xr-xpkg/bpf/bpf_state_autogen.go22
-rw-r--r--pkg/bpf/decoder.go245
-rw-r--r--pkg/bpf/input_bytes.go58
-rw-r--r--pkg/bpf/interpreter.go412
-rw-r--r--pkg/bpf/program_builder.go191
-rw-r--r--pkg/compressio/compressio.go743
-rwxr-xr-xpkg/compressio/compressio_state_autogen.go4
-rw-r--r--pkg/control/client/client.go33
-rwxr-xr-xpkg/control/client/client_state_autogen.go4
-rw-r--r--pkg/control/server/server.go160
-rwxr-xr-xpkg/control/server/server_state_autogen.go4
-rw-r--r--pkg/cpuid/cpu_amd64.s24
-rw-r--r--pkg/cpuid/cpuid.go941
-rwxr-xr-xpkg/cpuid/cpuid_state_autogen.go36
-rw-r--r--pkg/eventchannel/event.go165
-rwxr-xr-xpkg/eventchannel/eventchannel_go_proto/event.pb.go85
-rwxr-xr-xpkg/eventchannel/eventchannel_state_autogen.go4
-rw-r--r--pkg/fd/fd.go234
-rwxr-xr-xpkg/fd/fd_state_autogen.go4
-rw-r--r--pkg/fdnotifier/fdnotifier.go202
-rwxr-xr-xpkg/fdnotifier/fdnotifier_state_autogen.go4
-rw-r--r--pkg/fdnotifier/poll_unsafe.go76
-rw-r--r--pkg/gate/gate.go134
-rwxr-xr-xpkg/gate/gate_state_autogen.go4
-rwxr-xr-xpkg/ilist/ilist_state_autogen.go38
-rwxr-xr-xpkg/ilist/interface_list.go192
-rw-r--r--pkg/linewriter/linewriter.go78
-rwxr-xr-xpkg/linewriter/linewriter_state_autogen.go4
-rw-r--r--pkg/log/glog.go163
-rw-r--r--pkg/log/glog_unsafe.go32
-rw-r--r--pkg/log/json.go76
-rw-r--r--pkg/log/json_k8s.go47
-rw-r--r--pkg/log/log.go323
-rwxr-xr-xpkg/log/log_state_autogen.go4
-rw-r--r--pkg/metric/metric.go250
-rwxr-xr-xpkg/metric/metric_go_proto/metric.pb.go297
-rwxr-xr-xpkg/metric/metric_state_autogen.go4
-rw-r--r--pkg/p9/buffer.go263
-rw-r--r--pkg/p9/client.go307
-rw-r--r--pkg/p9/client_file.go632
-rw-r--r--pkg/p9/file.go256
-rw-r--r--pkg/p9/handlers.go1291
-rw-r--r--pkg/p9/messages.go2359
-rw-r--r--pkg/p9/p9.go1141
-rwxr-xr-xpkg/p9/p9_state_autogen.go4
-rw-r--r--pkg/p9/path_tree.go109
-rw-r--r--pkg/p9/pool.go68
-rw-r--r--pkg/p9/server.go575
-rw-r--r--pkg/p9/transport.go342
-rw-r--r--pkg/p9/version.go150
-rw-r--r--pkg/rand/rand.go29
-rw-r--r--pkg/rand/rand_linux.go62
-rwxr-xr-xpkg/rand/rand_state_autogen.go4
-rw-r--r--pkg/refs/refcounter.go303
-rw-r--r--pkg/refs/refcounter_state.go35
-rwxr-xr-xpkg/refs/refs_state_autogen.go77
-rwxr-xr-xpkg/refs/weak_ref_list.go173
-rw-r--r--pkg/seccomp/seccomp.go375
-rw-r--r--pkg/seccomp/seccomp_amd64.go26
-rw-r--r--pkg/seccomp/seccomp_arm64.go26
-rw-r--r--pkg/seccomp/seccomp_rules.go132
-rwxr-xr-xpkg/seccomp/seccomp_state_autogen.go4
-rw-r--r--pkg/seccomp/seccomp_unsafe.go70
-rw-r--r--pkg/secio/full_reader.go34
-rw-r--r--pkg/secio/secio.go105
-rwxr-xr-xpkg/secio/secio_state_autogen.go4
-rw-r--r--pkg/sentry/arch/aligned.go31
-rw-r--r--pkg/sentry/arch/arch.go359
-rw-r--r--pkg/sentry/arch/arch_amd64.go325
-rw-r--r--pkg/sentry/arch/arch_amd64.s135
-rwxr-xr-xpkg/sentry/arch/arch_state_autogen.go193
-rw-r--r--pkg/sentry/arch/arch_state_x86.go131
-rw-r--r--pkg/sentry/arch/arch_x86.go621
-rw-r--r--pkg/sentry/arch/auxv.go30
-rwxr-xr-xpkg/sentry/arch/registers_go_proto/registers.pb.go367
-rw-r--r--pkg/sentry/arch/signal_act.go79
-rw-r--r--pkg/sentry/arch/signal_amd64.go521
-rw-r--r--pkg/sentry/arch/signal_info.go66
-rw-r--r--pkg/sentry/arch/signal_stack.go65
-rw-r--r--pkg/sentry/arch/stack.go252
-rw-r--r--pkg/sentry/arch/syscalls_amd64.go52
-rw-r--r--pkg/sentry/context/context.go126
-rwxr-xr-xpkg/sentry/context/context_state_autogen.go4
-rw-r--r--pkg/sentry/control/control.go17
-rwxr-xr-xpkg/sentry/control/control_state_autogen.go4
-rw-r--r--pkg/sentry/control/pprof.go168
-rw-r--r--pkg/sentry/control/proc.go390
-rw-r--r--pkg/sentry/control/state.go73
-rw-r--r--pkg/sentry/device/device.go266
-rwxr-xr-xpkg/sentry/device/device_state_autogen.go52
-rw-r--r--pkg/sentry/fs/anon/anon.go42
-rwxr-xr-xpkg/sentry/fs/anon/anon_state_autogen.go4
-rw-r--r--pkg/sentry/fs/anon/device.go22
-rw-r--r--pkg/sentry/fs/ashmem/area.go308
-rwxr-xr-xpkg/sentry/fs/ashmem/ashmem_state_autogen.go123
-rw-r--r--pkg/sentry/fs/ashmem/device.go61
-rw-r--r--pkg/sentry/fs/ashmem/pin_board.go127
-rwxr-xr-xpkg/sentry/fs/ashmem/uint64_range.go62
-rwxr-xr-xpkg/sentry/fs/ashmem/uint64_set.go1270
-rw-r--r--pkg/sentry/fs/attr.go422
-rw-r--r--pkg/sentry/fs/binder/binder.go260
-rwxr-xr-xpkg/sentry/fs/binder/binder_state_autogen.go40
-rw-r--r--pkg/sentry/fs/context.go114
-rw-r--r--pkg/sentry/fs/copy_up.go433
-rw-r--r--pkg/sentry/fs/dentry.go234
-rw-r--r--pkg/sentry/fs/dev/dev.go146
-rwxr-xr-xpkg/sentry/fs/dev/dev_state_autogen.go108
-rw-r--r--pkg/sentry/fs/dev/device.go20
-rw-r--r--pkg/sentry/fs/dev/fs.go99
-rw-r--r--pkg/sentry/fs/dev/full.go81
-rw-r--r--pkg/sentry/fs/dev/null.go130
-rw-r--r--pkg/sentry/fs/dev/random.go79
-rw-r--r--pkg/sentry/fs/dirent.go1675
-rw-r--r--pkg/sentry/fs/dirent_cache.go175
-rw-r--r--pkg/sentry/fs/dirent_cache_limiter.go55
-rwxr-xr-xpkg/sentry/fs/dirent_list.go173
-rw-r--r--pkg/sentry/fs/dirent_state.go77
-rwxr-xr-xpkg/sentry/fs/event_list.go173
-rwxr-xr-xpkg/sentry/fs/fdpipe/fdpipe_state_autogen.go27
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go168
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_opener.go193
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_state.go89
-rw-r--r--pkg/sentry/fs/file.go556
-rw-r--r--pkg/sentry/fs/file_operations.go159
-rw-r--r--pkg/sentry/fs/file_overlay.go505
-rw-r--r--pkg/sentry/fs/file_state.go31
-rw-r--r--pkg/sentry/fs/filesystems.go174
-rw-r--r--pkg/sentry/fs/flags.go121
-rw-r--r--pkg/sentry/fs/fs.go161
-rwxr-xr-xpkg/sentry/fs/fs_state_autogen.go626
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set.go237
-rwxr-xr-xpkg/sentry/fs/fsutil/dirty_set_impl.go1274
-rw-r--r--pkg/sentry/fs/fsutil/file.go394
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go209
-rwxr-xr-xpkg/sentry/fs/fsutil/file_range_set_impl.go1274
-rw-r--r--pkg/sentry/fs/fsutil/frame_ref_set.go50
-rwxr-xr-xpkg/sentry/fs/fsutil/frame_ref_set_impl.go1274
-rw-r--r--pkg/sentry/fs/fsutil/fsutil.go24
-rwxr-xr-xpkg/sentry/fs/fsutil/fsutil_state_autogen.go349
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go211
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper_state.go20
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go27
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go197
-rw-r--r--pkg/sentry/fs/fsutil/inode.go503
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go1004
-rw-r--r--pkg/sentry/fs/gofer/attr.go162
-rw-r--r--pkg/sentry/fs/gofer/cache_policy.go183
-rw-r--r--pkg/sentry/fs/gofer/context_file.go190
-rw-r--r--pkg/sentry/fs/gofer/device.go20
-rw-r--r--pkg/sentry/fs/gofer/file.go333
-rw-r--r--pkg/sentry/fs/gofer/file_state.go39
-rw-r--r--pkg/sentry/fs/gofer/fs.go247
-rwxr-xr-xpkg/sentry/fs/gofer/gofer_state_autogen.go113
-rw-r--r--pkg/sentry/fs/gofer/handles.go129
-rw-r--r--pkg/sentry/fs/gofer/inode.go606
-rw-r--r--pkg/sentry/fs/gofer/inode_state.go172
-rw-r--r--pkg/sentry/fs/gofer/path.go433
-rw-r--r--pkg/sentry/fs/gofer/session.go361
-rw-r--r--pkg/sentry/fs/gofer/session_state.go115
-rw-r--r--pkg/sentry/fs/gofer/socket.go141
-rw-r--r--pkg/sentry/fs/gofer/util.go60
-rw-r--r--pkg/sentry/fs/host/control.go93
-rw-r--r--pkg/sentry/fs/host/descriptor.go120
-rw-r--r--pkg/sentry/fs/host/descriptor_state.go29
-rw-r--r--pkg/sentry/fs/host/device.go25
-rw-r--r--pkg/sentry/fs/host/file.go286
-rw-r--r--pkg/sentry/fs/host/fs.go339
-rwxr-xr-xpkg/sentry/fs/host/host_state_autogen.go142
-rw-r--r--pkg/sentry/fs/host/inode.go527
-rw-r--r--pkg/sentry/fs/host/inode_state.go79
-rw-r--r--pkg/sentry/fs/host/ioctl_unsafe.go56
-rw-r--r--pkg/sentry/fs/host/socket.go390
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go113
-rw-r--r--pkg/sentry/fs/host/socket_state.go42
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go100
-rw-r--r--pkg/sentry/fs/host/tty.go351
-rw-r--r--pkg/sentry/fs/host/util.go197
-rw-r--r--pkg/sentry/fs/host/util_unsafe.go137
-rw-r--r--pkg/sentry/fs/inode.go440
-rw-r--r--pkg/sentry/fs/inode_inotify.go169
-rw-r--r--pkg/sentry/fs/inode_operations.go308
-rw-r--r--pkg/sentry/fs/inode_overlay.go676
-rw-r--r--pkg/sentry/fs/inotify.go348
-rw-r--r--pkg/sentry/fs/inotify_event.go139
-rw-r--r--pkg/sentry/fs/inotify_watch.go135
-rw-r--r--pkg/sentry/fs/lock/lock.go461
-rwxr-xr-xpkg/sentry/fs/lock/lock_range.go62
-rwxr-xr-xpkg/sentry/fs/lock/lock_set.go1270
-rw-r--r--pkg/sentry/fs/lock/lock_set_functions.go69
-rwxr-xr-xpkg/sentry/fs/lock/lock_state_autogen.go106
-rw-r--r--pkg/sentry/fs/mock.go170
-rw-r--r--pkg/sentry/fs/mount.go267
-rw-r--r--pkg/sentry/fs/mount_overlay.go136
-rw-r--r--pkg/sentry/fs/mounts.go675
-rw-r--r--pkg/sentry/fs/offset.go65
-rw-r--r--pkg/sentry/fs/overlay.go303
-rw-r--r--pkg/sentry/fs/path.go119
-rw-r--r--pkg/sentry/fs/proc/cgroup.go41
-rw-r--r--pkg/sentry/fs/proc/cpuinfo.go35
-rw-r--r--pkg/sentry/fs/proc/device/device.go23
-rwxr-xr-xpkg/sentry/fs/proc/device/device_state_autogen.go4
-rw-r--r--pkg/sentry/fs/proc/exec_args.go203
-rw-r--r--pkg/sentry/fs/proc/fds.go285
-rw-r--r--pkg/sentry/fs/proc/filesystems.go61
-rw-r--r--pkg/sentry/fs/proc/fs.go81
-rw-r--r--pkg/sentry/fs/proc/inode.go97
-rw-r--r--pkg/sentry/fs/proc/loadavg.go55
-rw-r--r--pkg/sentry/fs/proc/meminfo.go85
-rw-r--r--pkg/sentry/fs/proc/mounts.go197
-rw-r--r--pkg/sentry/fs/proc/net.go308
-rw-r--r--pkg/sentry/fs/proc/proc.go251
-rwxr-xr-xpkg/sentry/fs/proc/proc_state_autogen.go657
-rw-r--r--pkg/sentry/fs/proc/rpcinet_proc.go217
-rw-r--r--pkg/sentry/fs/proc/seqfile/seqfile.go282
-rwxr-xr-xpkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go58
-rw-r--r--pkg/sentry/fs/proc/stat.go142
-rw-r--r--pkg/sentry/fs/proc/sys.go162
-rw-r--r--pkg/sentry/fs/proc/sys_net.go355
-rw-r--r--pkg/sentry/fs/proc/sys_net_state.go42
-rw-r--r--pkg/sentry/fs/proc/task.go776
-rw-r--r--pkg/sentry/fs/proc/uid_gid_map.go179
-rw-r--r--pkg/sentry/fs/proc/uptime.go87
-rw-r--r--pkg/sentry/fs/proc/version.go78
-rw-r--r--pkg/sentry/fs/ramfs/dir.go534
-rwxr-xr-xpkg/sentry/fs/ramfs/ramfs_state_autogen.go94
-rw-r--r--pkg/sentry/fs/ramfs/socket.go85
-rw-r--r--pkg/sentry/fs/ramfs/symlink.go106
-rw-r--r--pkg/sentry/fs/ramfs/tree.go77
-rw-r--r--pkg/sentry/fs/restore.go78
-rw-r--r--pkg/sentry/fs/save.go77
-rw-r--r--pkg/sentry/fs/seek.go43
-rw-r--r--pkg/sentry/fs/splice.go187
-rw-r--r--pkg/sentry/fs/sync.go43
-rw-r--r--pkg/sentry/fs/sys/device.go20
-rw-r--r--pkg/sentry/fs/sys/devices.go91
-rw-r--r--pkg/sentry/fs/sys/fs.go65
-rw-r--r--pkg/sentry/fs/sys/sys.go64
-rwxr-xr-xpkg/sentry/fs/sys/sys_state_autogen.go34
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go148
-rwxr-xr-xpkg/sentry/fs/timerfd/timerfd_state_autogen.go25
-rw-r--r--pkg/sentry/fs/tmpfs/device.go20
-rw-r--r--pkg/sentry/fs/tmpfs/file_regular.go60
-rw-r--r--pkg/sentry/fs/tmpfs/fs.go136
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go681
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go348
-rwxr-xr-xpkg/sentry/fs/tmpfs/tmpfs_state_autogen.go108
-rw-r--r--pkg/sentry/fs/tty/dir.go339
-rw-r--r--pkg/sentry/fs/tty/fs.go104
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go443
-rw-r--r--pkg/sentry/fs/tty/master.go220
-rw-r--r--pkg/sentry/fs/tty/queue.go244
-rw-r--r--pkg/sentry/fs/tty/slave.go162
-rw-r--r--pkg/sentry/fs/tty/terminal.go46
-rwxr-xr-xpkg/sentry/fs/tty/tty_state_autogen.go202
-rw-r--r--pkg/sentry/hostcpu/getcpu_amd64.s24
-rw-r--r--pkg/sentry/hostcpu/hostcpu.go67
-rwxr-xr-xpkg/sentry/hostcpu/hostcpu_state_autogen.go4
-rw-r--r--pkg/sentry/inet/context.go35
-rw-r--r--pkg/sentry/inet/inet.go104
-rwxr-xr-xpkg/sentry/inet/inet_state_autogen.go26
-rw-r--r--pkg/sentry/inet/test_stack.go83
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go111
-rw-r--r--pkg/sentry/kernel/auth/auth.go22
-rwxr-xr-xpkg/sentry/kernel/auth/auth_state_autogen.go151
-rw-r--r--pkg/sentry/kernel/auth/capability_set.go61
-rw-r--r--pkg/sentry/kernel/auth/context.go36
-rw-r--r--pkg/sentry/kernel/auth/credentials.go234
-rw-r--r--pkg/sentry/kernel/auth/id.go121
-rw-r--r--pkg/sentry/kernel/auth/id_map.go285
-rw-r--r--pkg/sentry/kernel/auth/id_map_functions.go45
-rwxr-xr-xpkg/sentry/kernel/auth/id_map_range.go62
-rwxr-xr-xpkg/sentry/kernel/auth/id_map_set.go1270
-rw-r--r--pkg/sentry/kernel/auth/user_namespace.go129
-rw-r--r--pkg/sentry/kernel/context.go135
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go473
-rwxr-xr-xpkg/sentry/kernel/epoll/epoll_list.go173
-rw-r--r--pkg/sentry/kernel/epoll/epoll_state.go49
-rwxr-xr-xpkg/sentry/kernel/epoll/epoll_state_autogen.go99
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go283
-rwxr-xr-xpkg/sentry/kernel/eventfd/eventfd_state_autogen.go27
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go148
-rwxr-xr-xpkg/sentry/kernel/fasync/fasync_state_autogen.go30
-rw-r--r--pkg/sentry/kernel/fd_map.go364
-rw-r--r--pkg/sentry/kernel/fs_context.go187
-rwxr-xr-xpkg/sentry/kernel/futex/atomicptr_bucket.go27
-rw-r--r--pkg/sentry/kernel/futex/futex.go783
-rwxr-xr-xpkg/sentry/kernel/futex/futex_state_autogen.go62
-rwxr-xr-xpkg/sentry/kernel/futex/waiter_list.go173
-rw-r--r--pkg/sentry/kernel/ipc_namespace.go58
-rw-r--r--pkg/sentry/kernel/kdefs/kdefs.go20
-rwxr-xr-xpkg/sentry/kernel/kdefs/kdefs_state_autogen.go4
-rw-r--r--pkg/sentry/kernel/kernel.go1241
-rw-r--r--pkg/sentry/kernel/kernel_state.go42
-rwxr-xr-xpkg/sentry/kernel/kernel_state_autogen.go1147
-rw-r--r--pkg/sentry/kernel/pending_signals.go142
-rwxr-xr-xpkg/sentry/kernel/pending_signals_list.go173
-rw-r--r--pkg/sentry/kernel/pending_signals_state.go46
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go90
-rwxr-xr-xpkg/sentry/kernel/pipe/buffer_list.go173
-rw-r--r--pkg/sentry/kernel/pipe/device.go20
-rw-r--r--pkg/sentry/kernel/pipe/node.go196
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go429
-rwxr-xr-xpkg/sentry/kernel/pipe/pipe_state_autogen.go134
-rw-r--r--pkg/sentry/kernel/pipe/reader.go42
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go96
-rw-r--r--pkg/sentry/kernel/pipe/writer.go42
-rw-r--r--pkg/sentry/kernel/posixtimer.go306
-rwxr-xr-xpkg/sentry/kernel/process_group_list.go173
-rw-r--r--pkg/sentry/kernel/ptrace.go1105
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go89
-rw-r--r--pkg/sentry/kernel/ptrace_arm64.go28
-rw-r--r--pkg/sentry/kernel/rseq.go120
-rw-r--r--pkg/sentry/kernel/sched/cpuset.go105
-rw-r--r--pkg/sentry/kernel/sched/sched.go16
-rwxr-xr-xpkg/sentry/kernel/sched/sched_state_autogen.go4
-rw-r--r--pkg/sentry/kernel/seccomp.go217
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go571
-rwxr-xr-xpkg/sentry/kernel/semaphore/semaphore_state_autogen.go115
-rwxr-xr-xpkg/sentry/kernel/semaphore/waiter_list.go173
-rwxr-xr-xpkg/sentry/kernel/seqatomic_taskgoroutineschedinfo.go55
-rwxr-xr-xpkg/sentry/kernel/session_list.go173
-rw-r--r--pkg/sentry/kernel/sessions.go508
-rw-r--r--pkg/sentry/kernel/shm/device.go20
-rw-r--r--pkg/sentry/kernel/shm/shm.go671
-rwxr-xr-xpkg/sentry/kernel/shm/shm_state_autogen.go74
-rw-r--r--pkg/sentry/kernel/signal.go76
-rw-r--r--pkg/sentry/kernel/signal_handlers.go89
-rw-r--r--pkg/sentry/kernel/syscalls.go307
-rw-r--r--pkg/sentry/kernel/syscalls_state.go29
-rw-r--r--pkg/sentry/kernel/syslog.go106
-rw-r--r--pkg/sentry/kernel/task.go723
-rw-r--r--pkg/sentry/kernel/task_acct.go196
-rw-r--r--pkg/sentry/kernel/task_block.go212
-rw-r--r--pkg/sentry/kernel/task_clone.go516
-rw-r--r--pkg/sentry/kernel/task_context.go174
-rw-r--r--pkg/sentry/kernel/task_exec.go262
-rw-r--r--pkg/sentry/kernel/task_exit.go1159
-rw-r--r--pkg/sentry/kernel/task_futex.go54
-rw-r--r--pkg/sentry/kernel/task_identity.go568
-rwxr-xr-xpkg/sentry/kernel/task_list.go173
-rw-r--r--pkg/sentry/kernel/task_log.go137
-rw-r--r--pkg/sentry/kernel/task_net.go35
-rw-r--r--pkg/sentry/kernel/task_run.go340
-rw-r--r--pkg/sentry/kernel/task_sched.go637
-rw-r--r--pkg/sentry/kernel/task_signals.go1110
-rw-r--r--pkg/sentry/kernel/task_start.go287
-rw-r--r--pkg/sentry/kernel/task_stop.go226
-rw-r--r--pkg/sentry/kernel/task_syscall.go447
-rw-r--r--pkg/sentry/kernel/task_usermem.go301
-rw-r--r--pkg/sentry/kernel/thread_group.go330
-rw-r--r--pkg/sentry/kernel/threads.go465
-rw-r--r--pkg/sentry/kernel/time/context.go44
-rw-r--r--pkg/sentry/kernel/time/time.go691
-rwxr-xr-xpkg/sentry/kernel/time/time_state_autogen.go56
-rw-r--r--pkg/sentry/kernel/timekeeper.go306
-rw-r--r--pkg/sentry/kernel/timekeeper_state.go41
-rwxr-xr-xpkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go119
-rw-r--r--pkg/sentry/kernel/uts_namespace.go102
-rw-r--r--pkg/sentry/kernel/vdso.go148
-rw-r--r--pkg/sentry/kernel/version.go33
-rw-r--r--pkg/sentry/limits/context.go35
-rw-r--r--pkg/sentry/limits/limits.go136
-rwxr-xr-xpkg/sentry/limits/limits_state_autogen.go36
-rw-r--r--pkg/sentry/limits/linux.go100
-rw-r--r--pkg/sentry/loader/elf.go669
-rw-r--r--pkg/sentry/loader/interpreter.go108
-rw-r--r--pkg/sentry/loader/loader.go283
-rwxr-xr-xpkg/sentry/loader/loader_state_autogen.go57
-rw-r--r--pkg/sentry/loader/vdso.go402
-rwxr-xr-xpkg/sentry/loader/vdso_bin.go5
-rw-r--r--pkg/sentry/loader/vdso_state.go48
-rwxr-xr-xpkg/sentry/memmap/mappable_range.go62
-rw-r--r--pkg/sentry/memmap/mapping_set.go253
-rwxr-xr-xpkg/sentry/memmap/mapping_set_impl.go1270
-rw-r--r--pkg/sentry/memmap/memmap.go361
-rwxr-xr-xpkg/sentry/memmap/memmap_state_autogen.go93
-rw-r--r--pkg/sentry/memutil/memutil.go16
-rwxr-xr-xpkg/sentry/memutil/memutil_state_autogen.go4
-rw-r--r--pkg/sentry/memutil/memutil_unsafe.go39
-rw-r--r--pkg/sentry/mm/address_space.go216
-rw-r--r--pkg/sentry/mm/aio_context.go387
-rw-r--r--pkg/sentry/mm/aio_context_state.go20
-rw-r--r--pkg/sentry/mm/debug.go98
-rwxr-xr-xpkg/sentry/mm/file_refcount_set.go1274
-rw-r--r--pkg/sentry/mm/io.go639
-rwxr-xr-xpkg/sentry/mm/io_list.go173
-rw-r--r--pkg/sentry/mm/lifecycle.go234
-rw-r--r--pkg/sentry/mm/metadata.go139
-rw-r--r--pkg/sentry/mm/mm.go456
-rwxr-xr-xpkg/sentry/mm/mm_state_autogen.go380
-rw-r--r--pkg/sentry/mm/pma.go1036
-rwxr-xr-xpkg/sentry/mm/pma_set.go1274
-rw-r--r--pkg/sentry/mm/procfs.go289
-rw-r--r--pkg/sentry/mm/save_restore.go57
-rw-r--r--pkg/sentry/mm/shm.go66
-rw-r--r--pkg/sentry/mm/special_mappable.go155
-rw-r--r--pkg/sentry/mm/syscalls.go1197
-rw-r--r--pkg/sentry/mm/vma.go564
-rwxr-xr-xpkg/sentry/mm/vma_set.go1274
-rw-r--r--pkg/sentry/pgalloc/context.go48
-rwxr-xr-xpkg/sentry/pgalloc/evictable_range.go62
-rwxr-xr-xpkg/sentry/pgalloc/evictable_range_set.go1270
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go1187
-rwxr-xr-xpkg/sentry/pgalloc/pgalloc_state_autogen.go146
-rw-r--r--pkg/sentry/pgalloc/pgalloc_unsafe.go40
-rw-r--r--pkg/sentry/pgalloc/save_restore.go210
-rwxr-xr-xpkg/sentry/pgalloc/usage_set.go1274
-rw-r--r--pkg/sentry/platform/context.go36
-rwxr-xr-xpkg/sentry/platform/file_range.go62
-rw-r--r--pkg/sentry/platform/interrupt/interrupt.go96
-rwxr-xr-xpkg/sentry/platform/interrupt/interrupt_state_autogen.go4
-rw-r--r--pkg/sentry/platform/kvm/address_space.go234
-rw-r--r--pkg/sentry/platform/kvm/allocator.go76
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go82
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go141
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s93
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go56
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go127
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go213
-rw-r--r--pkg/sentry/platform/kvm/context.go87
-rw-r--r--pkg/sentry/platform/kvm/kvm.go143
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64.go213
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_unsafe.go77
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go64
-rwxr-xr-xpkg/sentry/platform/kvm/kvm_state_autogen.go4
-rw-r--r--pkg/sentry/platform/kvm/machine.go525
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go357
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go161
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go160
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go224
-rw-r--r--pkg/sentry/platform/kvm/virtual_map.go113
-rw-r--r--pkg/sentry/platform/mmap_min_addr.go60
-rw-r--r--pkg/sentry/platform/platform.go349
-rwxr-xr-xpkg/sentry/platform/platform_state_autogen.go24
-rw-r--r--pkg/sentry/platform/procid/procid.go21
-rw-r--r--pkg/sentry/platform/procid/procid_amd64.s30
-rw-r--r--pkg/sentry/platform/procid/procid_arm64.s29
-rwxr-xr-xpkg/sentry/platform/procid/procid_state_autogen.go4
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go238
-rwxr-xr-xpkg/sentry/platform/ptrace/ptrace_state_autogen.go4
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go166
-rw-r--r--pkg/sentry/platform/ptrace/stub_amd64.s114
-rw-r--r--pkg/sentry/platform/ptrace/stub_unsafe.go98
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go610
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go104
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go338
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go109
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_unsafe.go33
-rwxr-xr-xpkg/sentry/platform/ring0/defs_impl.go538
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.go128
-rwxr-xr-xpkg/sentry/platform/ring0/entry_impl_amd64.s383
-rw-r--r--pkg/sentry/platform/ring0/kernel.go66
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go271
-rw-r--r--pkg/sentry/platform/ring0/kernel_unsafe.go41
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.go131
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.s247
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator.go122
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go53
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go221
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go45
-rwxr-xr-xpkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go4
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_x86.go180
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_x86.go109
-rwxr-xr-xpkg/sentry/platform/ring0/pagetables/walker_empty.go255
-rwxr-xr-xpkg/sentry/platform/ring0/pagetables/walker_lookup.go255
-rwxr-xr-xpkg/sentry/platform/ring0/pagetables/walker_map.go255
-rwxr-xr-xpkg/sentry/platform/ring0/pagetables/walker_unmap.go255
-rw-r--r--pkg/sentry/platform/ring0/ring0.go16
-rwxr-xr-xpkg/sentry/platform/ring0/ring0_state_autogen.go4
-rw-r--r--pkg/sentry/platform/safecopy/atomic_amd64.s136
-rw-r--r--pkg/sentry/platform/safecopy/atomic_arm64.s126
-rw-r--r--pkg/sentry/platform/safecopy/memclr_amd64.s147
-rw-r--r--pkg/sentry/platform/safecopy/memclr_arm64.s74
-rw-r--r--pkg/sentry/platform/safecopy/memcpy_amd64.s250
-rw-r--r--pkg/sentry/platform/safecopy/memcpy_arm64.s78
-rw-r--r--pkg/sentry/platform/safecopy/safecopy.go144
-rwxr-xr-xpkg/sentry/platform/safecopy/safecopy_state_autogen.go4
-rw-r--r--pkg/sentry/platform/safecopy/safecopy_unsafe.go335
-rw-r--r--pkg/sentry/platform/safecopy/sighandler_amd64.s133
-rw-r--r--pkg/sentry/platform/safecopy/sighandler_arm64.s143
-rw-r--r--pkg/sentry/safemem/block_unsafe.go279
-rw-r--r--pkg/sentry/safemem/io.go339
-rw-r--r--pkg/sentry/safemem/safemem.go16
-rwxr-xr-xpkg/sentry/safemem/safemem_state_autogen.go4
-rw-r--r--pkg/sentry/safemem/seq_unsafe.go299
-rw-r--r--pkg/sentry/sighandling/sighandling.go140
-rwxr-xr-xpkg/sentry/sighandling/sighandling_state_autogen.go4
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go74
-rw-r--r--pkg/sentry/socket/control/control.go421
-rwxr-xr-xpkg/sentry/socket/control/control_state_autogen.go36
-rw-r--r--pkg/sentry/socket/epsocket/device.go20
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go2283
-rwxr-xr-xpkg/sentry/socket/epsocket/epsocket_state_autogen.go52
-rw-r--r--pkg/sentry/socket/epsocket/provider.go140
-rw-r--r--pkg/sentry/socket/epsocket/save_restore.go27
-rw-r--r--pkg/sentry/socket/epsocket/stack.go140
-rw-r--r--pkg/sentry/socket/hostinet/device.go19
-rw-r--r--pkg/sentry/socket/hostinet/hostinet.go17
-rwxr-xr-xpkg/sentry/socket/hostinet/hostinet_state_autogen.go4
-rw-r--r--pkg/sentry/socket/hostinet/save_restore.go20
-rw-r--r--pkg/sentry/socket/hostinet/socket.go578
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go138
-rw-r--r--pkg/sentry/socket/hostinet/stack.go246
-rw-r--r--pkg/sentry/socket/netlink/message.go159
-rwxr-xr-xpkg/sentry/socket/netlink/netlink_state_autogen.go36
-rw-r--r--pkg/sentry/socket/netlink/port/port.go116
-rwxr-xr-xpkg/sentry/socket/netlink/port/port_state_autogen.go22
-rw-r--r--pkg/sentry/socket/netlink/provider.go105
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go197
-rwxr-xr-xpkg/sentry/socket/netlink/route/route_state_autogen.go20
-rw-r--r--pkg/sentry/socket/netlink/socket.go618
-rw-r--r--pkg/sentry/socket/rpcinet/conn/conn.go187
-rwxr-xr-xpkg/sentry/socket/rpcinet/conn/conn_state_autogen.go4
-rw-r--r--pkg/sentry/socket/rpcinet/device.go19
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/notifier.go230
-rwxr-xr-xpkg/sentry/socket/rpcinet/notifier/notifier_state_autogen.go4
-rw-r--r--pkg/sentry/socket/rpcinet/rpcinet.go16
-rwxr-xr-xpkg/sentry/socket/rpcinet/rpcinet_state_autogen.go4
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go887
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go135
-rw-r--r--pkg/sentry/socket/rpcinet/stack_unsafe.go193
-rwxr-xr-xpkg/sentry/socket/rpcinet/syscall_rpc_go_proto/syscall_rpc.pb.go3938
-rw-r--r--pkg/sentry/socket/socket.go336
-rwxr-xr-xpkg/sentry/socket/socket_state_autogen.go24
-rw-r--r--pkg/sentry/socket/unix/device.go20
-rw-r--r--pkg/sentry/socket/unix/io.go93
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go460
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned_state.go53
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go196
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go210
-rwxr-xr-xpkg/sentry/socket/unix/transport/transport_message_list.go173
-rwxr-xr-xpkg/sentry/socket/unix/transport/transport_state_autogen.go191
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go973
-rw-r--r--pkg/sentry/socket/unix/unix.go650
-rwxr-xr-xpkg/sentry/socket/unix/unix_state_autogen.go28
-rw-r--r--pkg/sentry/state/state.go118
-rw-r--r--pkg/sentry/state/state_metadata.go45
-rwxr-xr-xpkg/sentry/state/state_state_autogen.go4
-rw-r--r--pkg/sentry/state/state_unsafe.go34
-rw-r--r--pkg/sentry/strace/capability.go176
-rw-r--r--pkg/sentry/strace/clone.go113
-rw-r--r--pkg/sentry/strace/futex.go52
-rw-r--r--pkg/sentry/strace/linux64.go338
-rw-r--r--pkg/sentry/strace/open.go96
-rw-r--r--pkg/sentry/strace/poll.go72
-rw-r--r--pkg/sentry/strace/ptrace.go62
-rw-r--r--pkg/sentry/strace/signal.go148
-rw-r--r--pkg/sentry/strace/socket.go412
-rw-r--r--pkg/sentry/strace/strace.go820
-rwxr-xr-xpkg/sentry/strace/strace_go_proto/strace.pb.go247
-rwxr-xr-xpkg/sentry/strace/strace_state_autogen.go4
-rw-r--r--pkg/sentry/strace/syscalls.go267
-rw-r--r--pkg/sentry/syscalls/epoll.go174
-rw-r--r--pkg/sentry/syscalls/linux/error.go114
-rw-r--r--pkg/sentry/syscalls/linux/flags.go53
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go487
-rwxr-xr-xpkg/sentry/syscalls/linux/linux_state_autogen.go80
-rw-r--r--pkg/sentry/syscalls/linux/sigset.go69
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go416
-rw-r--r--pkg/sentry/syscalls/linux/sys_capability.go149
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go171
-rw-r--r--pkg/sentry/syscalls/linux/sys_eventfd.go65
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go2088
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go278
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go269
-rw-r--r--pkg/sentry/syscalls/linux/sys_identity.go180
-rw-r--r--pkg/sentry/syscalls/linux/sys_inotify.go135
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go55
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go470
-rw-r--r--pkg/sentry/syscalls/linux/sys_mount.go146
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go79
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go549
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go201
-rw-r--r--pkg/sentry/syscalls/linux/sys_random.go92
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go357
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go224
-rw-r--r--pkg/sentry/syscalls/linux/sys_rusage.go112
-rw-r--r--pkg/sentry/syscalls/linux/sys_sched.go100
-rw-r--r--pkg/sentry/syscalls/linux/sys_seccomp.go77
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go241
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go156
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go508
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go1117
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go293
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go259
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go138
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go43
-rw-r--r--pkg/sentry/syscalls/linux/sys_syslog.go61
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go706
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go340
-rw-r--r--pkg/sentry/syscalls/linux/sys_timer.go203
-rw-r--r--pkg/sentry/syscalls/linux/sys_timerfd.go122
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls.go53
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go89
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go361
-rw-r--r--pkg/sentry/syscalls/linux/timespec.go112
-rw-r--r--pkg/sentry/syscalls/syscalls.go61
-rwxr-xr-xpkg/sentry/syscalls/syscalls_state_autogen.go4
-rw-r--r--pkg/sentry/time/arith_arm64.go70
-rw-r--r--pkg/sentry/time/calibrated_clock.go269
-rw-r--r--pkg/sentry/time/clock_id.go40
-rw-r--r--pkg/sentry/time/clocks.go31
-rw-r--r--pkg/sentry/time/muldiv_amd64.s44
-rw-r--r--pkg/sentry/time/muldiv_arm64.s44
-rw-r--r--pkg/sentry/time/parameters.go239
-rw-r--r--pkg/sentry/time/sampler.go225
-rw-r--r--pkg/sentry/time/sampler_unsafe.go56
-rwxr-xr-xpkg/sentry/time/seqatomic_parameters.go55
-rwxr-xr-xpkg/sentry/time/time_state_autogen.go4
-rw-r--r--pkg/sentry/time/tsc_amd64.s27
-rw-r--r--pkg/sentry/time/tsc_arm64.s22
-rw-r--r--pkg/sentry/unimpl/events.go45
-rwxr-xr-xpkg/sentry/unimpl/unimpl_state_autogen.go4
-rwxr-xr-xpkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go91
-rw-r--r--pkg/sentry/uniqueid/context.go54
-rwxr-xr-xpkg/sentry/uniqueid/uniqueid_state_autogen.go4
-rw-r--r--pkg/sentry/usage/cpu.go46
-rw-r--r--pkg/sentry/usage/io.go90
-rw-r--r--pkg/sentry/usage/memory.go284
-rw-r--r--pkg/sentry/usage/memory_unsafe.go27
-rw-r--r--pkg/sentry/usage/usage.go16
-rwxr-xr-xpkg/sentry/usage/usage_state_autogen.go50
-rw-r--r--pkg/sentry/usermem/access_type.go128
-rw-r--r--pkg/sentry/usermem/addr.go108
-rwxr-xr-xpkg/sentry/usermem/addr_range.go62
-rw-r--r--pkg/sentry/usermem/addr_range_seq_unsafe.go277
-rw-r--r--pkg/sentry/usermem/bytes_io.go126
-rw-r--r--pkg/sentry/usermem/bytes_io_unsafe.go47
-rw-r--r--pkg/sentry/usermem/usermem.go587
-rw-r--r--pkg/sentry/usermem/usermem_arm64.go53
-rwxr-xr-xpkg/sentry/usermem/usermem_state_autogen.go49
-rw-r--r--pkg/sentry/usermem/usermem_unsafe.go27
-rw-r--r--pkg/sentry/usermem/usermem_x86.go38
-rw-r--r--pkg/sentry/watchdog/watchdog.go305
-rwxr-xr-xpkg/sentry/watchdog/watchdog_state_autogen.go4
-rw-r--r--pkg/sleep/commit_amd64.s35
-rw-r--r--pkg/sleep/commit_asm.go20
-rw-r--r--pkg/sleep/commit_noasm.go42
-rwxr-xr-xpkg/sleep/sleep_state_autogen.go4
-rw-r--r--pkg/sleep/sleep_unsafe.go403
-rwxr-xr-xpkg/state/addr_range.go62
-rwxr-xr-xpkg/state/addr_set.go1274
-rw-r--r--pkg/state/decode.go605
-rw-r--r--pkg/state/encode.go466
-rw-r--r--pkg/state/encode_unsafe.go81
-rw-r--r--pkg/state/map.go221
-rwxr-xr-xpkg/state/object_go_proto/object.pb.go1195
-rw-r--r--pkg/state/printer.go251
-rw-r--r--pkg/state/state.go359
-rw-r--r--pkg/state/statefile/statefile.go232
-rwxr-xr-xpkg/state/statefile/statefile_state_autogen.go4
-rw-r--r--pkg/state/stats.go152
-rw-r--r--pkg/syserr/host_linux.go46
-rw-r--r--pkg/syserr/netstack.go102
-rw-r--r--pkg/syserr/syserr.go293
-rwxr-xr-xpkg/syserr/syserr_state_autogen.go4
-rw-r--r--pkg/syserror/syserror.go153
-rwxr-xr-xpkg/syserror/syserror_state_autogen.go4
-rwxr-xr-xpkg/tcpip/buffer/buffer_state_autogen.go24
-rw-r--r--pkg/tcpip/buffer/prependable.go74
-rw-r--r--pkg/tcpip/buffer/view.go158
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins.go80
-rwxr-xr-xpkg/tcpip/hash/jenkins/jenkins_state_autogen.go4
-rw-r--r--pkg/tcpip/header/arp.go100
-rw-r--r--pkg/tcpip/header/checksum.go94
-rw-r--r--pkg/tcpip/header/eth.go74
-rw-r--r--pkg/tcpip/header/gue.go73
-rwxr-xr-xpkg/tcpip/header/header_state_autogen.go42
-rw-r--r--pkg/tcpip/header/icmpv4.go108
-rw-r--r--pkg/tcpip/header/icmpv6.go121
-rw-r--r--pkg/tcpip/header/interfaces.go92
-rw-r--r--pkg/tcpip/header/ipv4.go282
-rw-r--r--pkg/tcpip/header/ipv6.go248
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go146
-rw-r--r--pkg/tcpip/header/tcp.go543
-rw-r--r--pkg/tcpip/header/udp.go110
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go372
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_unsafe.go32
-rwxr-xr-xpkg/tcpip/link/fdbased/fdbased_state_autogen.go4
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go25
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64.go194
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go84
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go309
-rw-r--r--pkg/tcpip/link/loopback/loopback.go87
-rwxr-xr-xpkg/tcpip/link/loopback/loopback_state_autogen.go4
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s40
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go60
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_unsafe.go29
-rw-r--r--pkg/tcpip/link/rawfile/errors.go70
-rwxr-xr-xpkg/tcpip/link/rawfile/rawfile_state_autogen.go4
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go182
-rw-r--r--pkg/tcpip/link/sniffer/pcap.go66
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go408
-rwxr-xr-xpkg/tcpip/link/sniffer/sniffer_state_autogen.go4
-rw-r--r--pkg/tcpip/network/arp/arp.go203
-rwxr-xr-xpkg/tcpip/network/arp/arp_state_autogen.go4
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go134
-rwxr-xr-xpkg/tcpip/network/fragmentation/fragmentation_state_autogen.go38
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go118
-rwxr-xr-xpkg/tcpip/network/fragmentation/reassembler_list.go173
-rw-r--r--pkg/tcpip/network/hash/hash.go93
-rwxr-xr-xpkg/tcpip/network/hash/hash_state_autogen.go4
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go160
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go344
-rwxr-xr-xpkg/tcpip/network/ipv4/ipv4_state_autogen.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go297
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go207
-rwxr-xr-xpkg/tcpip/network/ipv6/ipv6_state_autogen.go4
-rw-r--r--pkg/tcpip/ports/ports.go209
-rwxr-xr-xpkg/tcpip/ports/ports_state_autogen.go4
-rw-r--r--pkg/tcpip/seqnum/seqnum.go67
-rwxr-xr-xpkg/tcpip/seqnum/seqnum_state_autogen.go4
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go306
-rw-r--r--pkg/tcpip/stack/nic.go728
-rw-r--r--pkg/tcpip/stack/registration.go441
-rw-r--r--pkg/tcpip/stack/route.go189
-rw-r--r--pkg/tcpip/stack/stack.go1095
-rw-r--r--pkg/tcpip/stack/stack_global_state.go19
-rwxr-xr-xpkg/tcpip/stack/stack_state_autogen.go59
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go420
-rw-r--r--pkg/tcpip/tcpip.go1055
-rwxr-xr-xpkg/tcpip/tcpip_state_autogen.go40
-rw-r--r--pkg/tcpip/time_unsafe.go45
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go710
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go90
-rwxr-xr-xpkg/tcpip/transport/icmp/icmp_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/icmp/icmp_state_autogen.go98
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go136
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go521
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go88
-rwxr-xr-xpkg/tcpip/transport/raw/packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/raw/raw_state_autogen.go96
-rw-r--r--pkg/tcpip/transport/tcp/accept.go499
-rw-r--r--pkg/tcpip/transport/tcp/connect.go1066
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go233
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go1741
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go362
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go171
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go250
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go221
-rw-r--r--pkg/tcpip/transport/tcp/reno.go103
-rw-r--r--pkg/tcpip/transport/tcp/sack.go99
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go306
-rw-r--r--pkg/tcpip/transport/tcp/segment.go186
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go46
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go79
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go82
-rw-r--r--pkg/tcpip/transport/tcp/snd.go1180
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go50
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_segment_list.go173
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_state_autogen.go400
-rw-r--r--pkg/tcpip/transport/tcp/timer.go141
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go1002
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go112
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go96
-rw-r--r--pkg/tcpip/transport/udp/protocol.go90
-rwxr-xr-xpkg/tcpip/transport/udp/udp_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/udp/udp_state_autogen.go128
-rw-r--r--pkg/tmutex/tmutex.go81
-rwxr-xr-xpkg/tmutex/tmutex_state_autogen.go4
-rw-r--r--pkg/unet/unet.go569
-rwxr-xr-xpkg/unet/unet_state_autogen.go4
-rw-r--r--pkg/unet/unet_unsafe.go289
-rw-r--r--pkg/urpc/urpc.go636
-rwxr-xr-xpkg/urpc/urpc_state_autogen.go4
-rw-r--r--pkg/waiter/waiter.go250
-rwxr-xr-xpkg/waiter/waiter_list.go173
-rwxr-xr-xpkg/waiter/waiter_state_autogen.go67
834 files changed, 180701 insertions, 0 deletions
diff --git a/pkg/abi/abi.go b/pkg/abi/abi.go
new file mode 100644
index 000000000..d56c481c9
--- /dev/null
+++ b/pkg/abi/abi.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package abi describes the interface between a kernel and userspace.
+package abi
+
+import (
+ "fmt"
+)
+
+// OS describes the target operating system for an ABI.
+//
+// Note that OS is architecture-independent. The details of the OS ABI will
+// vary between architectures.
+type OS int
+
+const (
+ // Linux is the Linux ABI.
+ Linux OS = iota
+)
+
+// String implements fmt.Stringer.
+func (o OS) String() string {
+ switch o {
+ case Linux:
+ return "linux"
+ default:
+ return fmt.Sprintf("OS(%d)", o)
+ }
+}
diff --git a/pkg/abi/abi_linux.go b/pkg/abi/abi_linux.go
new file mode 100644
index 000000000..3059479bd
--- /dev/null
+++ b/pkg/abi/abi_linux.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package abi
+
+// Host specifies the host ABI.
+const Host = Linux
diff --git a/pkg/abi/abi_state_autogen.go b/pkg/abi/abi_state_autogen.go
new file mode 100755
index 000000000..4f94570e5
--- /dev/null
+++ b/pkg/abi/abi_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package abi
+
diff --git a/pkg/abi/flag.go b/pkg/abi/flag.go
new file mode 100644
index 000000000..dcdd66d4e
--- /dev/null
+++ b/pkg/abi/flag.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package abi
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+)
+
+// A FlagSet is a slice of bit-flags and their name.
+type FlagSet []struct {
+ Flag uint64
+ Name string
+}
+
+// Parse returns a pretty version of val, using the flag names for known flags.
+// Unknown flags remain numeric.
+func (s FlagSet) Parse(val uint64) string {
+ var flags []string
+
+ for _, f := range s {
+ if val&f.Flag == f.Flag {
+ flags = append(flags, f.Name)
+ val &^= f.Flag
+ }
+ }
+
+ if val != 0 {
+ flags = append(flags, "0x"+strconv.FormatUint(val, 16))
+ }
+
+ if len(flags) == 0 {
+ // Prefer 0 to an empty string.
+ return "0x0"
+ }
+
+ return strings.Join(flags, "|")
+}
+
+// ValueSet is a map of syscall values to their name. Parse will use the name
+// or the value if unknown.
+type ValueSet map[uint64]string
+
+// Parse returns the name of the value associated with `val`. Unknown values
+// are converted to hex.
+func (s ValueSet) Parse(val uint64) string {
+ if v, ok := s[val]; ok {
+ return v
+ }
+ return fmt.Sprintf("%#x", val)
+}
+
+// ParseDecimal returns the name of the value associated with `val`. Unknown
+// values are converted to decimal.
+func (s ValueSet) ParseDecimal(val uint64) string {
+ if v, ok := s[val]; ok {
+ return v
+ }
+ return fmt.Sprintf("%d", val)
+}
+
+// ParseName returns the flag value associated with 'name'. Returns false
+// if no value is found.
+func (s ValueSet) ParseName(name string) (uint64, bool) {
+ for k, v := range s {
+ if v == name {
+ return k, true
+ }
+ }
+ return math.MaxUint64, false
+}
diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go
new file mode 100644
index 000000000..3c6e0079d
--- /dev/null
+++ b/pkg/abi/linux/aio.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ // AIORingSize is sizeof(struct aio_ring).
+ AIORingSize = 32
+)
diff --git a/pkg/abi/linux/ashmem.go b/pkg/abi/linux/ashmem.go
new file mode 100644
index 000000000..2a722abe0
--- /dev/null
+++ b/pkg/abi/linux/ashmem.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Constants used by ashmem in pin-related ioctls.
+const (
+ AshmemNotPurged = 0
+ AshmemWasPurged = 1
+ AshmemIsUnpinned = 0
+ AshmemIsPinned = 1
+)
+
+// AshmemPin structure is used for pin-related ioctls.
+type AshmemPin struct {
+ Offset uint32
+ Len uint32
+}
diff --git a/pkg/abi/linux/audit.go b/pkg/abi/linux/audit.go
new file mode 100644
index 000000000..6cca69af9
--- /dev/null
+++ b/pkg/abi/linux/audit.go
@@ -0,0 +1,23 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Audit numbers identify different system call APIs, from <uapi/linux/audit.h>
+const (
+ // AUDIT_ARCH_X86_64 identifies AMD64.
+ AUDIT_ARCH_X86_64 = 0xc000003e
+ // AUDIT_ARCH_AARCH64 identifies ARM64.
+ AUDIT_ARCH_AARCH64 = 0xc00000b7
+)
diff --git a/pkg/abi/linux/binder.go b/pkg/abi/linux/binder.go
new file mode 100644
index 000000000..63b08324a
--- /dev/null
+++ b/pkg/abi/linux/binder.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// BinderVersion structure is used for BINDER_VERSION ioctl.
+type BinderVersion struct {
+ ProtocolVersion int32
+}
diff --git a/pkg/abi/linux/bpf.go b/pkg/abi/linux/bpf.go
new file mode 100644
index 000000000..aa3d3ce70
--- /dev/null
+++ b/pkg/abi/linux/bpf.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// BPFInstruction is a raw BPF virtual machine instruction.
+//
+// +stateify savable
+type BPFInstruction struct {
+ // OpCode is the operation to execute.
+ OpCode uint16
+
+ // JumpIfTrue is the number of instructions to skip if OpCode is a
+ // conditional instruction and the condition is true.
+ JumpIfTrue uint8
+
+ // JumpIfFalse is the number of instructions to skip if OpCode is a
+ // conditional instruction and the condition is false.
+ JumpIfFalse uint8
+
+ // K is a constant parameter. The meaning depends on the value of OpCode.
+ K uint32
+}
diff --git a/pkg/abi/linux/capability.go b/pkg/abi/linux/capability.go
new file mode 100644
index 000000000..c120cac64
--- /dev/null
+++ b/pkg/abi/linux/capability.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// A Capability represents the ability to perform a privileged operation.
+type Capability int
+
+// Capabilities defined by Linux. Taken from the kernel's
+// include/uapi/linux/capability.h. See capabilities(7) or that file for more
+// detailed capability descriptions.
+const (
+ CAP_CHOWN = Capability(0)
+ CAP_DAC_OVERRIDE = Capability(1)
+ CAP_DAC_READ_SEARCH = Capability(2)
+ CAP_FOWNER = Capability(3)
+ CAP_FSETID = Capability(4)
+ CAP_KILL = Capability(5)
+ CAP_SETGID = Capability(6)
+ CAP_SETUID = Capability(7)
+ CAP_SETPCAP = Capability(8)
+ CAP_LINUX_IMMUTABLE = Capability(9)
+ CAP_NET_BIND_SERVICE = Capability(10)
+ CAP_NET_BROADCAST = Capability(11)
+ CAP_NET_ADMIN = Capability(12)
+ CAP_NET_RAW = Capability(13)
+ CAP_IPC_LOCK = Capability(14)
+ CAP_IPC_OWNER = Capability(15)
+ CAP_SYS_MODULE = Capability(16)
+ CAP_SYS_RAWIO = Capability(17)
+ CAP_SYS_CHROOT = Capability(18)
+ CAP_SYS_PTRACE = Capability(19)
+ CAP_SYS_PACCT = Capability(20)
+ CAP_SYS_ADMIN = Capability(21)
+ CAP_SYS_BOOT = Capability(22)
+ CAP_SYS_NICE = Capability(23)
+ CAP_SYS_RESOURCE = Capability(24)
+ CAP_SYS_TIME = Capability(25)
+ CAP_SYS_TTY_CONFIG = Capability(26)
+ CAP_MKNOD = Capability(27)
+ CAP_LEASE = Capability(28)
+ CAP_AUDIT_WRITE = Capability(29)
+ CAP_AUDIT_CONTROL = Capability(30)
+ CAP_SETFCAP = Capability(31)
+ CAP_MAC_OVERRIDE = Capability(32)
+ CAP_MAC_ADMIN = Capability(33)
+ CAP_SYSLOG = Capability(34)
+ CAP_WAKE_ALARM = Capability(35)
+ CAP_BLOCK_SUSPEND = Capability(36)
+ CAP_AUDIT_READ = Capability(37)
+
+ // MaxCapability is the highest-numbered capability.
+ MaxCapability = CAP_AUDIT_READ
+)
+
+// Ok returns true if cp is a supported capability.
+func (cp Capability) Ok() bool {
+ return cp >= 0 && cp <= MaxCapability
+}
+
+// Version numbers used by the capget/capset syscalls, defined in Linux's
+// include/uapi/linux/capability.h.
+const (
+ // LINUX_CAPABILITY_VERSION_1 causes the data pointer to be
+ // interpreted as a pointer to a single cap_user_data_t. Since capability
+ // sets are 64 bits and the "capability sets" in cap_user_data_t are 32
+ // bits only, this causes the upper 32 bits to be implicitly 0.
+ LINUX_CAPABILITY_VERSION_1 = 0x19980330
+
+ // LINUX_CAPABILITY_VERSION_2 and LINUX_CAPABILITY_VERSION_3 cause the
+ // data pointer to be interpreted as a pointer to an array of 2
+ // cap_user_data_t, using the second to store the 32 MSB of each capability
+ // set. Versions 2 and 3 are identical, but Linux printk's a warning on use
+ // of version 2 due to a userspace API defect.
+ LINUX_CAPABILITY_VERSION_2 = 0x20071026
+ LINUX_CAPABILITY_VERSION_3 = 0x20080522
+
+ // HighestCapabilityVersion is the highest supported
+ // LINUX_CAPABILITY_VERSION_* version.
+ HighestCapabilityVersion = LINUX_CAPABILITY_VERSION_3
+)
+
+// CapUserHeader is equivalent to Linux's cap_user_header_t.
+type CapUserHeader struct {
+ Version uint32
+ Pid int32
+}
+
+// CapUserData is equivalent to Linux's cap_user_data_t.
+type CapUserData struct {
+ Effective uint32
+ Permitted uint32
+ Inheritable uint32
+}
diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go
new file mode 100644
index 000000000..421e11256
--- /dev/null
+++ b/pkg/abi/linux/dev.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// MakeDeviceID encodes a major and minor device number into a single device ID.
+//
+// Format (see linux/kdev_t.h:new_encode_dev):
+//
+// Bits 7:0 - minor bits 7:0
+// Bits 19:8 - major bits 11:0
+// Bits 31:20 - minor bits 19:8
+func MakeDeviceID(major uint16, minor uint32) uint32 {
+ return (minor & 0xff) | ((uint32(major) & 0xfff) << 8) | ((minor >> 8) << 20)
+}
+
+// DecodeDeviceID decodes a device ID into major and minor device numbers.
+func DecodeDeviceID(rdev uint32) (uint16, uint32) {
+ major := uint16((rdev >> 8) & 0xfff)
+ minor := (rdev & 0xff) | ((rdev >> 20) << 8)
+ return major, minor
+}
+
+// Character device IDs.
+//
+// See Documentations/devices.txt and uapi/linux/major.h.
+const (
+ // TTYAUX_MAJOR is the major device number for alternate TTY devices.
+ TTYAUX_MAJOR = 5
+
+ // UNIX98_PTY_MASTER_MAJOR is the initial major device number for
+ // Unix98 PTY masters.
+ UNIX98_PTY_MASTER_MAJOR = 128
+
+ // UNIX98_PTY_SLAVE_MAJOR is the initial major device number for
+ // Unix98 PTY slaves.
+ UNIX98_PTY_SLAVE_MAJOR = 136
+)
+
+// Minor device numbers for TTYAUX_MAJOR.
+const (
+ // PTMX_MINOR is the minor device number for /dev/ptmx.
+ PTMX_MINOR = 2
+)
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go
new file mode 100644
index 000000000..fb1c679d2
--- /dev/null
+++ b/pkg/abi/linux/elf.go
@@ -0,0 +1,91 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Linux auxiliary vector entry types.
+const (
+ // AT_NULL is the end of the auxiliary vector.
+ AT_NULL = 0
+
+ // AT_IGNORE should be ignored.
+ AT_IGNORE = 1
+
+ // AT_EXECFD is the file descriptor of the program.
+ AT_EXECFD = 2
+
+ // AT_PHDR points to the program headers.
+ AT_PHDR = 3
+
+ // AT_PHENT is the size of a program header entry.
+ AT_PHENT = 4
+
+ // AT_PHNUM is the number of program headers.
+ AT_PHNUM = 5
+
+ // AT_PAGESZ is the system page size.
+ AT_PAGESZ = 6
+
+ // AT_BASE is the base address of the interpreter.
+ AT_BASE = 7
+
+ // AT_FLAGS are flags.
+ AT_FLAGS = 8
+
+ // AT_ENTRY is the program entry point.
+ AT_ENTRY = 9
+
+ // AT_NOTELF indicates that the program is not an ELF binary.
+ AT_NOTELF = 10
+
+ // AT_UID is the real UID.
+ AT_UID = 11
+
+ // AT_EUID is the effective UID.
+ AT_EUID = 12
+
+ // AT_GID is the real GID.
+ AT_GID = 13
+
+ // AT_EGID is the effective GID.
+ AT_EGID = 14
+
+ // AT_PLATFORM is a string identifying the CPU.
+ AT_PLATFORM = 15
+
+ // AT_HWCAP are arch-dependent CPU capabilities.
+ AT_HWCAP = 16
+
+ // AT_CLKTCK is the frequency used by times(2).
+ AT_CLKTCK = 17
+
+ // AT_SECURE indicate secure mode.
+ AT_SECURE = 23
+
+ // AT_BASE_PLATFORM is a string identifying the "real" platform. It may
+ // differ from AT_PLATFORM.
+ AT_BASE_PLATFORM = 24
+
+ // AT_RANDOM points to 16-bytes of random data.
+ AT_RANDOM = 25
+
+ // AT_HWCAP2 is an extension of AT_HWCAP.
+ AT_HWCAP2 = 26
+
+ // AT_EXECFN is the path used to execute the program.
+ AT_EXECFN = 31
+
+ // AT_SYSINFO_EHDR is the address of the VDSO.
+ AT_SYSINFO_EHDR = 33
+)
diff --git a/pkg/abi/linux/errors.go b/pkg/abi/linux/errors.go
new file mode 100644
index 000000000..93f85a864
--- /dev/null
+++ b/pkg/abi/linux/errors.go
@@ -0,0 +1,172 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Errno represents a Linux errno value.
+type Errno struct {
+ number int
+ name string
+}
+
+// Number returns the errno number.
+func (e *Errno) Number() int {
+ return e.number
+}
+
+// String implements fmt.Stringer.String.
+func (e *Errno) String() string {
+ return e.name
+}
+
+// Errno values from include/uapi/asm-generic/errno-base.h.
+var (
+ EPERM = &Errno{1, "operation not permitted"}
+ ENOENT = &Errno{2, "no such file or directory"}
+ ESRCH = &Errno{3, "no such process"}
+ EINTR = &Errno{4, "interrupted system call"}
+ EIO = &Errno{5, "I/O error"}
+ ENXIO = &Errno{6, "no such device or address"}
+ E2BIG = &Errno{7, "argument list too long"}
+ ENOEXEC = &Errno{8, "exec format error"}
+ EBADF = &Errno{9, "bad file number"}
+ ECHILD = &Errno{10, "no child processes"}
+ EAGAIN = &Errno{11, "try again"}
+ ENOMEM = &Errno{12, "out of memory"}
+ EACCES = &Errno{13, "permission denied"}
+ EFAULT = &Errno{14, "bad address"}
+ ENOTBLK = &Errno{15, "block device required"}
+ EBUSY = &Errno{16, "device or resource busy"}
+ EEXIST = &Errno{17, "file exists"}
+ EXDEV = &Errno{18, "cross-device link"}
+ ENODEV = &Errno{19, "no such device"}
+ ENOTDIR = &Errno{20, "not a directory"}
+ EISDIR = &Errno{21, "is a directory"}
+ EINVAL = &Errno{22, "invalid argument"}
+ ENFILE = &Errno{23, "file table overflow"}
+ EMFILE = &Errno{24, "too many open files"}
+ ENOTTY = &Errno{25, "not a typewriter"}
+ ETXTBSY = &Errno{26, "text file busy"}
+ EFBIG = &Errno{27, "file too large"}
+ ENOSPC = &Errno{28, "no space left on device"}
+ ESPIPE = &Errno{29, "illegal seek"}
+ EROFS = &Errno{30, "read-only file system"}
+ EMLINK = &Errno{31, "too many links"}
+ EPIPE = &Errno{32, "broken pipe"}
+ EDOM = &Errno{33, "math argument out of domain of func"}
+ ERANGE = &Errno{34, "math result not representable"}
+)
+
+// Errno values from include/uapi/asm-generic/errno.h.
+var (
+ EDEADLK = &Errno{35, "resource deadlock would occur"}
+ ENAMETOOLONG = &Errno{36, "file name too long"}
+ ENOLCK = &Errno{37, "no record locks available"}
+ ENOSYS = &Errno{38, "invalid system call number"}
+ ENOTEMPTY = &Errno{39, "directory not empty"}
+ ELOOP = &Errno{40, "too many symbolic links encountered"}
+ EWOULDBLOCK = &Errno{EAGAIN.number, "operation would block"}
+ ENOMSG = &Errno{42, "no message of desired type"}
+ EIDRM = &Errno{43, "identifier removed"}
+ ECHRNG = &Errno{44, "channel number out of range"}
+ EL2NSYNC = &Errno{45, "level 2 not synchronized"}
+ EL3HLT = &Errno{46, "level 3 halted"}
+ EL3RST = &Errno{47, "level 3 reset"}
+ ELNRNG = &Errno{48, "link number out of range"}
+ EUNATCH = &Errno{49, "protocol driver not attached"}
+ ENOCSI = &Errno{50, "no CSI structure available"}
+ EL2HLT = &Errno{51, "level 2 halted"}
+ EBADE = &Errno{52, "invalid exchange"}
+ EBADR = &Errno{53, "invalid request descriptor"}
+ EXFULL = &Errno{54, "exchange full"}
+ ENOANO = &Errno{55, "no anode"}
+ EBADRQC = &Errno{56, "invalid request code"}
+ EBADSLT = &Errno{57, "invalid slot"}
+ EDEADLOCK = EDEADLK
+ EBFONT = &Errno{59, "bad font file format"}
+ ENOSTR = &Errno{60, "device not a stream"}
+ ENODATA = &Errno{61, "no data available"}
+ ETIME = &Errno{62, "timer expired"}
+ ENOSR = &Errno{63, "out of streams resources"}
+ ENONET = &Errno{64, "machine is not on the network"}
+ ENOPKG = &Errno{65, "package not installed"}
+ EREMOTE = &Errno{66, "object is remote"}
+ ENOLINK = &Errno{67, "link has been severed"}
+ EADV = &Errno{68, "advertise error"}
+ ESRMNT = &Errno{69, "srmount error"}
+ ECOMM = &Errno{70, "communication error on send"}
+ EPROTO = &Errno{71, "protocol error"}
+ EMULTIHOP = &Errno{72, "multihop attempted"}
+ EDOTDOT = &Errno{73, "RFS specific error"}
+ EBADMSG = &Errno{74, "not a data message"}
+ EOVERFLOW = &Errno{75, "value too large for defined data type"}
+ ENOTUNIQ = &Errno{76, "name not unique on network"}
+ EBADFD = &Errno{77, "file descriptor in bad state"}
+ EREMCHG = &Errno{78, "remote address changed"}
+ ELIBACC = &Errno{79, "can not access a needed shared library"}
+ ELIBBAD = &Errno{80, "accessing a corrupted shared library"}
+ ELIBSCN = &Errno{81, ".lib section in a.out corrupted"}
+ ELIBMAX = &Errno{82, "attempting to link in too many shared libraries"}
+ ELIBEXEC = &Errno{83, "cannot exec a shared library directly"}
+ EILSEQ = &Errno{84, "illegal byte sequence"}
+ ERESTART = &Errno{85, "interrupted system call should be restarted"}
+ ESTRPIPE = &Errno{86, "streams pipe error"}
+ EUSERS = &Errno{87, "too many users"}
+ ENOTSOCK = &Errno{88, "socket operation on non-socket"}
+ EDESTADDRREQ = &Errno{89, "destination address required"}
+ EMSGSIZE = &Errno{90, "message too long"}
+ EPROTOTYPE = &Errno{91, "protocol wrong type for socket"}
+ ENOPROTOOPT = &Errno{92, "protocol not available"}
+ EPROTONOSUPPORT = &Errno{93, "protocol not supported"}
+ ESOCKTNOSUPPORT = &Errno{94, "socket type not supported"}
+ EOPNOTSUPP = &Errno{95, "operation not supported on transport endpoint"}
+ EPFNOSUPPORT = &Errno{96, "protocol family not supported"}
+ EAFNOSUPPORT = &Errno{97, "address family not supported by protocol"}
+ EADDRINUSE = &Errno{98, "address already in use"}
+ EADDRNOTAVAIL = &Errno{99, "cannot assign requested address"}
+ ENETDOWN = &Errno{100, "network is down"}
+ ENETUNREACH = &Errno{101, "network is unreachable"}
+ ENETRESET = &Errno{102, "network dropped connection because of reset"}
+ ECONNABORTED = &Errno{103, "software caused connection abort"}
+ ECONNRESET = &Errno{104, "connection reset by peer"}
+ ENOBUFS = &Errno{105, "no buffer space available"}
+ EISCONN = &Errno{106, "transport endpoint is already connected"}
+ ENOTCONN = &Errno{107, "transport endpoint is not connected"}
+ ESHUTDOWN = &Errno{108, "cannot send after transport endpoint shutdown"}
+ ETOOMANYREFS = &Errno{109, "too many references: cannot splice"}
+ ETIMEDOUT = &Errno{110, "connection timed out"}
+ ECONNREFUSED = &Errno{111, "connection refused"}
+ EHOSTDOWN = &Errno{112, "host is down"}
+ EHOSTUNREACH = &Errno{113, "no route to host"}
+ EALREADY = &Errno{114, "operation already in progress"}
+ EINPROGRESS = &Errno{115, "operation now in progress"}
+ ESTALE = &Errno{116, "stale file handle"}
+ EUCLEAN = &Errno{117, "structure needs cleaning"}
+ ENOTNAM = &Errno{118, "not a XENIX named type file"}
+ ENAVAIL = &Errno{119, "no XENIX semaphores available"}
+ EISNAM = &Errno{120, "is a named type file"}
+ EREMOTEIO = &Errno{121, "remote I/O error"}
+ EDQUOT = &Errno{122, "quota exceeded"}
+ ENOMEDIUM = &Errno{123, "no medium found"}
+ EMEDIUMTYPE = &Errno{124, "wrong medium type"}
+ ECANCELED = &Errno{125, "operation Canceled"}
+ ENOKEY = &Errno{126, "required key not available"}
+ EKEYEXPIRED = &Errno{127, "key has expired"}
+ EKEYREVOKED = &Errno{128, "key has been revoked"}
+ EKEYREJECTED = &Errno{129, "key was rejected by service"}
+ EOWNERDEAD = &Errno{130, "owner died"}
+ ENOTRECOVERABLE = &Errno{131, "state not recoverable"}
+ ERFKILL = &Errno{132, "operation not possible due to RF-kill"}
+ EHWPOISON = &Errno{133, "memory page has hardware error"}
+)
diff --git a/pkg/abi/linux/eventfd.go b/pkg/abi/linux/eventfd.go
new file mode 100644
index 000000000..9c479fc8f
--- /dev/null
+++ b/pkg/abi/linux/eventfd.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Constants for eventfd2(2).
+const (
+ EFD_SEMAPHORE = 0x1
+ EFD_CLOEXEC = O_CLOEXEC
+ EFD_NONBLOCK = O_NONBLOCK
+)
diff --git a/pkg/abi/linux/exec.go b/pkg/abi/linux/exec.go
new file mode 100644
index 000000000..579d46c41
--- /dev/null
+++ b/pkg/abi/linux/exec.go
@@ -0,0 +1,18 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// TASK_COMM_LEN is the task command name length.
+const TASK_COMM_LEN = 16
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
new file mode 100644
index 000000000..b30350193
--- /dev/null
+++ b/pkg/abi/linux/fcntl.go
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Comands from linux/fcntl.h.
+const (
+ F_DUPFD = 0
+ F_GETFD = 1
+ F_GETFL = 3
+ F_GETOWN = 9
+ F_SETFD = 2
+ F_SETFL = 4
+ F_SETLK = 6
+ F_SETLKW = 7
+ F_SETOWN = 8
+ F_DUPFD_CLOEXEC = 1024 + 6
+ F_SETPIPE_SZ = 1024 + 7
+ F_GETPIPE_SZ = 1024 + 8
+)
+
+// Flags for fcntl.
+const (
+ FD_CLOEXEC = 00000001
+)
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
new file mode 100644
index 000000000..81ff9fe9e
--- /dev/null
+++ b/pkg/abi/linux/file.go
@@ -0,0 +1,267 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+)
+
+// Constants for open(2).
+const (
+ O_ACCMODE = 00000003
+ O_RDONLY = 00000000
+ O_WRONLY = 00000001
+ O_RDWR = 00000002
+ O_CREAT = 00000100
+ O_EXCL = 00000200
+ O_NOCTTY = 00000400
+ O_TRUNC = 00001000
+ O_APPEND = 00002000
+ O_NONBLOCK = 00004000
+ O_ASYNC = 00020000
+ O_DIRECT = 00040000
+ O_LARGEFILE = 00100000
+ O_DIRECTORY = 00200000
+ O_NOFOLLOW = 00400000
+ O_CLOEXEC = 02000000
+ O_SYNC = 04010000
+ O_PATH = 010000000
+)
+
+// Constants for fstatat(2).
+const (
+ AT_SYMLINK_NOFOLLOW = 0x100
+)
+
+// Constants for mount(2).
+const (
+ MS_RDONLY = 0x1
+ MS_NOSUID = 0x2
+ MS_NODEV = 0x4
+ MS_NOEXEC = 0x8
+ MS_SYNCHRONOUS = 0x10
+ MS_REMOUNT = 0x20
+ MS_MANDLOCK = 0x40
+ MS_DIRSYNC = 0x80
+ MS_NOATIME = 0x400
+ MS_NODIRATIME = 0x800
+ MS_BIND = 0x1000
+ MS_MOVE = 0x2000
+ MS_REC = 0x4000
+
+ MS_POSIXACL = 0x10000
+ MS_UNBINDABLE = 0x20000
+ MS_PRIVATE = 0x40000
+ MS_SLAVE = 0x80000
+ MS_SHARED = 0x100000
+ MS_RELATIME = 0x200000
+ MS_KERNMOUNT = 0x400000
+ MS_I_VERSION = 0x800000
+ MS_STRICTATIME = 0x1000000
+
+ MS_MGC_VAL = 0xC0ED0000
+ MS_MGC_MSK = 0xffff0000
+)
+
+// Constants for umount2(2).
+const (
+ MNT_FORCE = 0x1
+ MNT_DETACH = 0x2
+ MNT_EXPIRE = 0x4
+ UMOUNT_NOFOLLOW = 0x8
+)
+
+// Constants for unlinkat(2).
+const (
+ AT_REMOVEDIR = 0x200
+)
+
+// Constants for linkat(2) and fchownat(2).
+const (
+ AT_SYMLINK_FOLLOW = 0x400
+ AT_EMPTY_PATH = 0x1000
+)
+
+// Constants for all file-related ...at(2) syscalls.
+const (
+ AT_FDCWD = -100
+)
+
+// Special values for the ns field in utimensat(2).
+const (
+ UTIME_NOW = ((1 << 30) - 1)
+ UTIME_OMIT = ((1 << 30) - 2)
+)
+
+// MaxSymlinkTraversals is the maximum number of links that will be followed by
+// the kernel to resolve a symlink.
+const MaxSymlinkTraversals = 40
+
+// Constants for flock(2).
+const (
+ LOCK_SH = 1 // shared lock
+ LOCK_EX = 2 // exclusive lock
+ LOCK_NB = 4 // or'd with one of the above to prevent blocking
+ LOCK_UN = 8 // remove lock
+)
+
+// Values for mode_t.
+const (
+ FileTypeMask = 0170000
+ ModeSocket = 0140000
+ ModeSymlink = 0120000
+ ModeRegular = 0100000
+ ModeBlockDevice = 060000
+ ModeDirectory = 040000
+ ModeCharacterDevice = 020000
+ ModeNamedPipe = 010000
+
+ ModeSetUID = 04000
+ ModeSetGID = 02000
+ ModeSticky = 01000
+
+ ModeUserAll = 0700
+ ModeUserRead = 0400
+ ModeUserWrite = 0200
+ ModeUserExec = 0100
+ ModeGroupAll = 0070
+ ModeGroupRead = 0040
+ ModeGroupWrite = 0020
+ ModeGroupExec = 0010
+ ModeOtherAll = 0007
+ ModeOtherRead = 0004
+ ModeOtherWrite = 0002
+ ModeOtherExec = 0001
+ PermissionsMask = 0777
+)
+
+// Values for preadv2/pwritev2.
+const (
+ RWF_HIPRI = 0x00000001
+ RWF_DSYNC = 0x00000002
+ RWF_SYNC = 0x00000004
+ RWF_VALID = RWF_HIPRI | RWF_DSYNC | RWF_SYNC
+)
+
+// Stat represents struct stat.
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Nlink uint64
+ Mode uint32
+ UID uint32
+ GID uint32
+ X_pad0 int32
+ Rdev uint64
+ Size int64
+ Blksize int64
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ X_unused [3]int64
+}
+
+// SizeOfStat is the size of a Stat struct.
+var SizeOfStat = binary.Size(Stat{})
+
+// FileMode represents a mode_t.
+type FileMode uint
+
+// Permissions returns just the permission bits.
+func (m FileMode) Permissions() FileMode {
+ return m & PermissionsMask
+}
+
+// FileType returns just the file type bits.
+func (m FileMode) FileType() FileMode {
+ return m & FileTypeMask
+}
+
+// ExtraBits returns everything but the file type and permission bits.
+func (m FileMode) ExtraBits() FileMode {
+ return m &^ (PermissionsMask | FileTypeMask)
+}
+
+// String returns a string representation of m.
+func (m FileMode) String() string {
+ var s []string
+ if ft := m.FileType(); ft != 0 {
+ s = append(s, fileType.Parse(uint64(ft)))
+ }
+ if eb := m.ExtraBits(); eb != 0 {
+ s = append(s, modeExtraBits.Parse(uint64(eb)))
+ }
+ s = append(s, fmt.Sprintf("0o%o", m.Permissions()))
+ return strings.Join(s, "|")
+}
+
+var modeExtraBits = abi.FlagSet{
+ {
+ Flag: ModeSetUID,
+ Name: "S_ISUID",
+ },
+ {
+ Flag: ModeSetGID,
+ Name: "S_ISGID",
+ },
+ {
+ Flag: ModeSticky,
+ Name: "S_ISVTX",
+ },
+}
+
+var fileType = abi.ValueSet{
+ ModeSocket: "S_IFSOCK",
+ ModeSymlink: "S_IFLINK",
+ ModeRegular: "S_IFREG",
+ ModeBlockDevice: "S_IFBLK",
+ ModeDirectory: "S_IFDIR",
+ ModeCharacterDevice: "S_IFCHR",
+ ModeNamedPipe: "S_IFIFO",
+}
+
+// Constants for memfd_create(2). Source: include/uapi/linux/memfd.h
+const (
+ MFD_CLOEXEC = 0x0001
+ MFD_ALLOW_SEALING = 0x0002
+)
+
+// Constants related to file seals. Source: include/uapi/{asm-generic,linux}/fcntl.h
+const (
+ F_LINUX_SPECIFIC_BASE = 1024
+ F_ADD_SEALS = F_LINUX_SPECIFIC_BASE + 9
+ F_GET_SEALS = F_LINUX_SPECIFIC_BASE + 10
+
+ F_SEAL_SEAL = 0x0001 // Prevent further seals from being set.
+ F_SEAL_SHRINK = 0x0002 // Prevent file from shrinking.
+ F_SEAL_GROW = 0x0004 // Prevent file from growing.
+ F_SEAL_WRITE = 0x0008 // Prevent writes.
+)
+
+// Constants related to fallocate(2). Source: include/uapi/linux/falloc.h
+const (
+ FALLOC_FL_KEEP_SIZE = 0x01
+ FALLOC_FL_PUNCH_HOLE = 0x02
+ FALLOC_FL_NO_HIDE_STALE = 0x04
+ FALLOC_FL_COLLAPSE_RANGE = 0x08
+ FALLOC_FL_ZERO_RANGE = 0x10
+ FALLOC_FL_INSERT_RANGE = 0x20
+ FALLOC_FL_UNSHARE_RANGE = 0x40
+)
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
new file mode 100644
index 000000000..c82ab9b5b
--- /dev/null
+++ b/pkg/abi/linux/fs.go
@@ -0,0 +1,84 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Filesystem types used in statfs(2).
+//
+// See linux/magic.h.
+const (
+ ANON_INODE_FS_MAGIC = 0x09041934
+ DEVPTS_SUPER_MAGIC = 0x00001cd1
+ OVERLAYFS_SUPER_MAGIC = 0x794c7630
+ PIPEFS_MAGIC = 0x50495045
+ PROC_SUPER_MAGIC = 0x9fa0
+ RAMFS_MAGIC = 0x09041934
+ SOCKFS_MAGIC = 0x534F434B
+ SYSFS_MAGIC = 0x62656572
+ TMPFS_MAGIC = 0x01021994
+ V9FS_MAGIC = 0x01021997
+)
+
+// Filesystem path limits, from uapi/linux/limits.h.
+const (
+ NAME_MAX = 255
+ PATH_MAX = 4096
+)
+
+// Statfs is struct statfs, from uapi/asm-generic/statfs.h.
+type Statfs struct {
+ // Type is one of the filesystem magic values, defined above.
+ Type uint64
+
+ // BlockSize is the data block size.
+ BlockSize int64
+
+ // Blocks is the number of data blocks in use.
+ Blocks uint64
+
+ // BlocksFree is the number of free blocks.
+ BlocksFree uint64
+
+ // BlocksAvailable is the number of blocks free for use by
+ // unprivileged users.
+ BlocksAvailable uint64
+
+ // Files is the number of used file nodes on the filesystem.
+ Files uint64
+
+ // FileFress is the number of free file nodes on the filesystem.
+ FilesFree uint64
+
+ // FSID is the filesystem ID.
+ FSID [2]int32
+
+ // NameLength is the maximum file name length.
+ NameLength uint64
+
+ // FragmentSize is equivalent to BlockSize.
+ FragmentSize int64
+
+ // Flags is the set of filesystem mount flags.
+ Flags uint64
+
+ // Spare is unused.
+ Spare [4]uint64
+}
+
+// Sync_file_range flags, from include/uapi/linux/fs.h
+const (
+ SYNC_FILE_RANGE_WAIT_BEFORE = 1
+ SYNC_FILE_RANGE_WRITE = 2
+ SYNC_FILE_RANGE_WAIT_AFTER = 4
+)
diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go
new file mode 100644
index 000000000..08bfde3b5
--- /dev/null
+++ b/pkg/abi/linux/futex.go
@@ -0,0 +1,62 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// From <linux/futex.h> and <sys/time.h>.
+// Flags are used in syscall futex(2).
+const (
+ FUTEX_WAIT = 0
+ FUTEX_WAKE = 1
+ FUTEX_FD = 2
+ FUTEX_REQUEUE = 3
+ FUTEX_CMP_REQUEUE = 4
+ FUTEX_WAKE_OP = 5
+ FUTEX_LOCK_PI = 6
+ FUTEX_UNLOCK_PI = 7
+ FUTEX_TRYLOCK_PI = 8
+ FUTEX_WAIT_BITSET = 9
+ FUTEX_WAKE_BITSET = 10
+ FUTEX_WAIT_REQUEUE_PI = 11
+ FUTEX_CMP_REQUEUE_PI = 12
+
+ FUTEX_PRIVATE_FLAG = 128
+ FUTEX_CLOCK_REALTIME = 256
+)
+
+// These are flags are from <linux/futex.h> and are used in FUTEX_WAKE_OP
+// to define the operations.
+const (
+ FUTEX_OP_SET = 0
+ FUTEX_OP_ADD = 1
+ FUTEX_OP_OR = 2
+ FUTEX_OP_ANDN = 3
+ FUTEX_OP_XOR = 4
+ FUTEX_OP_OPARG_SHIFT = 8
+ FUTEX_OP_CMP_EQ = 0
+ FUTEX_OP_CMP_NE = 1
+ FUTEX_OP_CMP_LT = 2
+ FUTEX_OP_CMP_LE = 3
+ FUTEX_OP_CMP_GT = 4
+ FUTEX_OP_CMP_GE = 5
+)
+
+// FUTEX_TID_MASK is the TID portion of a PI futex word.
+const FUTEX_TID_MASK = 0x3fffffff
+
+// Constants used for priority-inheritance futexes.
+const (
+ FUTEX_WAITERS = 0x80000000
+ FUTEX_OWNER_DIED = 0x40000000
+)
diff --git a/pkg/abi/linux/inotify.go b/pkg/abi/linux/inotify.go
new file mode 100644
index 000000000..2d08194ba
--- /dev/null
+++ b/pkg/abi/linux/inotify.go
@@ -0,0 +1,97 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Inotify events observable by userspace. These directly correspond to
+// filesystem operations and there may only be a single of them per inotify
+// event read from an inotify fd.
+const (
+ // IN_ACCESS indicates a file was accessed.
+ IN_ACCESS = 0x00000001
+ // IN_MODIFY indicates a file was modified.
+ IN_MODIFY = 0x00000002
+ // IN_ATTRIB indicates a watch target's metadata changed.
+ IN_ATTRIB = 0x00000004
+ // IN_CLOSE_WRITE indicates a writable file was closed.
+ IN_CLOSE_WRITE = 0x00000008
+ // IN_CLOSE_NOWRITE indicates a non-writable file was closed.
+ IN_CLOSE_NOWRITE = 0x00000010
+ // IN_OPEN indicates a file was opened.
+ IN_OPEN = 0x00000020
+ // IN_MOVED_FROM indicates a file was moved from X.
+ IN_MOVED_FROM = 0x00000040
+ // IN_MOVED_TO indicates a file was moved to Y.
+ IN_MOVED_TO = 0x00000080
+ // IN_CREATE indicates a file was created in a watched directory.
+ IN_CREATE = 0x00000100
+ // IN_DELETE indicates a file was deleted in a watched directory.
+ IN_DELETE = 0x00000200
+ // IN_DELETE_SELF indicates a watch target itself was deleted.
+ IN_DELETE_SELF = 0x00000400
+ // IN_MOVE_SELF indicates a watch target itself was moved.
+ IN_MOVE_SELF = 0x00000800
+ // IN_ALL_EVENTS is a mask for all observable userspace events.
+ IN_ALL_EVENTS = 0x00000fff
+)
+
+// Inotify control events. These may be present in their own events, or ORed
+// with other observable events.
+const (
+ // IN_UNMOUNT indicates the backing filesystem was unmounted.
+ IN_UNMOUNT = 0x00002000
+ // IN_Q_OVERFLOW indicates the event queued overflowed.
+ IN_Q_OVERFLOW = 0x00004000
+ // IN_IGNORED indicates a watch was removed, either implicitly or through
+ // inotify_rm_watch(2).
+ IN_IGNORED = 0x00008000
+ // IN_ISDIR indicates the subject of an event was a directory.
+ IN_ISDIR = 0x40000000
+)
+
+// Feature flags for inotify_add_watch(2).
+const (
+ // IN_ONLYDIR indicates that a path should be watched only if it's a
+ // directory.
+ IN_ONLYDIR = 0x01000000
+ // IN_DONT_FOLLOW indicates that the watch path shouldn't be resolved if
+ // it's a symlink.
+ IN_DONT_FOLLOW = 0x02000000
+ // IN_EXCL_UNLINK indicates events to this watch from unlinked objects
+ // should be filtered out.
+ IN_EXCL_UNLINK = 0x04000000
+ // IN_MASK_ADD indicates the provided mask should be ORed into any existing
+ // watch on the provided path.
+ IN_MASK_ADD = 0x20000000
+ // IN_ONESHOT indicates the watch should be removed after one event.
+ IN_ONESHOT = 0x80000000
+)
+
+// Feature flags for inotify_init1(2).
+const (
+ // IN_CLOEXEC is an alias for O_CLOEXEC. It indicates that the inotify
+ // fd should be closed on exec(2) and friends.
+ IN_CLOEXEC = 0x00080000
+ // IN_NONBLOCK is an alias for O_NONBLOCK. It indicates I/O syscall on the
+ // inotify fd should not block.
+ IN_NONBLOCK = 0x00000800
+)
+
+// ALL_INOTIFY_BITS contains all the bits for all possible inotify events. It's
+// defined in the Linux source at "include/linux/inotify.h".
+const ALL_INOTIFY_BITS = IN_ACCESS | IN_MODIFY | IN_ATTRIB | IN_CLOSE_WRITE |
+ IN_CLOSE_NOWRITE | IN_OPEN | IN_MOVED_FROM | IN_MOVED_TO | IN_CREATE |
+ IN_DELETE | IN_DELETE_SELF | IN_MOVE_SELF | IN_UNMOUNT | IN_Q_OVERFLOW |
+ IN_IGNORED | IN_ONLYDIR | IN_DONT_FOLLOW | IN_EXCL_UNLINK | IN_MASK_ADD |
+ IN_ISDIR | IN_ONESHOT
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
new file mode 100644
index 000000000..04bb767dc
--- /dev/null
+++ b/pkg/abi/linux/ioctl.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// ioctl(2) requests provided by asm-generic/ioctls.h
+//
+// These are ordered by request number (low byte).
+const (
+ TCGETS = 0x00005401
+ TCSETS = 0x00005402
+ TCSETSW = 0x00005403
+ TCSETSF = 0x00005404
+ TCSBRK = 0x00005409
+ TIOCEXCL = 0x0000540c
+ TIOCNXCL = 0x0000540d
+ TIOCSCTTY = 0x0000540e
+ TIOCGPGRP = 0x0000540f
+ TIOCSPGRP = 0x00005410
+ TIOCOUTQ = 0x00005411
+ TIOCSTI = 0x00005412
+ TIOCGWINSZ = 0x00005413
+ TIOCSWINSZ = 0x00005414
+ TIOCMGET = 0x00005415
+ TIOCMBIS = 0x00005416
+ TIOCMBIC = 0x00005417
+ TIOCMSET = 0x00005418
+ TIOCINQ = 0x0000541b
+ FIONREAD = TIOCINQ
+ FIONBIO = 0x00005421
+ TIOCSETD = 0x00005423
+ TIOCNOTTY = 0x00005422
+ TIOCGETD = 0x00005424
+ TCSBRKP = 0x00005425
+ TIOCSBRK = 0x00005427
+ TIOCCBRK = 0x00005428
+ TIOCGSID = 0x00005429
+ TIOCGPTN = 0x80045430
+ TIOCSPTLCK = 0x40045431
+ TIOCGDEV = 0x80045432
+ TIOCVHANGUP = 0x00005437
+ TCFLSH = 0x0000540b
+ TIOCCONS = 0x0000541d
+ TIOCSSERIAL = 0x0000541f
+ TIOCGEXCL = 0x80045440
+ TIOCGPTPEER = 0x80045441
+ TIOCGICOUNT = 0x0000545d
+ FIONCLEX = 0x00005450
+ FIOCLEX = 0x00005451
+ FIOASYNC = 0x00005452
+ FIOSETOWN = 0x00008901
+ SIOCSPGRP = 0x00008902
+ FIOGETOWN = 0x00008903
+ SIOCGPGRP = 0x00008904
+)
+
+// ioctl(2) requests provided by uapi/linux/sockios.h
+const (
+ SIOCGIFMEM = 0x891f
+ SIOCGIFPFLAGS = 0x8935
+ SIOCGMIIPHY = 0x8947
+ SIOCGMIIREG = 0x8948
+)
+
+// ioctl(2) requests provided by uapi/linux/android/binder.h
+const (
+ BinderWriteReadIoctl = 0xc0306201
+ BinderSetIdleTimeoutIoctl = 0x40086203
+ BinderSetMaxThreadsIoctl = 0x40046205
+ BinderSetIdlePriorityIoctl = 0x40046206
+ BinderSetContextMgrIoctl = 0x40046207
+ BinderThreadExitIoctl = 0x40046208
+ BinderVersionIoctl = 0xc0046209
+)
+
+// ioctl(2) requests provided by drivers/staging/android/uapi/ashmem.h
+const (
+ AshmemSetNameIoctl = 0x41007701
+ AshmemGetNameIoctl = 0x81007702
+ AshmemSetSizeIoctl = 0x40087703
+ AshmemGetSizeIoctl = 0x00007704
+ AshmemSetProtMaskIoctl = 0x40087705
+ AshmemGetProtMaskIoctl = 0x00007706
+ AshmemPinIoctl = 0x40087707
+ AshmemUnpinIoctl = 0x40087708
+ AshmemGetPinStatusIoctl = 0x00007709
+ AshmemPurgeAllCachesIoctl = 0x0000770a
+)
diff --git a/pkg/abi/linux/ip.go b/pkg/abi/linux/ip.go
new file mode 100644
index 000000000..31e56ffa6
--- /dev/null
+++ b/pkg/abi/linux/ip.go
@@ -0,0 +1,151 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// IP protocols
+const (
+ IPPROTO_IP = 0
+ IPPROTO_ICMP = 1
+ IPPROTO_IGMP = 2
+ IPPROTO_IPIP = 4
+ IPPROTO_TCP = 6
+ IPPROTO_EGP = 8
+ IPPROTO_PUP = 12
+ IPPROTO_UDP = 17
+ IPPROTO_IDP = 22
+ IPPROTO_TP = 29
+ IPPROTO_DCCP = 33
+ IPPROTO_IPV6 = 41
+ IPPROTO_RSVP = 46
+ IPPROTO_GRE = 47
+ IPPROTO_ESP = 50
+ IPPROTO_AH = 51
+ IPPROTO_MTP = 92
+ IPPROTO_BEETPH = 94
+ IPPROTO_ENCAP = 98
+ IPPROTO_PIM = 103
+ IPPROTO_COMP = 108
+ IPPROTO_SCTP = 132
+ IPPROTO_UDPLITE = 136
+ IPPROTO_MPLS = 137
+ IPPROTO_RAW = 255
+)
+
+// Socket options from uapi/linux/in.h
+const (
+ IP_TOS = 1
+ IP_TTL = 2
+ IP_HDRINCL = 3
+ IP_OPTIONS = 4
+ IP_ROUTER_ALERT = 5
+ IP_RECVOPTS = 6
+ IP_RETOPTS = 7
+ IP_PKTINFO = 8
+ IP_PKTOPTIONS = 9
+ IP_MTU_DISCOVER = 10
+ IP_RECVERR = 11
+ IP_RECVTTL = 12
+ IP_RECVTOS = 13
+ IP_MTU = 14
+ IP_FREEBIND = 15
+ IP_IPSEC_POLICY = 16
+ IP_XFRM_POLICY = 17
+ IP_PASSSEC = 18
+ IP_TRANSPARENT = 19
+ IP_ORIGDSTADDR = 20
+ IP_RECVORIGDSTADDR = IP_ORIGDSTADDR
+ IP_MINTTL = 21
+ IP_NODEFRAG = 22
+ IP_CHECKSUM = 23
+ IP_BIND_ADDRESS_NO_PORT = 24
+ IP_RECVFRAGSIZE = 25
+ IP_MULTICAST_IF = 32
+ IP_MULTICAST_TTL = 33
+ IP_MULTICAST_LOOP = 34
+ IP_ADD_MEMBERSHIP = 35
+ IP_DROP_MEMBERSHIP = 36
+ IP_UNBLOCK_SOURCE = 37
+ IP_BLOCK_SOURCE = 38
+ IP_ADD_SOURCE_MEMBERSHIP = 39
+ IP_DROP_SOURCE_MEMBERSHIP = 40
+ IP_MSFILTER = 41
+ MCAST_JOIN_GROUP = 42
+ MCAST_BLOCK_SOURCE = 43
+ MCAST_UNBLOCK_SOURCE = 44
+ MCAST_LEAVE_GROUP = 45
+ MCAST_JOIN_SOURCE_GROUP = 46
+ MCAST_LEAVE_SOURCE_GROUP = 47
+ MCAST_MSFILTER = 48
+ IP_MULTICAST_ALL = 49
+ IP_UNICAST_IF = 50
+)
+
+// Socket options from uapi/linux/in6.h
+const (
+ IPV6_ADDRFORM = 1
+ IPV6_2292PKTINFO = 2
+ IPV6_2292HOPOPTS = 3
+ IPV6_2292DSTOPTS = 4
+ IPV6_2292RTHDR = 5
+ IPV6_2292PKTOPTIONS = 6
+ IPV6_CHECKSUM = 7
+ IPV6_2292HOPLIMIT = 8
+ IPV6_NEXTHOP = 9
+ IPV6_FLOWINFO = 11
+ IPV6_UNICAST_HOPS = 16
+ IPV6_MULTICAST_IF = 17
+ IPV6_MULTICAST_HOPS = 18
+ IPV6_MULTICAST_LOOP = 19
+ IPV6_ADD_MEMBERSHIP = 20
+ IPV6_DROP_MEMBERSHIP = 21
+ IPV6_ROUTER_ALERT = 22
+ IPV6_MTU_DISCOVER = 23
+ IPV6_MTU = 24
+ IPV6_RECVERR = 25
+ IPV6_V6ONLY = 26
+ IPV6_JOIN_ANYCAST = 27
+ IPV6_LEAVE_ANYCAST = 28
+ IPV6_MULTICAST_ALL = 29
+ IPV6_FLOWLABEL_MGR = 32
+ IPV6_FLOWINFO_SEND = 33
+ IPV6_IPSEC_POLICY = 34
+ IPV6_XFRM_POLICY = 35
+ IPV6_HDRINCL = 36
+ IPV6_RECVPKTINFO = 49
+ IPV6_PKTINFO = 50
+ IPV6_RECVHOPLIMIT = 51
+ IPV6_HOPLIMIT = 52
+ IPV6_RECVHOPOPTS = 53
+ IPV6_HOPOPTS = 54
+ IPV6_RTHDRDSTOPTS = 55
+ IPV6_RECVRTHDR = 56
+ IPV6_RTHDR = 57
+ IPV6_RECVDSTOPTS = 58
+ IPV6_DSTOPTS = 59
+ IPV6_RECVPATHMTU = 60
+ IPV6_PATHMTU = 61
+ IPV6_DONTFRAG = 62
+ IPV6_RECVTCLASS = 66
+ IPV6_TCLASS = 67
+ IPV6_AUTOFLOWLABEL = 70
+ IPV6_ADDR_PREFERENCES = 72
+ IPV6_MINHOPCOUNT = 73
+ IPV6_ORIGDSTADDR = 74
+ IPV6_RECVORIGDSTADDR = IPV6_ORIGDSTADDR
+ IPV6_TRANSPARENT = 75
+ IPV6_UNICAST_IF = 76
+ IPV6_RECVFRAGSIZE = 77
+ IPV6_FREEBIND = 78
+)
diff --git a/pkg/abi/linux/ipc.go b/pkg/abi/linux/ipc.go
new file mode 100644
index 000000000..2ef8d6cbb
--- /dev/null
+++ b/pkg/abi/linux/ipc.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Control commands used with semctl, shmctl, and msgctl. Source:
+// include/uapi/linux/ipc.h.
+const (
+ IPC_RMID = 0
+ IPC_SET = 1
+ IPC_STAT = 2
+ IPC_INFO = 3
+)
+
+// resource get request flags. Source: include/uapi/linux/ipc.h
+const (
+ IPC_CREAT = 00001000
+ IPC_EXCL = 00002000
+ IPC_NOWAIT = 00004000
+)
+
+const IPC_PRIVATE = 0
+
+// In Linux, amd64 does not enable CONFIG_ARCH_WANT_IPC_PARSE_VERSION, so SysV
+// IPC unconditionally uses the "new" 64-bit structures that are needed for
+// features like 32-bit UIDs.
+
+// IPCPerm is equivalent to struct ipc64_perm.
+type IPCPerm struct {
+ Key uint32
+ UID uint32
+ GID uint32
+ CUID uint32
+ CGID uint32
+ Mode uint16
+ pad1 uint16
+ Seq uint16
+ pad2 uint16
+ pad3 uint32
+ unused1 uint64
+ unused2 uint64
+}
diff --git a/pkg/abi/linux/limits.go b/pkg/abi/linux/limits.go
new file mode 100644
index 000000000..c74dfcd53
--- /dev/null
+++ b/pkg/abi/linux/limits.go
@@ -0,0 +1,88 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Resources for getrlimit(2)/setrlimit(2)/prlimit(2).
+const (
+ RLIMIT_CPU = 0
+ RLIMIT_FSIZE = 1
+ RLIMIT_DATA = 2
+ RLIMIT_STACK = 3
+ RLIMIT_CORE = 4
+ RLIMIT_RSS = 5
+ RLIMIT_NPROC = 6
+ RLIMIT_NOFILE = 7
+ RLIMIT_MEMLOCK = 8
+ RLIMIT_AS = 9
+ RLIMIT_LOCKS = 10
+ RLIMIT_SIGPENDING = 11
+ RLIMIT_MSGQUEUE = 12
+ RLIMIT_NICE = 13
+ RLIMIT_RTPRIO = 14
+ RLIMIT_RTTIME = 15
+)
+
+// RLimit corresponds to Linux's struct rlimit.
+type RLimit struct {
+ // Cur specifies the soft limit.
+ Cur uint64
+ // Max specifies the hard limit.
+ Max uint64
+}
+
+const (
+ // RLimInfinity is RLIM_INFINITY on Linux.
+ RLimInfinity = ^uint64(0)
+
+ // DefaultStackSoftLimit is called _STK_LIM in Linux.
+ DefaultStackSoftLimit = 8 * 1024 * 1024
+
+ // DefaultNprocLimit is defined in kernel/fork.c:set_max_threads, and
+ // called MAX_THREADS / 2 in Linux.
+ DefaultNprocLimit = FUTEX_TID_MASK / 2
+
+ // DefaultNofileSoftLimit is called INR_OPEN_CUR in Linux.
+ DefaultNofileSoftLimit = 1024
+
+ // DefaultNofileHardLimit is called INR_OPEN_MAX in Linux.
+ DefaultNofileHardLimit = 4096
+
+ // DefaultMemlockLimit is called MLOCK_LIMIT in Linux.
+ DefaultMemlockLimit = 64 * 1024
+
+ // DefaultMsgqueueLimit is called MQ_BYTES_MAX in Linux.
+ DefaultMsgqueueLimit = 819200
+)
+
+// InitRLimits is a map of initial rlimits set by Linux in
+// include/asm-generic/resource.h.
+var InitRLimits = map[int]RLimit{
+ RLIMIT_CPU: {RLimInfinity, RLimInfinity},
+ RLIMIT_FSIZE: {RLimInfinity, RLimInfinity},
+ RLIMIT_DATA: {RLimInfinity, RLimInfinity},
+ RLIMIT_STACK: {DefaultStackSoftLimit, RLimInfinity},
+ RLIMIT_CORE: {0, RLimInfinity},
+ RLIMIT_RSS: {RLimInfinity, RLimInfinity},
+ RLIMIT_NPROC: {DefaultNprocLimit, DefaultNprocLimit},
+ RLIMIT_NOFILE: {DefaultNofileSoftLimit, DefaultNofileHardLimit},
+ RLIMIT_MEMLOCK: {DefaultMemlockLimit, DefaultMemlockLimit},
+ RLIMIT_AS: {RLimInfinity, RLimInfinity},
+ RLIMIT_LOCKS: {RLimInfinity, RLimInfinity},
+ RLIMIT_SIGPENDING: {0, 0},
+ RLIMIT_MSGQUEUE: {DefaultMsgqueueLimit, DefaultMsgqueueLimit},
+ RLIMIT_NICE: {0, 0},
+ RLIMIT_RTPRIO: {0, 0},
+ RLIMIT_RTTIME: {RLimInfinity, RLimInfinity},
+}
diff --git a/pkg/abi/linux/linux.go b/pkg/abi/linux/linux.go
new file mode 100644
index 000000000..8a8f831cd
--- /dev/null
+++ b/pkg/abi/linux/linux.go
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package linux contains the constants and types needed to inferface with a Linux kernel.
+package linux
+
+// NumSoftIRQ is the number of software IRQs, exposed via /proc/stat.
+//
+// Defined in linux/interrupt.h.
+const NumSoftIRQ = 10
+
+// Sysinfo is the structure provided by sysinfo on linux versions > 2.3.48.
+type Sysinfo struct {
+ Uptime int64
+ Loads [3]uint64
+ TotalRAM uint64
+ FreeRAM uint64
+ SharedRAM uint64
+ BufferRAM uint64
+ TotalSwap uint64
+ FreeSwap uint64
+ Procs uint16
+ _ [6]byte // Pad Procs to 64bits.
+ TotalHigh uint64
+ FreeHigh uint64
+ Unit uint32
+ /* The _f field in the glibc version of Sysinfo has size 0 on AMD64 */
+}
diff --git a/pkg/abi/linux/linux_state_autogen.go b/pkg/abi/linux/linux_state_autogen.go
new file mode 100755
index 000000000..3495e2e74
--- /dev/null
+++ b/pkg/abi/linux/linux_state_autogen.go
@@ -0,0 +1,68 @@
+// automatically generated by stateify.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *BPFInstruction) beforeSave() {}
+func (x *BPFInstruction) save(m state.Map) {
+ x.beforeSave()
+ m.Save("OpCode", &x.OpCode)
+ m.Save("JumpIfTrue", &x.JumpIfTrue)
+ m.Save("JumpIfFalse", &x.JumpIfFalse)
+ m.Save("K", &x.K)
+}
+
+func (x *BPFInstruction) afterLoad() {}
+func (x *BPFInstruction) load(m state.Map) {
+ m.Load("OpCode", &x.OpCode)
+ m.Load("JumpIfTrue", &x.JumpIfTrue)
+ m.Load("JumpIfFalse", &x.JumpIfFalse)
+ m.Load("K", &x.K)
+}
+
+func (x *KernelTermios) beforeSave() {}
+func (x *KernelTermios) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InputFlags", &x.InputFlags)
+ m.Save("OutputFlags", &x.OutputFlags)
+ m.Save("ControlFlags", &x.ControlFlags)
+ m.Save("LocalFlags", &x.LocalFlags)
+ m.Save("LineDiscipline", &x.LineDiscipline)
+ m.Save("ControlCharacters", &x.ControlCharacters)
+ m.Save("InputSpeed", &x.InputSpeed)
+ m.Save("OutputSpeed", &x.OutputSpeed)
+}
+
+func (x *KernelTermios) afterLoad() {}
+func (x *KernelTermios) load(m state.Map) {
+ m.Load("InputFlags", &x.InputFlags)
+ m.Load("OutputFlags", &x.OutputFlags)
+ m.Load("ControlFlags", &x.ControlFlags)
+ m.Load("LocalFlags", &x.LocalFlags)
+ m.Load("LineDiscipline", &x.LineDiscipline)
+ m.Load("ControlCharacters", &x.ControlCharacters)
+ m.Load("InputSpeed", &x.InputSpeed)
+ m.Load("OutputSpeed", &x.OutputSpeed)
+}
+
+func (x *WindowSize) beforeSave() {}
+func (x *WindowSize) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Rows", &x.Rows)
+ m.Save("Cols", &x.Cols)
+}
+
+func (x *WindowSize) afterLoad() {}
+func (x *WindowSize) load(m state.Map) {
+ m.Load("Rows", &x.Rows)
+ m.Load("Cols", &x.Cols)
+}
+
+func init() {
+ state.Register("linux.BPFInstruction", (*BPFInstruction)(nil), state.Fns{Save: (*BPFInstruction).save, Load: (*BPFInstruction).load})
+ state.Register("linux.KernelTermios", (*KernelTermios)(nil), state.Fns{Save: (*KernelTermios).save, Load: (*KernelTermios).load})
+ state.Register("linux.WindowSize", (*WindowSize)(nil), state.Fns{Save: (*WindowSize).save, Load: (*WindowSize).load})
+}
diff --git a/pkg/abi/linux/mm.go b/pkg/abi/linux/mm.go
new file mode 100644
index 000000000..0b02f938a
--- /dev/null
+++ b/pkg/abi/linux/mm.go
@@ -0,0 +1,116 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Protections for mmap(2).
+const (
+ PROT_NONE = 0
+ PROT_READ = 1 << 0
+ PROT_WRITE = 1 << 1
+ PROT_EXEC = 1 << 2
+ PROT_SEM = 1 << 3
+ PROT_GROWSDOWN = 1 << 24
+ PROT_GROWSUP = 1 << 25
+)
+
+// Flags for mmap(2).
+const (
+ MAP_SHARED = 1 << 0
+ MAP_PRIVATE = 1 << 1
+ MAP_FIXED = 1 << 4
+ MAP_ANONYMOUS = 1 << 5
+ MAP_32BIT = 1 << 6 // arch/x86/include/uapi/asm/mman.h
+ MAP_GROWSDOWN = 1 << 8
+ MAP_DENYWRITE = 1 << 11
+ MAP_EXECUTABLE = 1 << 12
+ MAP_LOCKED = 1 << 13
+ MAP_NORESERVE = 1 << 14
+ MAP_POPULATE = 1 << 15
+ MAP_NONBLOCK = 1 << 16
+ MAP_STACK = 1 << 17
+ MAP_HUGETLB = 1 << 18
+)
+
+// Flags for mremap(2).
+const (
+ MREMAP_MAYMOVE = 1 << 0
+ MREMAP_FIXED = 1 << 1
+)
+
+// Flags for mlock2(2).
+const (
+ MLOCK_ONFAULT = 0x01
+)
+
+// Flags for mlockall(2).
+const (
+ MCL_CURRENT = 1
+ MCL_FUTURE = 2
+ MCL_ONFAULT = 4
+)
+
+// Advice for madvise(2).
+const (
+ MADV_NORMAL = 0
+ MADV_RANDOM = 1
+ MADV_SEQUENTIAL = 2
+ MADV_WILLNEED = 3
+ MADV_DONTNEED = 4
+ MADV_REMOVE = 9
+ MADV_DONTFORK = 10
+ MADV_DOFORK = 11
+ MADV_MERGEABLE = 12
+ MADV_UNMERGEABLE = 13
+ MADV_HUGEPAGE = 14
+ MADV_NOHUGEPAGE = 15
+ MADV_DONTDUMP = 16
+ MADV_DODUMP = 17
+ MADV_HWPOISON = 100
+ MADV_SOFT_OFFLINE = 101
+ MADV_NOMAJFAULT = 200
+ MADV_DONTCHGME = 201
+)
+
+// Flags for msync(2).
+const (
+ MS_ASYNC = 1 << 0
+ MS_INVALIDATE = 1 << 1
+ MS_SYNC = 1 << 2
+)
+
+// Policies for get_mempolicy(2)/set_mempolicy(2).
+const (
+ MPOL_DEFAULT = 0
+ MPOL_PREFERRED = 1
+ MPOL_BIND = 2
+ MPOL_INTERLEAVE = 3
+ MPOL_LOCAL = 4
+ MPOL_MAX = 5
+)
+
+// Flags for get_mempolicy(2).
+const (
+ MPOL_F_NODE = 1 << 0
+ MPOL_F_ADDR = 1 << 1
+ MPOL_F_MEMS_ALLOWED = 1 << 2
+)
+
+// Flags for set_mempolicy(2).
+const (
+ MPOL_F_RELATIVE_NODES = 1 << 14
+ MPOL_F_STATIC_NODES = 1 << 15
+
+ MPOL_MODE_FLAGS = (MPOL_F_STATIC_NODES | MPOL_F_RELATIVE_NODES)
+)
diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go
new file mode 100644
index 000000000..aef1acf75
--- /dev/null
+++ b/pkg/abi/linux/netdevice.go
@@ -0,0 +1,86 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "gvisor.googlesource.com/gvisor/pkg/binary"
+
+const (
+ // IFNAMSIZ is the size of the name field for IFReq.
+ IFNAMSIZ = 16
+)
+
+// IFReq is an interface request.
+type IFReq struct {
+ // IFName is an encoded name, normally null-terminated. This should be
+ // accessed via the Name and SetName functions.
+ IFName [IFNAMSIZ]byte
+
+ // Data is the union of the following structures:
+ //
+ // struct sockaddr ifr_addr;
+ // struct sockaddr ifr_dstaddr;
+ // struct sockaddr ifr_broadaddr;
+ // struct sockaddr ifr_netmask;
+ // struct sockaddr ifr_hwaddr;
+ // short ifr_flags;
+ // int ifr_ifindex;
+ // int ifr_metric;
+ // int ifr_mtu;
+ // struct ifmap ifr_map;
+ // char ifr_slave[IFNAMSIZ];
+ // char ifr_newname[IFNAMSIZ];
+ // char *ifr_data;
+ Data [24]byte
+}
+
+// Name returns the name.
+func (ifr *IFReq) Name() string {
+ for c := 0; c < len(ifr.IFName); c++ {
+ if ifr.IFName[c] == 0 {
+ return string(ifr.IFName[:c])
+ }
+ }
+ return string(ifr.IFName[:])
+}
+
+// SetName sets the name.
+func (ifr *IFReq) SetName(name string) {
+ n := copy(ifr.IFName[:], []byte(name))
+ for i := n; i < len(ifr.IFName); i++ {
+ ifr.IFName[i] = 0
+ }
+}
+
+// SizeOfIFReq is the binary size of an IFReq struct (40 bytes).
+var SizeOfIFReq = binary.Size(IFReq{})
+
+// IFMap contains interface hardware parameters.
+type IFMap struct {
+ MemStart uint64
+ MemEnd uint64
+ BaseAddr int16
+ IRQ byte
+ DMA byte
+ Port byte
+ _ [3]byte // Pad to sizeof(struct ifmap).
+}
+
+// IFConf is used to return a list of interfaces and their addresses. See
+// netdevice(7) and struct ifconf for more detail on its use.
+type IFConf struct {
+ Len int32
+ _ [4]byte // Pad to sizeof(struct ifconf).
+ Ptr uint64
+}
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
new file mode 100644
index 000000000..7f399142b
--- /dev/null
+++ b/pkg/abi/linux/netfilter.go
@@ -0,0 +1,240 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// This file contains structures required to support netfilter, specifically
+// the iptables tool.
+
+// Hooks into the network stack. These correspond to values in
+// include/uapi/linux/netfilter.h.
+const (
+ NF_INET_PRE_ROUTING = 0
+ NF_INET_LOCAL_IN = 1
+ NF_INET_FORWARD = 2
+ NF_INET_LOCAL_OUT = 3
+ NF_INET_POST_ROUTING = 4
+ NF_INET_NUMHOOKS = 5
+)
+
+// Verdicts that can be returned by targets. These correspond to values in
+// include/uapi/linux/netfilter.h
+const (
+ NF_DROP = 0
+ NF_ACCEPT = 1
+ NF_STOLEN = 2
+ NF_QUEUE = 3
+ NF_REPEAT = 4
+ NF_STOP = 5
+ NF_MAX_VERDICT = NF_STOP
+ // NF_RETURN is defined in include/uapi/linux/netfilter/x_tables.h.
+ NF_RETURN = -NF_REPEAT - 1
+)
+
+// Socket options. These correspond to values in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+const (
+ IPT_BASE_CTL = 64
+ IPT_SO_SET_REPLACE = IPT_BASE_CTL
+ IPT_SO_SET_ADD_COUNTERS = IPT_BASE_CTL + 1
+ IPT_SO_SET_MAX = IPT_SO_SET_ADD_COUNTERS
+
+ IPT_SO_GET_INFO = IPT_BASE_CTL
+ IPT_SO_GET_ENTRIES = IPT_BASE_CTL + 1
+ IPT_SO_GET_REVISION_MATCH = IPT_BASE_CTL + 2
+ IPT_SO_GET_REVISION_TARGET = IPT_BASE_CTL + 3
+ IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET
+)
+
+// Name lengths. These correspond to values in
+// include/uapi/linux/netfilter/x_tables.h.
+const (
+ XT_FUNCTION_MAXNAMELEN = 30
+ XT_EXTENSION_MAXNAMELEN = 29
+ XT_TABLE_MAXNAMELEN = 32
+)
+
+// IPTEntry is an iptable rule. It corresponds to struct ipt_entry in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTEntry struct {
+ // IP is used to filter packets based on the IP header.
+ IP IPTIP
+
+ // NFCache relates to kernel-internal caching and isn't used by
+ // userspace.
+ NFCache uint32
+
+ // TargetOffset is the byte offset from the beginning of this IPTEntry
+ // to the start of the entry's target.
+ TargetOffset uint16
+
+ // NextOffset is the byte offset from the beginning of this IPTEntry to
+ // the start of the next entry. It is thus also the size of the entry.
+ NextOffset uint16
+
+ // Comeback is a return pointer. It is not used by userspace.
+ Comeback uint32
+
+ // Counters holds the packet and byte counts for this rule.
+ Counters XTCounters
+
+ // Elems holds the data for all this rule's matches followed by the
+ // target. It is variable length -- users have to iterate over any
+ // matches and use TargetOffset and NextOffset to make sense of the
+ // data.
+ //
+ // Elems is omitted here because it would cause IPTEntry to be an extra
+ // byte larger (see http://www.catb.org/esr/structure-packing/).
+ //
+ // Elems [0]byte
+}
+
+// IPTIP contains information for matching a packet's IP header.
+// It corresponds to struct ipt_ip in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTIP struct {
+ // Src is the source IP address.
+ Src InetAddr
+
+ // Dst is the destination IP address.
+ Dst InetAddr
+
+ // SrcMask is the source IP mask.
+ SrcMask InetAddr
+
+ // DstMask is the destination IP mask.
+ DstMask InetAddr
+
+ // InputInterface is the input network interface.
+ InputInterface [IFNAMSIZ]byte
+
+ // OutputInterface is the output network interface.
+ OutputInterface [IFNAMSIZ]byte
+
+ // InputInterfaceMask is the intput interface mask.
+ InputInterfaceMast [IFNAMSIZ]byte
+
+ // OuputInterfaceMask is the output interface mask.
+ OuputInterfaceMask [IFNAMSIZ]byte
+
+ // Protocol is the transport protocol.
+ Protocol uint16
+
+ // Flags define matching behavior for the IP header.
+ Flags uint8
+
+ // InverseFlags invert the meaning of fields in struct IPTIP.
+ InverseFlags uint8
+}
+
+// XTCounters holds packet and byte counts for a rule. It corresponds to struct
+// xt_counters in include/uapi/linux/netfilter/x_tables.h.
+type XTCounters struct {
+ // Pcnt is the packet count.
+ Pcnt uint64
+
+ // Bcnt is the byte count.
+ Bcnt uint64
+}
+
+// XTEntryMatch holds a match for a rule. For example, a user using the
+// addrtype iptables match extension would put the data for that match into an
+// XTEntryMatch. iptables-extensions(8) has a list of possible matches.
+//
+// XTEntryMatch corresponds to struct xt_entry_match in
+// include/uapi/linux/netfilter/x_tables.h. That struct contains a union
+// exposing different data to the user and kernel, but this struct holds only
+// the user data.
+type XTEntryMatch struct {
+ MatchSize uint16
+ Name [XT_EXTENSION_MAXNAMELEN]byte
+ Revision uint8
+ // Data is omitted here because it would cause XTEntryTarget to be an
+ // extra byte larger (see http://www.catb.org/esr/structure-packing/).
+ // Data [0]byte
+}
+
+// XTEntryTarget holds a target for a rule. For example, it can specify that
+// packets matching the rule should DROP, ACCEPT, or use an extension target.
+// iptables-extension(8) has a list of possible targets.
+//
+// XTEntryTarget corresponds to struct xt_entry_target in
+// include/uapi/linux/netfilter/x_tables.h. That struct contains a union
+// exposing different data to the user and kernel, but this struct holds only
+// the user data.
+type XTEntryTarget struct {
+ MatchSize uint16
+ Name [XT_EXTENSION_MAXNAMELEN]byte
+ Revision uint8
+ // Data is omitted here because it would cause XTEntryTarget to be an
+ // extra byte larger (see http://www.catb.org/esr/structure-packing/).
+ // Data [0]byte
+}
+
+// XTStandardTarget is a builtin target, one of ACCEPT, DROP, JUMP, QUEUE, or
+// RETURN. It corresponds to struct xt_standard_target in
+// include/uapi/linux/netfilter/x_tables.h.
+type XTStandardTarget struct {
+ Target XTEntryTarget
+ Verdict int32
+}
+
+// XTErrorTarget triggers an error when reached. It is also used to mark the
+// beginning of user-defined chains by putting the name of the chain in
+// ErrorName. It corresponds to struct xt_error_target in
+// include/uapi/linux/netfilter/x_tables.h.
+type XTErrorTarget struct {
+ Target XTEntryTarget
+ ErrorName [XT_FUNCTION_MAXNAMELEN]byte
+}
+
+// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
+// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTGetinfo struct {
+ Name [XT_TABLE_MAXNAMELEN]byte
+ ValidHooks uint32
+ HookEntry [NF_INET_NUMHOOKS]uint32
+ Underflow [NF_INET_NUMHOOKS]uint32
+ NumEntries uint32
+ Size uint32
+}
+
+// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
+// corresponds to struct ipt_get_entries in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTGetEntries struct {
+ Name [XT_TABLE_MAXNAMELEN]byte
+ Size uint32
+ // Entrytable is omitted here because it would cause IPTGetEntries to
+ // be an extra byte longer (see
+ // http://www.catb.org/esr/structure-packing/).
+ // Entrytable [0]IPTEntry
+}
+
+// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
+// corresponds to struct ipt_replace in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+type IPTReplace struct {
+ Name [XT_TABLE_MAXNAMELEN]byte
+ ValidHooks uint32
+ NumEntries uint32
+ Size uint32
+ HookEntry [NF_INET_NUMHOOKS]uint32
+ Underflow [NF_INET_NUMHOOKS]uint32
+ NumCounters uint32
+ Counters *XTCounters
+ // Entries is omitted here because it would cause IPTReplace to be an
+ // extra byte longer (see http://www.catb.org/esr/structure-packing/).
+ // Entries [0]IPTEntry
+}
diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go
new file mode 100644
index 000000000..5e718c363
--- /dev/null
+++ b/pkg/abi/linux/netlink.go
@@ -0,0 +1,124 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Netlink protocols, from uapi/linux/netlink.h.
+const (
+ NETLINK_ROUTE = 0
+ NETLINK_UNUSED = 1
+ NETLINK_USERSOCK = 2
+ NETLINK_FIREWALL = 3
+ NETLINK_SOCK_DIAG = 4
+ NETLINK_NFLOG = 5
+ NETLINK_XFRM = 6
+ NETLINK_SELINUX = 7
+ NETLINK_ISCSI = 8
+ NETLINK_AUDIT = 9
+ NETLINK_FIB_LOOKUP = 10
+ NETLINK_CONNECTOR = 11
+ NETLINK_NETFILTER = 12
+ NETLINK_IP6_FW = 13
+ NETLINK_DNRTMSG = 14
+ NETLINK_KOBJECT_UEVENT = 15
+ NETLINK_GENERIC = 16
+ NETLINK_SCSITRANSPORT = 18
+ NETLINK_ECRYPTFS = 19
+ NETLINK_RDMA = 20
+ NETLINK_CRYPTO = 21
+)
+
+// SockAddrNetlink is struct sockaddr_nl, from uapi/linux/netlink.h.
+type SockAddrNetlink struct {
+ Family uint16
+ Padding uint16
+ PortID uint32
+ Groups uint32
+}
+
+// SockAddrNetlinkSize is the size of SockAddrNetlink.
+const SockAddrNetlinkSize = 12
+
+// NetlinkMessageHeader is struct nlmsghdr, from uapi/linux/netlink.h.
+type NetlinkMessageHeader struct {
+ Length uint32
+ Type uint16
+ Flags uint16
+ Seq uint32
+ PortID uint32
+}
+
+// NetlinkMessageHeaderSize is the size of NetlinkMessageHeader.
+const NetlinkMessageHeaderSize = 16
+
+// Netlink message header flags, from uapi/linux/netlink.h.
+const (
+ NLM_F_REQUEST = 0x1
+ NLM_F_MULTI = 0x2
+ NLM_F_ACK = 0x4
+ NLM_F_ECHO = 0x8
+ NLM_F_DUMP_INTR = 0x10
+ NLM_F_ROOT = 0x100
+ NLM_F_MATCH = 0x200
+ NLM_F_ATOMIC = 0x400
+ NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
+ NLM_F_REPLACE = 0x100
+ NLM_F_EXCL = 0x200
+ NLM_F_CREATE = 0x400
+ NLM_F_APPEND = 0x800
+)
+
+// Standard netlink message types, from uapi/linux/netlink.h.
+const (
+ NLMSG_NOOP = 0x1
+ NLMSG_ERROR = 0x2
+ NLMSG_DONE = 0x3
+ NLMSG_OVERRUN = 0x4
+
+ // NLMSG_MIN_TYPE is the first value for protocol-level types.
+ NLMSG_MIN_TYPE = 0x10
+)
+
+// NLMSG_ALIGNTO is the alignment of netlink messages, from
+// uapi/linux/netlink.h.
+const NLMSG_ALIGNTO = 4
+
+// NetlinkAttrHeader is the header of a netlink attribute, followed by payload.
+//
+// This is struct nlattr, from uapi/linux/netlink.h.
+type NetlinkAttrHeader struct {
+ Length uint16
+ Type uint16
+}
+
+// NetlinkAttrHeaderSize is the size of NetlinkAttrHeader.
+const NetlinkAttrHeaderSize = 4
+
+// NLA_ALIGNTO is the alignment of netlink attributes, from
+// uapi/linux/netlink.h.
+const NLA_ALIGNTO = 4
+
+// Socket options, from uapi/linux/netlink.h.
+const (
+ NETLINK_ADD_MEMBERSHIP = 1
+ NETLINK_DROP_MEMBERSHIP = 2
+ NETLINK_PKTINFO = 3
+ NETLINK_BROADCAST_ERROR = 4
+ NETLINK_NO_ENOBUFS = 5
+ NETLINK_LISTEN_ALL_NSID = 8
+ NETLINK_LIST_MEMBERSHIPS = 9
+ NETLINK_CAP_ACK = 10
+ NETLINK_EXT_ACK = 11
+ NETLINK_DUMP_STRICT_CHK = 12
+)
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
new file mode 100644
index 000000000..630dc339a
--- /dev/null
+++ b/pkg/abi/linux/netlink_route.go
@@ -0,0 +1,191 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Netlink message types for NETLINK_ROUTE sockets, from uapi/linux/rtnetlink.h.
+const (
+ RTM_NEWLINK = 16
+ RTM_DELLINK = 17
+ RTM_GETLINK = 18
+ RTM_SETLINK = 19
+
+ RTM_NEWADDR = 20
+ RTM_DELADDR = 21
+ RTM_GETADDR = 22
+
+ RTM_NEWROUTE = 24
+ RTM_DELROUTE = 25
+ RTM_GETROUTE = 26
+
+ RTM_NEWNEIGH = 28
+ RTM_DELNEIGH = 29
+ RTM_GETNEIGH = 30
+
+ RTM_NEWRULE = 32
+ RTM_DELRULE = 33
+ RTM_GETRULE = 34
+
+ RTM_NEWQDISC = 36
+ RTM_DELQDISC = 37
+ RTM_GETQDISC = 38
+
+ RTM_NEWTCLASS = 40
+ RTM_DELTCLASS = 41
+ RTM_GETTCLASS = 42
+
+ RTM_NEWTFILTER = 44
+ RTM_DELTFILTER = 45
+ RTM_GETTFILTER = 46
+
+ RTM_NEWACTION = 48
+ RTM_DELACTION = 49
+ RTM_GETACTION = 50
+
+ RTM_NEWPREFIX = 52
+
+ RTM_GETMULTICAST = 58
+
+ RTM_GETANYCAST = 62
+
+ RTM_NEWNEIGHTBL = 64
+ RTM_GETNEIGHTBL = 66
+ RTM_SETNEIGHTBL = 67
+
+ RTM_NEWNDUSEROPT = 68
+
+ RTM_NEWADDRLABEL = 72
+ RTM_DELADDRLABEL = 73
+ RTM_GETADDRLABEL = 74
+
+ RTM_GETDCB = 78
+ RTM_SETDCB = 79
+
+ RTM_NEWNETCONF = 80
+ RTM_GETNETCONF = 82
+
+ RTM_NEWMDB = 84
+ RTM_DELMDB = 85
+ RTM_GETMDB = 86
+
+ RTM_NEWNSID = 88
+ RTM_DELNSID = 89
+ RTM_GETNSID = 90
+)
+
+// InterfaceInfoMessage is struct ifinfomsg, from uapi/linux/rtnetlink.h.
+type InterfaceInfoMessage struct {
+ Family uint8
+ Padding uint8
+ Type uint16
+ Index int32
+ Flags uint32
+ Change uint32
+}
+
+// Interface flags, from uapi/linux/if.h.
+const (
+ IFF_UP = 1 << 0
+ IFF_BROADCAST = 1 << 1
+ IFF_DEBUG = 1 << 2
+ IFF_LOOPBACK = 1 << 3
+ IFF_POINTOPOINT = 1 << 4
+ IFF_NOTRAILERS = 1 << 5
+ IFF_RUNNING = 1 << 6
+ IFF_NOARP = 1 << 7
+ IFF_PROMISC = 1 << 8
+ IFF_ALLMULTI = 1 << 9
+ IFF_MASTER = 1 << 10
+ IFF_SLAVE = 1 << 11
+ IFF_MULTICAST = 1 << 12
+ IFF_PORTSEL = 1 << 13
+ IFF_AUTOMEDIA = 1 << 14
+ IFF_DYNAMIC = 1 << 15
+ IFF_LOWER_UP = 1 << 16
+ IFF_DORMANT = 1 << 17
+ IFF_ECHO = 1 << 18
+)
+
+// Interface link attributes, from uapi/linux/if_link.h.
+const (
+ IFLA_UNSPEC = 0
+ IFLA_ADDRESS = 1
+ IFLA_BROADCAST = 2
+ IFLA_IFNAME = 3
+ IFLA_MTU = 4
+ IFLA_LINK = 5
+ IFLA_QDISC = 6
+ IFLA_STATS = 7
+ IFLA_COST = 8
+ IFLA_PRIORITY = 9
+ IFLA_MASTER = 10
+ IFLA_WIRELESS = 11
+ IFLA_PROTINFO = 12
+ IFLA_TXQLEN = 13
+ IFLA_MAP = 14
+ IFLA_WEIGHT = 15
+ IFLA_OPERSTATE = 16
+ IFLA_LINKMODE = 17
+ IFLA_LINKINFO = 18
+ IFLA_NET_NS_PID = 19
+ IFLA_IFALIAS = 20
+ IFLA_NUM_VF = 21
+ IFLA_VFINFO_LIST = 22
+ IFLA_STATS64 = 23
+ IFLA_VF_PORTS = 24
+ IFLA_PORT_SELF = 25
+ IFLA_AF_SPEC = 26
+ IFLA_GROUP = 27
+ IFLA_NET_NS_FD = 28
+ IFLA_EXT_MASK = 29
+ IFLA_PROMISCUITY = 30
+ IFLA_NUM_TX_QUEUES = 31
+ IFLA_NUM_RX_QUEUES = 32
+ IFLA_CARRIER = 33
+ IFLA_PHYS_PORT_ID = 34
+ IFLA_CARRIER_CHANGES = 35
+ IFLA_PHYS_SWITCH_ID = 36
+ IFLA_LINK_NETNSID = 37
+ IFLA_PHYS_PORT_NAME = 38
+ IFLA_PROTO_DOWN = 39
+ IFLA_GSO_MAX_SEGS = 40
+ IFLA_GSO_MAX_SIZE = 41
+)
+
+// InterfaceAddrMessage is struct ifaddrmsg, from uapi/linux/if_addr.h.
+type InterfaceAddrMessage struct {
+ Family uint8
+ PrefixLen uint8
+ Flags uint8
+ Scope uint8
+ Index uint32
+}
+
+// Interface attributes, from uapi/linux/if_addr.h.
+const (
+ IFA_UNSPEC = 0
+ IFA_ADDRESS = 1
+ IFA_LOCAL = 2
+ IFA_LABEL = 3
+ IFA_BROADCAST = 4
+ IFA_ANYCAST = 5
+ IFA_CACHEINFO = 6
+ IFA_MULTICAST = 7
+ IFA_FLAGS = 8
+)
+
+// Device types, from uapi/linux/if_arp.h.
+const (
+ ARPHRD_LOOPBACK = 772
+)
diff --git a/pkg/abi/linux/poll.go b/pkg/abi/linux/poll.go
new file mode 100644
index 000000000..c04d26e4c
--- /dev/null
+++ b/pkg/abi/linux/poll.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// PollFD is struct pollfd, used by poll(2)/ppoll(2), from uapi/asm-generic/poll.h.
+type PollFD struct {
+ FD int32
+ Events int16
+ REvents int16
+}
+
+// Poll event flags, used by poll(2)/ppoll(2) and/or
+// epoll_ctl(2)/epoll_wait(2), from uapi/asm-generic/poll.h.
+const (
+ POLLIN = 0x0001
+ POLLPRI = 0x0002
+ POLLOUT = 0x0004
+ POLLERR = 0x0008
+ POLLHUP = 0x0010
+ POLLNVAL = 0x0020
+ POLLRDNORM = 0x0040
+ POLLRDBAND = 0x0080
+ POLLWRNORM = 0x0100
+ POLLWRBAND = 0x0200
+ POLLMSG = 0x0400
+ POLLREMOVE = 0x1000
+ POLLRDHUP = 0x2000
+ POLLFREE = 0x4000
+ POLL_BUSY_LOOP = 0x8000
+)
diff --git a/pkg/abi/linux/prctl.go b/pkg/abi/linux/prctl.go
new file mode 100644
index 000000000..0428282dd
--- /dev/null
+++ b/pkg/abi/linux/prctl.go
@@ -0,0 +1,157 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// PR_* flags, from <linux/pcrtl.h> for prctl(2).
+const (
+ // PR_SET_PDEATHSIG sets the process' death signal.
+ PR_SET_PDEATHSIG = 1
+
+ // PR_GET_PDEATHSIG gets the process' death signal.
+ PR_GET_PDEATHSIG = 2
+
+ // PR_GET_DUMPABLE gets the process' dumpable flag.
+ PR_GET_DUMPABLE = 3
+
+ // PR_SET_DUMPABLE sets the process' dumpable flag.
+ PR_SET_DUMPABLE = 4
+
+ // PR_GET_KEEPCAPS gets the value of the keep capabilities flag.
+ PR_GET_KEEPCAPS = 7
+
+ // PR_SET_KEEPCAPS sets the value of the keep capabilities flag.
+ PR_SET_KEEPCAPS = 8
+
+ // PR_GET_TIMING gets the process' timing method.
+ PR_GET_TIMING = 13
+
+ // PR_SET_TIMING sets the process' timing method.
+ PR_SET_TIMING = 14
+
+ // PR_SET_NAME sets the process' name.
+ PR_SET_NAME = 15
+
+ // PR_GET_NAME gets the process' name.
+ PR_GET_NAME = 16
+
+ // PR_GET_SECCOMP gets a process' seccomp mode.
+ PR_GET_SECCOMP = 21
+
+ // PR_SET_SECCOMP sets a process' seccomp mode.
+ PR_SET_SECCOMP = 22
+
+ // PR_CAPBSET_READ gets the capability bounding set.
+ PR_CAPBSET_READ = 23
+
+ // PR_CAPBSET_DROP sets the capability bounding set.
+ PR_CAPBSET_DROP = 24
+
+ // PR_GET_TSC gets the value of the flag determining whether the
+ // timestamp counter can be read.
+ PR_GET_TSC = 25
+
+ // PR_SET_TSC sets the value of the flag determining whether the
+ // timestamp counter can be read.
+ PR_SET_TSC = 26
+
+ // PR_SET_TIMERSLACK sets the process' time slack.
+ PR_SET_TIMERSLACK = 29
+
+ // PR_GET_TIMERSLACK gets the process' time slack.
+ PR_GET_TIMERSLACK = 30
+
+ // PR_TASK_PERF_EVENTS_DISABLE disables all performance counters
+ // attached to the calling process.
+ PR_TASK_PERF_EVENTS_DISABLE = 31
+
+ // PR_TASK_PERF_EVENTS_ENABLE enables all performance counters attached
+ // to the calling process.
+ PR_TASK_PERF_EVENTS_ENABLE = 32
+
+ // PR_MCE_KILL sets the machine check memory corruption kill policy for
+ // the calling thread.
+ PR_MCE_KILL = 33
+
+ // PR_MCE_KILL_GET gets the machine check memory corruption kill policy
+ // for the calling thread.
+ PR_MCE_KILL_GET = 34
+
+ // PR_SET_MM modifies certain kernel memory map descriptor fields of
+ // the calling process. See prctl(2) for more information.
+ PR_SET_MM = 35
+
+ PR_SET_MM_START_CODE = 1
+ PR_SET_MM_END_CODE = 2
+ PR_SET_MM_START_DATA = 3
+ PR_SET_MM_END_DATA = 4
+ PR_SET_MM_START_STACK = 5
+ PR_SET_MM_START_BRK = 6
+ PR_SET_MM_BRK = 7
+ PR_SET_MM_ARG_START = 8
+ PR_SET_MM_ARG_END = 9
+ PR_SET_MM_ENV_START = 10
+ PR_SET_MM_ENV_END = 11
+ PR_SET_MM_AUXV = 12
+ // PR_SET_MM_EXE_FILE supersedes the /proc/pid/exe symbolic link with a
+ // new one pointing to a new executable file identified by the file
+ // descriptor provided in arg3 argument. See prctl(2) for more
+ // information.
+ PR_SET_MM_EXE_FILE = 13
+ PR_SET_MM_MAP = 14
+ PR_SET_MM_MAP_SIZE = 15
+
+ // PR_SET_CHILD_SUBREAPER sets the "child subreaper" attribute of the
+ // calling process.
+ PR_SET_CHILD_SUBREAPER = 36
+
+ // PR_GET_CHILD_SUBREAPER gets the "child subreaper" attribute of the
+ // calling process.
+ PR_GET_CHILD_SUBREAPER = 37
+
+ // PR_SET_NO_NEW_PRIVS sets the calling thread's no_new_privs bit.
+ PR_SET_NO_NEW_PRIVS = 38
+
+ // PR_GET_NO_NEW_PRIVS gets the calling thread's no_new_privs bit.
+ PR_GET_NO_NEW_PRIVS = 39
+
+ // PR_GET_TID_ADDRESS retrieves the clear_child_tid address.
+ PR_GET_TID_ADDRESS = 40
+
+ // PR_SET_THP_DISABLE sets the state of the "THP disable" flag for the
+ // calling thread.
+ PR_SET_THP_DISABLE = 41
+
+ // PR_GET_THP_DISABLE gets the state of the "THP disable" flag for the
+ // calling thread.
+ PR_GET_THP_DISABLE = 42
+
+ // PR_MPX_ENABLE_MANAGEMENT enables kernel management of Memory
+ // Protection eXtensions (MPX) bounds tables.
+ PR_MPX_ENABLE_MANAGEMENT = 43
+
+ // PR_MPX_DISABLE_MANAGEMENT disables kernel management of Memory
+ // Protection eXtensions (MPX) bounds tables.
+ PR_MPX_DISABLE_MANAGEMENT = 44
+)
+
+// From <asm/prctl.h>
+// Flags are used in syscall arch_prctl(2).
+const (
+ ARCH_SET_GS = 0x1001
+ ARCH_SET_FS = 0x1002
+ ARCH_GET_FS = 0x1003
+ ARCH_GET_GS = 0x1004
+ ARCH_SET_CPUID = 0x1012
+)
diff --git a/pkg/abi/linux/ptrace.go b/pkg/abi/linux/ptrace.go
new file mode 100644
index 000000000..23e605ab2
--- /dev/null
+++ b/pkg/abi/linux/ptrace.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// ptrace commands from include/uapi/linux/ptrace.h.
+const (
+ PTRACE_TRACEME = 0
+ PTRACE_PEEKTEXT = 1
+ PTRACE_PEEKDATA = 2
+ PTRACE_PEEKUSR = 3
+ PTRACE_POKETEXT = 4
+ PTRACE_POKEDATA = 5
+ PTRACE_POKEUSR = 6
+ PTRACE_CONT = 7
+ PTRACE_KILL = 8
+ PTRACE_SINGLESTEP = 9
+ PTRACE_ATTACH = 16
+ PTRACE_DETACH = 17
+ PTRACE_SYSCALL = 24
+ PTRACE_SETOPTIONS = 0x4200
+ PTRACE_GETEVENTMSG = 0x4201
+ PTRACE_GETSIGINFO = 0x4202
+ PTRACE_SETSIGINFO = 0x4203
+ PTRACE_GETREGSET = 0x4204
+ PTRACE_SETREGSET = 0x4205
+ PTRACE_SEIZE = 0x4206
+ PTRACE_INTERRUPT = 0x4207
+ PTRACE_LISTEN = 0x4208
+ PTRACE_PEEKSIGINFO = 0x4209
+ PTRACE_GETSIGMASK = 0x420a
+ PTRACE_SETSIGMASK = 0x420b
+ PTRACE_SECCOMP_GET_FILTER = 0x420c
+ PTRACE_SECCOMP_GET_METADATA = 0x420d
+)
+
+// ptrace commands from arch/x86/include/uapi/asm/ptrace-abi.h.
+const (
+ PTRACE_GETREGS = 12
+ PTRACE_SETREGS = 13
+ PTRACE_GETFPREGS = 14
+ PTRACE_SETFPREGS = 15
+ PTRACE_GETFPXREGS = 18
+ PTRACE_SETFPXREGS = 19
+ PTRACE_OLDSETOPTIONS = 21
+ PTRACE_GET_THREAD_AREA = 25
+ PTRACE_SET_THREAD_AREA = 26
+ PTRACE_ARCH_PRCTL = 30
+ PTRACE_SYSEMU = 31
+ PTRACE_SYSEMU_SINGLESTEP = 32
+ PTRACE_SINGLEBLOCK = 33
+)
+
+// ptrace event codes from include/uapi/linux/ptrace.h.
+const (
+ PTRACE_EVENT_FORK = 1
+ PTRACE_EVENT_VFORK = 2
+ PTRACE_EVENT_CLONE = 3
+ PTRACE_EVENT_EXEC = 4
+ PTRACE_EVENT_VFORK_DONE = 5
+ PTRACE_EVENT_EXIT = 6
+ PTRACE_EVENT_SECCOMP = 7
+ PTRACE_EVENT_STOP = 128
+)
+
+// PTRACE_SETOPTIONS options from include/uapi/linux/ptrace.h.
+const (
+ PTRACE_O_TRACESYSGOOD = 1
+ PTRACE_O_TRACEFORK = 1 << PTRACE_EVENT_FORK
+ PTRACE_O_TRACEVFORK = 1 << PTRACE_EVENT_VFORK
+ PTRACE_O_TRACECLONE = 1 << PTRACE_EVENT_CLONE
+ PTRACE_O_TRACEEXEC = 1 << PTRACE_EVENT_EXEC
+ PTRACE_O_TRACEVFORKDONE = 1 << PTRACE_EVENT_VFORK_DONE
+ PTRACE_O_TRACEEXIT = 1 << PTRACE_EVENT_EXIT
+ PTRACE_O_TRACESECCOMP = 1 << PTRACE_EVENT_SECCOMP
+ PTRACE_O_EXITKILL = 1 << 20
+ PTRACE_O_SUSPEND_SECCOMP = 1 << 21
+)
diff --git a/pkg/abi/linux/rusage.go b/pkg/abi/linux/rusage.go
new file mode 100644
index 000000000..d8302dc85
--- /dev/null
+++ b/pkg/abi/linux/rusage.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Flags that may be used with wait4(2) and getrusage(2).
+const (
+ // wait4(2) uses this to aggregate RUSAGE_SELF and RUSAGE_CHILDREN.
+ RUSAGE_BOTH = -0x2
+
+ // getrusage(2) flags.
+ RUSAGE_CHILDREN = -0x1
+ RUSAGE_SELF = 0x0
+ RUSAGE_THREAD = 0x1
+)
+
+// Rusage represents the Linux struct rusage.
+type Rusage struct {
+ UTime Timeval
+ STime Timeval
+ MaxRSS int64
+ IXRSS int64
+ IDRSS int64
+ ISRSS int64
+ MinFlt int64
+ MajFlt int64
+ NSwap int64
+ InBlock int64
+ OuBlock int64
+ MsgSnd int64
+ MsgRcv int64
+ NSignals int64
+ NVCSw int64
+ NIvCSw int64
+}
diff --git a/pkg/abi/linux/sched.go b/pkg/abi/linux/sched.go
new file mode 100644
index 000000000..193d9a242
--- /dev/null
+++ b/pkg/abi/linux/sched.go
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Scheduling policies, exposed by sched_getscheduler(2)/sched_setscheduler(2).
+const (
+ SCHED_NORMAL = 0
+ SCHED_FIFO = 1
+ SCHED_RR = 2
+ SCHED_BATCH = 3
+ SCHED_IDLE = 5
+ SCHED_DEADLINE = 6
+ SCHED_MICROQ = 16
+
+ // SCHED_RESET_ON_FORK is a flag that indicates that the process is
+ // reverted back to SCHED_NORMAL on fork.
+ SCHED_RESET_ON_FORK = 0x40000000
+)
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
new file mode 100644
index 000000000..4eeb5cd7a
--- /dev/null
+++ b/pkg/abi/linux/seccomp.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "fmt"
+
+// Seccomp constants taken from <linux/seccomp.h>.
+const (
+ SECCOMP_MODE_NONE = 0
+ SECCOMP_MODE_FILTER = 2
+
+ SECCOMP_RET_ACTION_FULL = 0xffff0000
+ SECCOMP_RET_ACTION = 0x7fff0000
+ SECCOMP_RET_DATA = 0x0000ffff
+
+ SECCOMP_SET_MODE_FILTER = 1
+ SECCOMP_FILTER_FLAG_TSYNC = 1
+ SECCOMP_GET_ACTION_AVAIL = 2
+)
+
+type BPFAction uint32
+
+const (
+ SECCOMP_RET_KILL_PROCESS BPFAction = 0x80000000
+ SECCOMP_RET_KILL_THREAD = 0x00000000
+ SECCOMP_RET_TRAP = 0x00030000
+ SECCOMP_RET_ERRNO = 0x00050000
+ SECCOMP_RET_TRACE = 0x7ff00000
+ SECCOMP_RET_ALLOW = 0x7fff0000
+)
+
+func (a BPFAction) String() string {
+ switch a & SECCOMP_RET_ACTION_FULL {
+ case SECCOMP_RET_KILL_PROCESS:
+ return "kill process"
+ case SECCOMP_RET_KILL_THREAD:
+ return "kill thread"
+ case SECCOMP_RET_TRAP:
+ return fmt.Sprintf("trap (%d)", a.Data())
+ case SECCOMP_RET_ERRNO:
+ return fmt.Sprintf("errno (%d)", a.Data())
+ case SECCOMP_RET_TRACE:
+ return fmt.Sprintf("trace (%d)", a.Data())
+ case SECCOMP_RET_ALLOW:
+ return "allow"
+ }
+ return fmt.Sprintf("invalid action: %#x", a)
+}
+
+// Data returns the SECCOMP_RET_DATA portion of the action.
+func (a BPFAction) Data() uint16 {
+ return uint16(a & SECCOMP_RET_DATA)
+}
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
new file mode 100644
index 000000000..de422c519
--- /dev/null
+++ b/pkg/abi/linux/sem.go
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// semctl Command Definitions. Source: include/uapi/linux/sem.h
+const (
+ GETPID = 11
+ GETVAL = 12
+ GETALL = 13
+ GETNCNT = 14
+ GETZCNT = 15
+ SETVAL = 16
+ SETALL = 17
+)
+
+// ipcs ctl cmds. Source: include/uapi/linux/sem.h
+const (
+ SEM_STAT = 18
+ SEM_INFO = 19
+ SEM_STAT_ANY = 20
+)
+
+const SEM_UNDO = 0x1000
+
+// SemidDS is equivalent to struct semid64_ds.
+type SemidDS struct {
+ SemPerm IPCPerm
+ SemOTime TimeT
+ SemCTime TimeT
+ SemNSems uint64
+ unused3 uint64
+ unused4 uint64
+}
+
+// Sembuf is equivalent to struct sembuf.
+type Sembuf struct {
+ SemNum uint16
+ SemOp int16
+ SemFlg int16
+}
diff --git a/pkg/abi/linux/shm.go b/pkg/abi/linux/shm.go
new file mode 100644
index 000000000..e45aadb10
--- /dev/null
+++ b/pkg/abi/linux/shm.go
@@ -0,0 +1,86 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "math"
+
+// shmat(2) flags. Source: include/uapi/linux/shm.h
+const (
+ SHM_RDONLY = 010000 // Read-only access.
+ SHM_RND = 020000 // Round attach address to SHMLBA boundary.
+ SHM_REMAP = 040000 // Take-over region on attach.
+ SHM_EXEC = 0100000 // Execution access.
+)
+
+// IPCPerm.Mode upper byte flags. Source: include/linux/shm.h
+const (
+ SHM_DEST = 01000 // Segment will be destroyed on last detach.
+ SHM_LOCKED = 02000 // Segment will not be swapped.
+ SHM_HUGETLB = 04000 // Segment will use huge TLB pages.
+ SHM_NORESERVE = 010000 // Don't check for reservations.
+)
+
+// Additional Linux-only flags for shmctl(2). Source: include/uapi/linux/shm.h
+const (
+ SHM_LOCK = 11
+ SHM_UNLOCK = 12
+ SHM_STAT = 13
+ SHM_INFO = 14
+)
+
+// SHM defaults as specified by linux. Source: include/uapi/linux/shm.h
+const (
+ SHMMIN = 1
+ SHMMNI = 4096
+ SHMMAX = math.MaxUint64 - 1<<24
+ SHMALL = math.MaxUint64 - 1<<24
+ SHMSEG = 4096
+)
+
+// ShmidDS is equivalent to struct shmid64_ds. Source:
+// include/uapi/asm-generic/shmbuf.h
+type ShmidDS struct {
+ ShmPerm IPCPerm
+ ShmSegsz uint64
+ ShmAtime TimeT
+ ShmDtime TimeT
+ ShmCtime TimeT
+ ShmCpid int32
+ ShmLpid int32
+ ShmNattach uint64
+
+ Unused4 uint64
+ Unused5 uint64
+}
+
+// ShmParams is equivalent to struct shminfo. Source: include/uapi/linux/shm.h
+type ShmParams struct {
+ ShmMax uint64
+ ShmMin uint64
+ ShmMni uint64
+ ShmSeg uint64
+ ShmAll uint64
+}
+
+// ShmInfo is equivalent to struct shm_info. Source: include/uapi/linux/shm.h
+type ShmInfo struct {
+ UsedIDs int32 // Number of currently existing segments.
+ _ [4]byte
+ ShmTot uint64 // Total number of shared memory pages.
+ ShmRss uint64 // Number of resident shared memory pages.
+ ShmSwp uint64 // Number of swapped shared memory pages.
+ SwapAttempts uint64 // Unused since Linux 2.4.
+ SwapSuccesses uint64 // Unused since Linux 2.4.
+}
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
new file mode 100644
index 000000000..9cbd77dda
--- /dev/null
+++ b/pkg/abi/linux/signal.go
@@ -0,0 +1,232 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/bits"
+)
+
+const (
+ // SignalMaximum is the highest valid signal number.
+ SignalMaximum = 64
+
+ // FirstStdSignal is the lowest standard signal number.
+ FirstStdSignal = 1
+
+ // LastStdSignal is the highest standard signal number.
+ LastStdSignal = 31
+
+ // FirstRTSignal is the lowest real-time signal number.
+ //
+ // 32 (SIGCANCEL) and 33 (SIGSETXID) are used internally by glibc.
+ FirstRTSignal = 32
+
+ // LastRTSignal is the highest real-time signal number.
+ LastRTSignal = 64
+
+ // NumStdSignals is the number of standard signals.
+ NumStdSignals = LastStdSignal - FirstStdSignal + 1
+
+ // NumRTSignals is the number of realtime signals.
+ NumRTSignals = LastRTSignal - FirstRTSignal + 1
+)
+
+// Signal is a signal number.
+type Signal int
+
+// IsValid returns true if s is a valid standard or realtime signal. (0 is not
+// considered valid; interfaces special-casing signal number 0 should check for
+// 0 first before asserting validity.)
+func (s Signal) IsValid() bool {
+ return s > 0 && s <= SignalMaximum
+}
+
+// IsStandard returns true if s is a standard signal.
+//
+// Preconditions: s.IsValid().
+func (s Signal) IsStandard() bool {
+ return s <= LastStdSignal
+}
+
+// IsRealtime returns true if s is a realtime signal.
+//
+// Preconditions: s.IsValid().
+func (s Signal) IsRealtime() bool {
+ return s >= FirstRTSignal
+}
+
+// Index returns the index for signal s into arrays of both standard and
+// realtime signals (e.g. signal masks).
+//
+// Preconditions: s.IsValid().
+func (s Signal) Index() int {
+ return int(s - 1)
+}
+
+// Signals.
+const (
+ SIGABRT = Signal(6)
+ SIGALRM = Signal(14)
+ SIGBUS = Signal(7)
+ SIGCHLD = Signal(17)
+ SIGCLD = Signal(17)
+ SIGCONT = Signal(18)
+ SIGFPE = Signal(8)
+ SIGHUP = Signal(1)
+ SIGILL = Signal(4)
+ SIGINT = Signal(2)
+ SIGIO = Signal(29)
+ SIGIOT = Signal(6)
+ SIGKILL = Signal(9)
+ SIGPIPE = Signal(13)
+ SIGPOLL = Signal(29)
+ SIGPROF = Signal(27)
+ SIGPWR = Signal(30)
+ SIGQUIT = Signal(3)
+ SIGSEGV = Signal(11)
+ SIGSTKFLT = Signal(16)
+ SIGSTOP = Signal(19)
+ SIGSYS = Signal(31)
+ SIGTERM = Signal(15)
+ SIGTRAP = Signal(5)
+ SIGTSTP = Signal(20)
+ SIGTTIN = Signal(21)
+ SIGTTOU = Signal(22)
+ SIGUNUSED = Signal(31)
+ SIGURG = Signal(23)
+ SIGUSR1 = Signal(10)
+ SIGUSR2 = Signal(12)
+ SIGVTALRM = Signal(26)
+ SIGWINCH = Signal(28)
+ SIGXCPU = Signal(24)
+ SIGXFSZ = Signal(25)
+)
+
+// SignalSet is a signal mask with a bit corresponding to each signal.
+type SignalSet uint64
+
+// SignalSetSize is the size in bytes of a SignalSet.
+const SignalSetSize = 8
+
+// MakeSignalSet returns SignalSet with the bit corresponding to each of the
+// given signals set.
+func MakeSignalSet(sigs ...Signal) SignalSet {
+ indices := make([]int, len(sigs))
+ for i, sig := range sigs {
+ indices[i] = sig.Index()
+ }
+ return SignalSet(bits.Mask64(indices...))
+}
+
+// SignalSetOf returns a SignalSet with a single signal set.
+func SignalSetOf(sig Signal) SignalSet {
+ return SignalSet(bits.MaskOf64(sig.Index()))
+}
+
+// ForEachSignal invokes f for each signal set in the given mask.
+func ForEachSignal(mask SignalSet, f func(sig Signal)) {
+ bits.ForEachSetBit64(uint64(mask), func(i int) {
+ f(Signal(i + 1))
+ })
+}
+
+// 'how' values for rt_sigprocmask(2).
+const (
+ // SIG_BLOCK blocks the signals in the set.
+ SIG_BLOCK = 0
+
+ // SIG_UNBLOCK blocks the signals in the set.
+ SIG_UNBLOCK = 1
+
+ // SIG_SETMASK sets the signal mask to set.
+ SIG_SETMASK = 2
+)
+
+// Signal actions for rt_sigaction(2), from uapi/asm-generic/signal-defs.h.
+const (
+ // SIG_DFL performs the default action.
+ SIG_DFL = 0
+
+ // SIG_IGN ignores the signal.
+ SIG_IGN = 1
+)
+
+// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h
+const (
+ SA_NOCLDSTOP = 0x00000001
+ SA_NOCLDWAIT = 0x00000002
+ SA_SIGINFO = 0x00000004
+ SA_RESTORER = 0x04000000
+ SA_ONSTACK = 0x08000000
+ SA_RESTART = 0x10000000
+ SA_NODEFER = 0x40000000
+ SA_RESETHAND = 0x80000000
+ SA_NOMASK = SA_NODEFER
+ SA_ONESHOT = SA_RESETHAND
+)
+
+// Signal info types.
+const (
+ SI_MASK = 0xffff0000
+ SI_KILL = 0 << 16
+ SI_TIMER = 1 << 16
+ SI_POLL = 2 << 16
+ SI_FAULT = 3 << 16
+ SI_CHLD = 4 << 16
+ SI_RT = 5 << 16
+ SI_MESGQ = 6 << 16
+ SI_SYS = 7 << 16
+)
+
+// SIGPOLL si_codes.
+const (
+ // POLL_IN indicates that data input available.
+ POLL_IN = SI_POLL | 1
+
+ // POLL_OUT indicates that output buffers available.
+ POLL_OUT = SI_POLL | 2
+
+ // POLL_MSG indicates that an input message available.
+ POLL_MSG = SI_POLL | 3
+
+ // POLL_ERR indicates that there was an i/o error.
+ POLL_ERR = SI_POLL | 4
+
+ // POLL_PRI indicates that a high priority input available.
+ POLL_PRI = SI_POLL | 5
+
+ // POLL_HUP indicates that a device disconnected.
+ POLL_HUP = SI_POLL | 6
+)
+
+// Sigevent represents struct sigevent.
+type Sigevent struct {
+ Value uint64 // union sigval {int, void*}
+ Signo int32
+ Notify int32
+
+ // struct sigevent here contains 48-byte union _sigev_un. However, only
+ // member _tid is significant to the kernel.
+ Tid int32
+ UnRemainder [44]byte
+}
+
+// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify.
+const (
+ SIGEV_SIGNAL = 0
+ SIGEV_NONE = 1
+ SIGEV_THREAD = 2
+ SIGEV_THREAD_ID = 4
+)
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
new file mode 100644
index 000000000..417840731
--- /dev/null
+++ b/pkg/abi/linux/socket.go
@@ -0,0 +1,385 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import "gvisor.googlesource.com/gvisor/pkg/binary"
+
+// Address families, from linux/socket.h.
+const (
+ AF_UNSPEC = 0
+ AF_UNIX = 1
+ AF_INET = 2
+ AF_AX25 = 3
+ AF_IPX = 4
+ AF_APPLETALK = 5
+ AF_NETROM = 6
+ AF_BRIDGE = 7
+ AF_ATMPVC = 8
+ AF_X25 = 9
+ AF_INET6 = 10
+ AF_ROSE = 11
+ AF_DECnet = 12
+ AF_NETBEUI = 13
+ AF_SECURITY = 14
+ AF_KEY = 15
+ AF_NETLINK = 16
+ AF_PACKET = 17
+ AF_ASH = 18
+ AF_ECONET = 19
+ AF_ATMSVC = 20
+ AF_RDS = 21
+ AF_SNA = 22
+ AF_IRDA = 23
+ AF_PPPOX = 24
+ AF_WANPIPE = 25
+ AF_LLC = 26
+ AF_IB = 27
+ AF_MPLS = 28
+ AF_CAN = 29
+ AF_TIPC = 30
+ AF_BLUETOOTH = 31
+ AF_IUCV = 32
+ AF_RXRPC = 33
+ AF_ISDN = 34
+ AF_PHONET = 35
+ AF_IEEE802154 = 36
+ AF_CAIF = 37
+ AF_ALG = 38
+ AF_NFC = 39
+ AF_VSOCK = 40
+)
+
+// sendmsg(2)/recvmsg(2) flags, from linux/socket.h.
+const (
+ MSG_OOB = 0x1
+ MSG_PEEK = 0x2
+ MSG_DONTROUTE = 0x4
+ MSG_TRYHARD = 0x4
+ MSG_CTRUNC = 0x8
+ MSG_PROBE = 0x10
+ MSG_TRUNC = 0x20
+ MSG_DONTWAIT = 0x40
+ MSG_EOR = 0x80
+ MSG_WAITALL = 0x100
+ MSG_FIN = 0x200
+ MSG_EOF = MSG_FIN
+ MSG_SYN = 0x400
+ MSG_CONFIRM = 0x800
+ MSG_RST = 0x1000
+ MSG_ERRQUEUE = 0x2000
+ MSG_NOSIGNAL = 0x4000
+ MSG_MORE = 0x8000
+ MSG_WAITFORONE = 0x10000
+ MSG_SENDPAGE_NOTLAST = 0x20000
+ MSG_REINJECT = 0x8000000
+ MSG_ZEROCOPY = 0x4000000
+ MSG_FASTOPEN = 0x20000000
+ MSG_CMSG_CLOEXEC = 0x40000000
+)
+
+// Set/get socket option levels, from socket.h.
+const (
+ SOL_IP = 0
+ SOL_SOCKET = 1
+ SOL_TCP = 6
+ SOL_UDP = 17
+ SOL_IPV6 = 41
+ SOL_ICMPV6 = 58
+ SOL_RAW = 255
+ SOL_PACKET = 263
+ SOL_NETLINK = 270
+)
+
+// Socket types, from linux/net.h.
+const (
+ SOCK_STREAM = 1
+ SOCK_DGRAM = 2
+ SOCK_RAW = 3
+ SOCK_RDM = 4
+ SOCK_SEQPACKET = 5
+ SOCK_DCCP = 6
+ SOCK_PACKET = 10
+)
+
+// SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are
+// flags. From linux/net.h.
+const SOCK_TYPE_MASK = 0xf
+
+// socket(2)/socketpair(2)/accept4(2) flags, from linux/net.h.
+const (
+ SOCK_CLOEXEC = O_CLOEXEC
+ SOCK_NONBLOCK = O_NONBLOCK
+)
+
+// shutdown(2) how commands, from <linux/net.h>.
+const (
+ SHUT_RD = 0
+ SHUT_WR = 1
+ SHUT_RDWR = 2
+)
+
+// Socket options from socket.h.
+const (
+ SO_DEBUG = 1
+ SO_REUSEADDR = 2
+ SO_TYPE = 3
+ SO_ERROR = 4
+ SO_DONTROUTE = 5
+ SO_BROADCAST = 6
+ SO_SNDBUF = 7
+ SO_RCVBUF = 8
+ SO_KEEPALIVE = 9
+ SO_OOBINLINE = 10
+ SO_NO_CHECK = 11
+ SO_PRIORITY = 12
+ SO_LINGER = 13
+ SO_BSDCOMPAT = 14
+ SO_REUSEPORT = 15
+ SO_PASSCRED = 16
+ SO_PEERCRED = 17
+ SO_RCVLOWAT = 18
+ SO_SNDLOWAT = 19
+ SO_RCVTIMEO = 20
+ SO_SNDTIMEO = 21
+ SO_BINDTODEVICE = 25
+ SO_ATTACH_FILTER = 26
+ SO_DETACH_FILTER = 27
+ SO_GET_FILTER = SO_ATTACH_FILTER
+ SO_PEERNAME = 28
+ SO_TIMESTAMP = 29
+ SO_ACCEPTCONN = 30
+ SO_PEERSEC = 31
+ SO_SNDBUFFORCE = 32
+ SO_RCVBUFFORCE = 33
+ SO_PASSSEC = 34
+ SO_TIMESTAMPNS = 35
+ SO_MARK = 36
+ SO_TIMESTAMPING = 37
+ SO_PROTOCOL = 38
+ SO_DOMAIN = 39
+ SO_RXQ_OVFL = 40
+ SO_WIFI_STATUS = 41
+ SO_PEEK_OFF = 42
+ SO_NOFCS = 43
+ SO_LOCK_FILTER = 44
+ SO_SELECT_ERR_QUEUE = 45
+ SO_BUSY_POLL = 46
+ SO_MAX_PACING_RATE = 47
+ SO_BPF_EXTENSIONS = 48
+ SO_INCOMING_CPU = 49
+ SO_ATTACH_BPF = 50
+ SO_ATTACH_REUSEPORT_CBPF = 51
+ SO_ATTACH_REUSEPORT_EBPF = 52
+ SO_CNX_ADVICE = 53
+ SO_MEMINFO = 55
+ SO_INCOMING_NAPI_ID = 56
+ SO_COOKIE = 57
+ SO_PEERGROUPS = 59
+ SO_ZEROCOPY = 60
+ SO_TXTIME = 61
+)
+
+// enum socket_state, from uapi/linux/net.h.
+const (
+ SS_FREE = 0 // Not allocated.
+ SS_UNCONNECTED = 1 // Unconnected to any socket.
+ SS_CONNECTING = 2 // In process of connecting.
+ SS_CONNECTED = 3 // Connected to socket.
+ SS_DISCONNECTING = 4 // In process of disconnecting.
+)
+
+// SockAddrMax is the maximum size of a struct sockaddr, from
+// uapi/linux/socket.h.
+const SockAddrMax = 128
+
+// InetAddr is struct in_addr, from uapi/linux/in.h.
+type InetAddr [4]byte
+
+// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
+type SockAddrInet struct {
+ Family uint16
+ Port uint16
+ Addr InetAddr
+ Zero [8]uint8 // pad to sizeof(struct sockaddr).
+}
+
+// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
+type InetMulticastRequest struct {
+ MulticastAddr InetAddr
+ InterfaceAddr InetAddr
+}
+
+// InetMulticastRequestWithNIC is struct ip_mreqn, from uapi/linux/in.h.
+type InetMulticastRequestWithNIC struct {
+ InetMulticastRequest
+ InterfaceIndex int32
+}
+
+// SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h.
+type SockAddrInet6 struct {
+ Family uint16
+ Port uint16
+ Flowinfo uint32
+ Addr [16]byte
+ Scope_id uint32
+}
+
+// UnixPathMax is the maximum length of the path in an AF_UNIX socket.
+//
+// From uapi/linux/un.h.
+const UnixPathMax = 108
+
+// SockAddrUnix is struct sockaddr_un, from uapi/linux/un.h.
+type SockAddrUnix struct {
+ Family uint16
+ Path [UnixPathMax]int8
+}
+
+// Linger is struct linger, from include/linux/socket.h.
+type Linger struct {
+ OnOff int32
+ Linger int32
+}
+
+// SizeOfLinger is the binary size of a Linger struct.
+const SizeOfLinger = 8
+
+// TCPInfo is a collection of TCP statistics.
+//
+// From uapi/linux/tcp.h.
+type TCPInfo struct {
+ State uint8
+ CaState uint8
+ Retransmits uint8
+ Probes uint8
+ Backoff uint8
+ Options uint8
+ // WindowScale is the combination of snd_wscale (first 4 bits) and rcv_wscale (second 4 bits)
+ WindowScale uint8
+ // DeliveryRateAppLimited is a boolean and only the first bit is meaningful.
+ DeliveryRateAppLimited uint8
+
+ RTO uint32
+ ATO uint32
+ SndMss uint32
+ RcvMss uint32
+
+ Unacked uint32
+ Sacked uint32
+ Lost uint32
+ Retrans uint32
+ Fackets uint32
+
+ // Times.
+ LastDataSent uint32
+ LastAckSent uint32
+ LastDataRecv uint32
+ LastAckRecv uint32
+
+ // Metrics.
+ PMTU uint32
+ RcvSsthresh uint32
+ RTT uint32
+ RTTVar uint32
+ SndSsthresh uint32
+ SndCwnd uint32
+ Advmss uint32
+ Reordering uint32
+
+ RcvRTT uint32
+ RcvSpace uint32
+
+ TotalRetrans uint32
+
+ PacingRate uint64
+ MaxPacingRate uint64
+ // BytesAcked is RFC4898 tcpEStatsAppHCThruOctetsAcked.
+ BytesAcked uint64
+ // BytesReceived is RFC4898 tcpEStatsAppHCThruOctetsReceived.
+ BytesReceived uint64
+ // SegsOut is RFC4898 tcpEStatsPerfSegsOut.
+ SegsOut uint32
+ // SegsIn is RFC4898 tcpEStatsPerfSegsIn.
+ SegsIn uint32
+
+ NotSentBytes uint32
+ MinRTT uint32
+ // DataSegsIn is RFC4898 tcpEStatsDataSegsIn.
+ DataSegsIn uint32
+ // DataSegsOut is RFC4898 tcpEStatsDataSegsOut.
+ DataSegsOut uint32
+
+ DeliveryRate uint64
+
+ // BusyTime is the time in microseconds busy sending data.
+ BusyTime uint64
+ // RwndLimited is the time in microseconds limited by receive window.
+ RwndLimited uint64
+ // SndBufLimited is the time in microseconds limited by send buffer.
+ SndBufLimited uint64
+}
+
+// SizeOfTCPInfo is the binary size of a TCPInfo struct.
+const SizeOfTCPInfo = 104
+
+// Control message types, from linux/socket.h.
+const (
+ SCM_CREDENTIALS = 0x2
+ SCM_RIGHTS = 0x1
+)
+
+// A ControlMessageHeader is the header for a socket control message.
+//
+// ControlMessageHeader represents struct cmsghdr from linux/socket.h.
+type ControlMessageHeader struct {
+ Length uint64
+ Level int32
+ Type int32
+}
+
+// SizeOfControlMessageHeader is the binary size of a ControlMessageHeader
+// struct.
+var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
+
+// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
+//
+// ControlMessageCredentials represents struct ucred from linux/socket.h.
+type ControlMessageCredentials struct {
+ PID int32
+ UID uint32
+ GID uint32
+}
+
+// SizeOfControlMessageCredentials is the binary size of a
+// ControlMessageCredentials struct.
+var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
+
+// A ControlMessageRights is an SCM_RIGHTS socket control message.
+type ControlMessageRights []int32
+
+// SizeOfControlMessageRight is the size of a single element in
+// ControlMessageRights.
+const SizeOfControlMessageRight = 4
+
+// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
+// From net/scm.h.
+const SCM_MAX_FD = 253
+
+// SO_ACCEPTCON is defined as __SO_ACCEPTCON in
+// include/uapi/linux/net.h, which represents a listening socket
+// state. Note that this is distinct from SO_ACCEPTCONN, which is a
+// socket option for querying whether a socket is in a listening
+// state.
+const SO_ACCEPTCON = 1 << 16
diff --git a/pkg/abi/linux/splice.go b/pkg/abi/linux/splice.go
new file mode 100644
index 000000000..650eb87e8
--- /dev/null
+++ b/pkg/abi/linux/splice.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Constants for splice(2), sendfile(2) and tee(2).
+const (
+ SPLICE_F_MOVE = 1 << iota
+ SPLICE_F_NONBLOCK
+ SPLICE_F_MORE
+ SPLICE_F_GIFT
+)
diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go
new file mode 100644
index 000000000..174d470e2
--- /dev/null
+++ b/pkg/abi/linux/tcp.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Socket options from uapi/linux/tcp.h.
+const (
+ TCP_NODELAY = 1
+ TCP_MAXSEG = 2
+ TCP_CORK = 3
+ TCP_KEEPIDLE = 4
+ TCP_KEEPINTVL = 5
+ TCP_KEEPCNT = 6
+ TCP_SYNCNT = 7
+ TCP_LINGER2 = 8
+ TCP_DEFER_ACCEPT = 9
+ TCP_WINDOW_CLAMP = 10
+ TCP_INFO = 11
+ TCP_QUICKACK = 12
+ TCP_CONGESTION = 13
+ TCP_MD5SIG = 14
+ TCP_THIN_LINEAR_TIMEOUTS = 16
+ TCP_THIN_DUPACK = 17
+ TCP_USER_TIMEOUT = 18
+ TCP_REPAIR = 19
+ TCP_REPAIR_QUEUE = 20
+ TCP_QUEUE_SEQ = 21
+ TCP_REPAIR_OPTIONS = 22
+ TCP_FASTOPEN = 23
+ TCP_TIMESTAMP = 24
+ TCP_NOTSENT_LOWAT = 25
+ TCP_CC_INFO = 26
+ TCP_SAVE_SYN = 27
+ TCP_SAVED_SYN = 28
+ TCP_REPAIR_WINDOW = 29
+ TCP_FASTOPEN_CONNECT = 30
+ TCP_ULP = 31
+ TCP_MD5SIG_EXT = 32
+ TCP_FASTOPEN_KEY = 33
+ TCP_FASTOPEN_NO_COOKIE = 34
+ TCP_ZEROCOPY_RECEIVE = 35
+ TCP_INQ = 36
+)
+
+// Socket constants from include/net/tcp.h.
+const (
+ MAX_TCP_KEEPIDLE = 32767
+ MAX_TCP_KEEPINTVL = 32767
+)
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
new file mode 100644
index 000000000..fa9ee27e1
--- /dev/null
+++ b/pkg/abi/linux/time.go
@@ -0,0 +1,228 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "math"
+ "time"
+)
+
+const (
+ // ClockTick is the length of time represented by a single clock tick, as
+ // used by times(2) and /proc/[pid]/stat.
+ ClockTick = time.Second / CLOCKS_PER_SEC
+
+ // CLOCKS_PER_SEC is the number of ClockTicks per second.
+ //
+ // Linux defines this to be 100 on most architectures, irrespective of
+ // CONFIG_HZ. Userspace obtains the value through sysconf(_SC_CLK_TCK),
+ // which uses the AT_CLKTCK entry in the auxiliary vector if one is
+ // provided, and assumes 100 otherwise (glibc:
+ // sysdeps/posix/sysconf.c:__sysconf() =>
+ // sysdeps/unix/sysv/linux/getclktck.c, elf/dl-support.c:_dl_aux_init()).
+ //
+ // Not to be confused with POSIX CLOCKS_PER_SEC, as used by clock(3); "XSI
+ // requires that [POSIX] CLOCKS_PER_SEC equals 1000000 independent of the
+ // actual resolution" - clock(3).
+ CLOCKS_PER_SEC = 100
+)
+
+// CPU clock types for use with clock_gettime(2) et al.
+//
+// The 29 most significant bits of a 32 bit clock ID are either a PID or a FD.
+//
+// Bits 1 and 0 give the type: PROF=0, VIRT=1, SCHED=2, or FD=3.
+//
+// Bit 2 indicates whether a cpu clock refers to a thread or a process.
+const (
+ CPUCLOCK_PROF = 0
+ CPUCLOCK_VIRT = 1
+ CPUCLOCK_SCHED = 2
+ CPUCLOCK_MAX = 3
+ CLOCKFD = CPUCLOCK_MAX
+
+ CPUCLOCK_CLOCK_MASK = 3
+ CPUCLOCK_PERTHREAD_MASK = 4
+)
+
+// Clock identifiers for use with clock_gettime(2), clock_getres(2),
+// clock_nanosleep(2).
+const (
+ CLOCK_REALTIME = 0
+ CLOCK_MONOTONIC = 1
+ CLOCK_PROCESS_CPUTIME_ID = 2
+ CLOCK_THREAD_CPUTIME_ID = 3
+ CLOCK_MONOTONIC_RAW = 4
+ CLOCK_REALTIME_COARSE = 5
+ CLOCK_MONOTONIC_COARSE = 6
+ CLOCK_BOOTTIME = 7
+ CLOCK_REALTIME_ALARM = 8
+ CLOCK_BOOTTIME_ALARM = 9
+)
+
+// Flags for clock_nanosleep(2).
+const (
+ TIMER_ABSTIME = 1
+)
+
+// Flags for timerfd syscalls (timerfd_create(2), timerfd_settime(2)).
+const (
+ // TFD_CLOEXEC is a timerfd_create flag.
+ TFD_CLOEXEC = O_CLOEXEC
+
+ // TFD_NONBLOCK is a timerfd_create flag.
+ TFD_NONBLOCK = O_NONBLOCK
+
+ // TFD_TIMER_ABSTIME is a timerfd_settime flag.
+ TFD_TIMER_ABSTIME = 1
+)
+
+// The safe number of seconds you can represent by int64.
+const maxSecInDuration = math.MaxInt64 / int64(time.Second)
+
+// TimeT represents time_t in <time.h>. It represents time in seconds.
+type TimeT int64
+
+// NsecToTimeT translates nanoseconds to TimeT (seconds).
+func NsecToTimeT(nsec int64) TimeT {
+ return TimeT(nsec / 1e9)
+}
+
+// Timespec represents struct timespec in <time.h>.
+type Timespec struct {
+ Sec int64
+ Nsec int64
+}
+
+// Unix returns the second and nanosecond.
+func (ts Timespec) Unix() (sec int64, nsec int64) {
+ return int64(ts.Sec), int64(ts.Nsec)
+}
+
+// ToTime returns the Go time.Time representation.
+func (ts Timespec) ToTime() time.Time {
+ return time.Unix(ts.Sec, ts.Nsec)
+}
+
+// ToNsec returns the nanosecond representation.
+func (ts Timespec) ToNsec() int64 {
+ return int64(ts.Sec)*1e9 + int64(ts.Nsec)
+}
+
+// ToNsecCapped returns the safe nanosecond representation.
+func (ts Timespec) ToNsecCapped() int64 {
+ if ts.Sec > maxSecInDuration {
+ return math.MaxInt64
+ }
+ return ts.ToNsec()
+}
+
+// ToDuration returns the safe nanosecond representation as time.Duration.
+func (ts Timespec) ToDuration() time.Duration {
+ return time.Duration(ts.ToNsecCapped())
+}
+
+// Valid returns whether the timespec contains valid values.
+func (ts Timespec) Valid() bool {
+ return !(ts.Sec < 0 || ts.Nsec < 0 || ts.Nsec >= int64(time.Second))
+}
+
+// NsecToTimespec translates nanoseconds to Timespec.
+func NsecToTimespec(nsec int64) (ts Timespec) {
+ ts.Sec = nsec / 1e9
+ ts.Nsec = nsec % 1e9
+ return
+}
+
+// DurationToTimespec translates time.Duration to Timespec.
+func DurationToTimespec(dur time.Duration) Timespec {
+ return NsecToTimespec(dur.Nanoseconds())
+}
+
+// SizeOfTimeval is the size of a Timeval struct in bytes.
+const SizeOfTimeval = 16
+
+// Timeval represents struct timeval in <time.h>.
+type Timeval struct {
+ Sec int64
+ Usec int64
+}
+
+// ToNsecCapped returns the safe nanosecond representation.
+func (tv Timeval) ToNsecCapped() int64 {
+ if tv.Sec > maxSecInDuration {
+ return math.MaxInt64
+ }
+ return int64(tv.Sec)*1e9 + int64(tv.Usec)*1e3
+}
+
+// ToDuration returns the safe nanosecond representation as a time.Duration.
+func (tv Timeval) ToDuration() time.Duration {
+ return time.Duration(tv.ToNsecCapped())
+}
+
+// ToTime returns the Go time.Time representation.
+func (tv Timeval) ToTime() time.Time {
+ return time.Unix(tv.Sec, tv.Usec*1e3)
+}
+
+// NsecToTimeval translates nanosecond to Timeval.
+func NsecToTimeval(nsec int64) (tv Timeval) {
+ nsec += 999 // round up to microsecond
+ tv.Sec = nsec / 1e9
+ tv.Usec = nsec % 1e9 / 1e3
+ return
+}
+
+// DurationToTimeval translates time.Duration to Timeval.
+func DurationToTimeval(dur time.Duration) Timeval {
+ return NsecToTimeval(dur.Nanoseconds())
+}
+
+// Itimerspec represents struct itimerspec in <time.h>.
+type Itimerspec struct {
+ Interval Timespec
+ Value Timespec
+}
+
+// ItimerVal mimics the following struct in <sys/time.h>
+// struct itimerval {
+// struct timeval it_interval; /* next value */
+// struct timeval it_value; /* current value */
+// };
+type ItimerVal struct {
+ Interval Timeval
+ Value Timeval
+}
+
+// ClockT represents type clock_t.
+type ClockT int64
+
+// ClockTFromDuration converts time.Duration to clock_t.
+func ClockTFromDuration(d time.Duration) ClockT {
+ return ClockT(d / ClockTick)
+}
+
+// Tms represents struct tms, used by times(2).
+type Tms struct {
+ UTime ClockT
+ STime ClockT
+ CUTime ClockT
+ CSTime ClockT
+}
+
+// TimerID represents type timer_t, which identifies a POSIX per-process
+// interval timer.
+type TimerID int32
diff --git a/pkg/abi/linux/timer.go b/pkg/abi/linux/timer.go
new file mode 100644
index 000000000..e32d09e10
--- /dev/null
+++ b/pkg/abi/linux/timer.go
@@ -0,0 +1,23 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// itimer types for getitimer(2) and setitimer(2), from
+// include/uapi/linux/time.h.
+const (
+ ITIMER_REAL = 0
+ ITIMER_VIRTUAL = 1
+ ITIMER_PROF = 2
+)
diff --git a/pkg/abi/linux/tty.go b/pkg/abi/linux/tty.go
new file mode 100644
index 000000000..8ac02aee8
--- /dev/null
+++ b/pkg/abi/linux/tty.go
@@ -0,0 +1,344 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ // NumControlCharacters is the number of control characters in Termios.
+ NumControlCharacters = 19
+ // disabledChar is used to indicate that a control character is
+ // disabled.
+ disabledChar = 0
+)
+
+// Winsize is struct winsize, defined in uapi/asm-generic/termios.h.
+type Winsize struct {
+ Row uint16
+ Col uint16
+ Xpixel uint16
+ Ypixel uint16
+}
+
+// Termios is struct termios, defined in uapi/asm-generic/termbits.h.
+type Termios struct {
+ InputFlags uint32
+ OutputFlags uint32
+ ControlFlags uint32
+ LocalFlags uint32
+ LineDiscipline uint8
+ ControlCharacters [NumControlCharacters]uint8
+}
+
+// KernelTermios is struct ktermios/struct termios2, defined in
+// uapi/asm-generic/termbits.h.
+//
+// +stateify savable
+type KernelTermios struct {
+ InputFlags uint32
+ OutputFlags uint32
+ ControlFlags uint32
+ LocalFlags uint32
+ LineDiscipline uint8
+ ControlCharacters [NumControlCharacters]uint8
+ InputSpeed uint32
+ OutputSpeed uint32
+}
+
+// IEnabled returns whether flag is enabled in termios input flags.
+func (t *KernelTermios) IEnabled(flag uint32) bool {
+ return t.InputFlags&flag == flag
+}
+
+// OEnabled returns whether flag is enabled in termios output flags.
+func (t *KernelTermios) OEnabled(flag uint32) bool {
+ return t.OutputFlags&flag == flag
+}
+
+// CEnabled returns whether flag is enabled in termios control flags.
+func (t *KernelTermios) CEnabled(flag uint32) bool {
+ return t.ControlFlags&flag == flag
+}
+
+// LEnabled returns whether flag is enabled in termios local flags.
+func (t *KernelTermios) LEnabled(flag uint32) bool {
+ return t.LocalFlags&flag == flag
+}
+
+// ToTermios copies fields that are shared with Termios into a new Termios
+// struct.
+func (t *KernelTermios) ToTermios() Termios {
+ return Termios{
+ InputFlags: t.InputFlags,
+ OutputFlags: t.OutputFlags,
+ ControlFlags: t.ControlFlags,
+ LocalFlags: t.LocalFlags,
+ LineDiscipline: t.LineDiscipline,
+ ControlCharacters: t.ControlCharacters,
+ }
+}
+
+// FromTermios copies fields that are shared with Termios into this
+// KernelTermios struct.
+func (t *KernelTermios) FromTermios(term Termios) {
+ t.InputFlags = term.InputFlags
+ t.OutputFlags = term.OutputFlags
+ t.ControlFlags = term.ControlFlags
+ t.LocalFlags = term.LocalFlags
+ t.LineDiscipline = term.LineDiscipline
+ t.ControlCharacters = term.ControlCharacters
+}
+
+// IsTerminating returns whether c is a line terminating character.
+func (t *KernelTermios) IsTerminating(cBytes []byte) bool {
+ // All terminating characters are 1 byte.
+ if len(cBytes) != 1 {
+ return false
+ }
+ c := cBytes[0]
+
+ // Is this the user-set EOF character?
+ if t.IsEOF(c) {
+ return true
+ }
+
+ switch c {
+ case disabledChar:
+ return false
+ case '\n', t.ControlCharacters[VEOL]:
+ return true
+ case t.ControlCharacters[VEOL2]:
+ return t.LEnabled(IEXTEN)
+ }
+ return false
+}
+
+// IsEOF returns whether c is the EOF character.
+func (t *KernelTermios) IsEOF(c byte) bool {
+ return c == t.ControlCharacters[VEOF] && t.ControlCharacters[VEOF] != disabledChar
+}
+
+// Input flags.
+const (
+ IGNBRK = 0000001
+ BRKINT = 0000002
+ IGNPAR = 0000004
+ PARMRK = 0000010
+ INPCK = 0000020
+ ISTRIP = 0000040
+ INLCR = 0000100
+ IGNCR = 0000200
+ ICRNL = 0000400
+ IUCLC = 0001000
+ IXON = 0002000
+ IXANY = 0004000
+ IXOFF = 0010000
+ IMAXBEL = 0020000
+ IUTF8 = 0040000
+)
+
+// Output flags.
+const (
+ OPOST = 0000001
+ OLCUC = 0000002
+ ONLCR = 0000004
+ OCRNL = 0000010
+ ONOCR = 0000020
+ ONLRET = 0000040
+ OFILL = 0000100
+ OFDEL = 0000200
+ NLDLY = 0000400
+ NL0 = 0000000
+ NL1 = 0000400
+ CRDLY = 0003000
+ CR0 = 0000000
+ CR1 = 0001000
+ CR2 = 0002000
+ CR3 = 0003000
+ TABDLY = 0014000
+ TAB0 = 0000000
+ TAB1 = 0004000
+ TAB2 = 0010000
+ TAB3 = 0014000
+ XTABS = 0014000
+ BSDLY = 0020000
+ BS0 = 0000000
+ BS1 = 0020000
+ VTDLY = 0040000
+ VT0 = 0000000
+ VT1 = 0040000
+ FFDLY = 0100000
+ FF0 = 0000000
+ FF1 = 0100000
+)
+
+// Control flags.
+const (
+ CBAUD = 0010017
+ B0 = 0000000
+ B50 = 0000001
+ B75 = 0000002
+ B110 = 0000003
+ B134 = 0000004
+ B150 = 0000005
+ B200 = 0000006
+ B300 = 0000007
+ B600 = 0000010
+ B1200 = 0000011
+ B1800 = 0000012
+ B2400 = 0000013
+ B4800 = 0000014
+ B9600 = 0000015
+ B19200 = 0000016
+ B38400 = 0000017
+ EXTA = B19200
+ EXTB = B38400
+ CSIZE = 0000060
+ CS5 = 0000000
+ CS6 = 0000020
+ CS7 = 0000040
+ CS8 = 0000060
+ CSTOPB = 0000100
+ CREAD = 0000200
+ PARENB = 0000400
+ PARODD = 0001000
+ HUPCL = 0002000
+ CLOCAL = 0004000
+ CBAUDEX = 0010000
+ BOTHER = 0010000
+ B57600 = 0010001
+ B115200 = 0010002
+ B230400 = 0010003
+ B460800 = 0010004
+ B500000 = 0010005
+ B576000 = 0010006
+ B921600 = 0010007
+ B1000000 = 0010010
+ B1152000 = 0010011
+ B1500000 = 0010012
+ B2000000 = 0010013
+ B2500000 = 0010014
+ B3000000 = 0010015
+ B3500000 = 0010016
+ B4000000 = 0010017
+ CIBAUD = 002003600000
+ CMSPAR = 010000000000
+ CRTSCTS = 020000000000
+
+ // IBSHIFT is the shift from CBAUD to CIBAUD.
+ IBSHIFT = 16
+)
+
+// Local flags.
+const (
+ ISIG = 0000001
+ ICANON = 0000002
+ XCASE = 0000004
+ ECHO = 0000010
+ ECHOE = 0000020
+ ECHOK = 0000040
+ ECHONL = 0000100
+ NOFLSH = 0000200
+ TOSTOP = 0000400
+ ECHOCTL = 0001000
+ ECHOPRT = 0002000
+ ECHOKE = 0004000
+ FLUSHO = 0010000
+ PENDIN = 0040000
+ IEXTEN = 0100000
+ EXTPROC = 0200000
+)
+
+// Control Character indices.
+const (
+ VINTR = 0
+ VQUIT = 1
+ VERASE = 2
+ VKILL = 3
+ VEOF = 4
+ VTIME = 5
+ VMIN = 6
+ VSWTC = 7
+ VSTART = 8
+ VSTOP = 9
+ VSUSP = 10
+ VEOL = 11
+ VREPRINT = 12
+ VDISCARD = 13
+ VWERASE = 14
+ VLNEXT = 15
+ VEOL2 = 16
+)
+
+// ControlCharacter returns the termios-style control character for the passed
+// character.
+//
+// e.g., for Ctrl-C, i.e., ^C, call ControlCharacter('C').
+//
+// Standard control characters are ASCII bytes 0 through 31.
+func ControlCharacter(c byte) uint8 {
+ // A is 1, B is 2, etc.
+ return uint8(c - 'A' + 1)
+}
+
+// DefaultControlCharacters is the default set of Termios control characters.
+var DefaultControlCharacters = [NumControlCharacters]uint8{
+ ControlCharacter('C'), // VINTR = ^C
+ ControlCharacter('\\'), // VQUIT = ^\
+ '\x7f', // VERASE = DEL
+ ControlCharacter('U'), // VKILL = ^U
+ ControlCharacter('D'), // VEOF = ^D
+ 0, // VTIME
+ 1, // VMIN
+ 0, // VSWTC
+ ControlCharacter('Q'), // VSTART = ^Q
+ ControlCharacter('S'), // VSTOP = ^S
+ ControlCharacter('Z'), // VSUSP = ^Z
+ 0, // VEOL
+ ControlCharacter('R'), // VREPRINT = ^R
+ ControlCharacter('O'), // VDISCARD = ^O
+ ControlCharacter('W'), // VWERASE = ^W
+ ControlCharacter('V'), // VLNEXT = ^V
+ 0, // VEOL2
+}
+
+// MasterTermios is the terminal configuration of the master end of a Unix98
+// pseudoterminal.
+var MasterTermios = KernelTermios{
+ ControlFlags: B38400 | CS8 | CREAD,
+ ControlCharacters: DefaultControlCharacters,
+ InputSpeed: 38400,
+ OutputSpeed: 38400,
+}
+
+// DefaultSlaveTermios is the default terminal configuration of the slave end
+// of a Unix98 pseudoterminal.
+var DefaultSlaveTermios = KernelTermios{
+ InputFlags: ICRNL | IXON,
+ OutputFlags: OPOST | ONLCR,
+ ControlFlags: B38400 | CS8 | CREAD,
+ LocalFlags: ISIG | ICANON | ECHO | ECHOE | ECHOK | ECHOCTL | ECHOKE | IEXTEN,
+ ControlCharacters: DefaultControlCharacters,
+ InputSpeed: 38400,
+ OutputSpeed: 38400,
+}
+
+// WindowSize corresponds to struct winsize defined in
+// include/uapi/asm-generic/termios.h.
+//
+// +stateify savable
+type WindowSize struct {
+ Rows uint16
+ Cols uint16
+ _ [4]byte // Padding for 2 unused shorts.
+}
diff --git a/pkg/abi/linux/uio.go b/pkg/abi/linux/uio.go
new file mode 100644
index 000000000..1fd1e9802
--- /dev/null
+++ b/pkg/abi/linux/uio.go
@@ -0,0 +1,18 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// UIO_MAXIOV is the maximum number of struct iovecs in a struct iovec array.
+const UIO_MAXIOV = 1024
diff --git a/pkg/abi/linux/utsname.go b/pkg/abi/linux/utsname.go
new file mode 100644
index 000000000..60f220a67
--- /dev/null
+++ b/pkg/abi/linux/utsname.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // UTSLen is the maximum length of strings contained in fields of
+ // UtsName.
+ UTSLen = 64
+)
+
+// UtsName represents struct utsname, the struct returned by uname(2).
+type UtsName struct {
+ Sysname [UTSLen + 1]byte
+ Nodename [UTSLen + 1]byte
+ Release [UTSLen + 1]byte
+ Version [UTSLen + 1]byte
+ Machine [UTSLen + 1]byte
+ Domainname [UTSLen + 1]byte
+}
+
+// utsNameString converts a UtsName entry to a string without NULs.
+func utsNameString(s [UTSLen + 1]byte) string {
+ // The NUL bytes will remain even in a cast to string. We must
+ // explicitly strip them.
+ return string(bytes.TrimRight(s[:], "\x00"))
+}
+
+func (u UtsName) String() string {
+ return fmt.Sprintf("{Sysname: %s, Nodename: %s, Release: %s, Version: %s, Machine: %s, Domainname: %s}",
+ utsNameString(u.Sysname), utsNameString(u.Nodename), utsNameString(u.Release),
+ utsNameString(u.Version), utsNameString(u.Machine), utsNameString(u.Domainname))
+}
diff --git a/pkg/abi/linux/wait.go b/pkg/abi/linux/wait.go
new file mode 100644
index 000000000..4bdc280d1
--- /dev/null
+++ b/pkg/abi/linux/wait.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Options for waitpid(2), wait4(2), and/or waitid(2), from
+// include/uapi/linux/wait.h.
+const (
+ WNOHANG = 0x00000001
+ WUNTRACED = 0x00000002
+ WSTOPPED = WUNTRACED
+ WEXITED = 0x00000004
+ WCONTINUED = 0x00000008
+ WNOWAIT = 0x01000000
+ WNOTHREAD = 0x20000000
+ WALL = 0x40000000
+ WCLONE = 0x80000000
+)
+
+// ID types for waitid(2), from include/uapi/linux/wait.h.
+const (
+ P_ALL = 0x0
+ P_PID = 0x1
+ P_PGID = 0x2
+)
diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go
new file mode 100644
index 000000000..4f7759b87
--- /dev/null
+++ b/pkg/amutex/amutex.go
@@ -0,0 +1,120 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package amutex provides the implementation of an abortable mutex. It allows
+// the Lock() function to be canceled while it waits to acquire the mutex.
+package amutex
+
+import (
+ "sync/atomic"
+)
+
+// Sleeper must be implemented by users of the abortable mutex to allow for
+// cancelation of waits.
+type Sleeper interface {
+ // SleepStart is called by the AbortableMutex.Lock() function when the
+ // mutex is contended and the goroutine is about to sleep.
+ //
+ // A channel can be returned that causes the sleep to be canceled if
+ // it's readable. If no cancelation is desired, nil can be returned.
+ SleepStart() <-chan struct{}
+
+ // SleepFinish is called by AbortableMutex.Lock() once a contended mutex
+ // is acquired or the wait is aborted.
+ SleepFinish(success bool)
+
+ // Interrupted returns true if the wait is aborted.
+ Interrupted() bool
+}
+
+// NoopSleeper is a stateless no-op implementation of Sleeper for anonymous
+// embedding in other types that do not support cancelation.
+type NoopSleeper struct{}
+
+// SleepStart implements Sleeper.SleepStart.
+func (NoopSleeper) SleepStart() <-chan struct{} {
+ return nil
+}
+
+// SleepFinish implements Sleeper.SleepFinish.
+func (NoopSleeper) SleepFinish(success bool) {}
+
+// Interrupted implements Sleeper.Interrupted.
+func (NoopSleeper) Interrupted() bool { return false }
+
+// AbortableMutex is an abortable mutex. It allows Lock() to be aborted while it
+// waits to acquire the mutex.
+type AbortableMutex struct {
+ v int32
+ ch chan struct{}
+}
+
+// Init initializes the abortable mutex.
+func (m *AbortableMutex) Init() {
+ m.v = 1
+ m.ch = make(chan struct{}, 1)
+}
+
+// Lock attempts to acquire the mutex, returning true on success. If something
+// is written to the "c" while Lock waits, the wait is aborted and false is
+// returned instead.
+func (m *AbortableMutex) Lock(s Sleeper) bool {
+ // Uncontended case.
+ if atomic.AddInt32(&m.v, -1) == 0 {
+ return true
+ }
+
+ var c <-chan struct{}
+ if s != nil {
+ c = s.SleepStart()
+ }
+
+ for {
+ // Try to acquire the mutex again, at the same time making sure
+ // that m.v is negative, which indicates to the owner of the
+ // lock that it is contended, which ill force it to try to wake
+ // someone up when it releases the mutex.
+ if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 {
+ if s != nil {
+ s.SleepFinish(true)
+ }
+ return true
+ }
+
+ // Wait for the owner to wake us up before trying again, or for
+ // the wait to be aborted by the provided channel.
+ select {
+ case <-m.ch:
+ case <-c:
+ // s must be non-nil, otherwise c would be nil and we'd
+ // never reach this path.
+ s.SleepFinish(false)
+ return false
+ }
+ }
+}
+
+// Unlock releases the mutex.
+func (m *AbortableMutex) Unlock() {
+ if atomic.SwapInt32(&m.v, 1) == 0 {
+ // There were no pending waiters.
+ return
+ }
+
+ // Wake some waiter up.
+ select {
+ case m.ch <- struct{}{}:
+ default:
+ }
+}
diff --git a/pkg/amutex/amutex_state_autogen.go b/pkg/amutex/amutex_state_autogen.go
new file mode 100755
index 000000000..5651ae4e9
--- /dev/null
+++ b/pkg/amutex/amutex_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package amutex
+
diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomic_bitops.go
new file mode 100644
index 000000000..63aa2b7f1
--- /dev/null
+++ b/pkg/atomicbitops/atomic_bitops.go
@@ -0,0 +1,59 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+// Package atomicbitops provides basic bitwise operations in an atomic way.
+// The implementation on amd64 leverages the LOCK prefix directly instead of
+// relying on the generic cas primitives.
+//
+// WARNING: the bitwise ops provided in this package doesn't imply any memory
+// ordering. Using them to construct locks must employ proper memory barriers.
+package atomicbitops
+
+// AndUint32 atomically applies bitwise and operation to *addr with val.
+func AndUint32(addr *uint32, val uint32)
+
+// OrUint32 atomically applies bitwise or operation to *addr with val.
+func OrUint32(addr *uint32, val uint32)
+
+// XorUint32 atomically applies bitwise xor operation to *addr with val.
+func XorUint32(addr *uint32, val uint32)
+
+// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
+// the value previously stored at addr.
+func CompareAndSwapUint32(addr *uint32, old, new uint32) uint32
+
+// AndUint64 atomically applies bitwise and operation to *addr with val.
+func AndUint64(addr *uint64, val uint64)
+
+// OrUint64 atomically applies bitwise or operation to *addr with val.
+func OrUint64(addr *uint64, val uint64)
+
+// XorUint64 atomically applies bitwise xor operation to *addr with val.
+func XorUint64(addr *uint64, val uint64)
+
+// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
+// the value previously stored at addr.
+func CompareAndSwapUint64(addr *uint64, old, new uint64) uint64
+
+// IncUnlessZeroInt32 increments the value stored at the given address and
+// returns true; unless the value stored in the pointer is zero, in which case
+// it is left unmodified and false is returned.
+func IncUnlessZeroInt32(addr *int32) bool
+
+// DecUnlessOneInt32 decrements the value stored at the given address and
+// returns true; unless the value stored in the pointer is 1, in which case it
+// is left unmodified and false is returned.
+func DecUnlessOneInt32(addr *int32) bool
diff --git a/pkg/atomicbitops/atomic_bitops_amd64.s b/pkg/atomicbitops/atomic_bitops_amd64.s
new file mode 100644
index 000000000..db0972001
--- /dev/null
+++ b/pkg/atomicbitops/atomic_bitops_amd64.s
@@ -0,0 +1,115 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+#include "textflag.h"
+
+TEXT ·AndUint32(SB),$0-12
+ MOVQ addr+0(FP), BP
+ MOVL val+8(FP), AX
+ LOCK
+ ANDL AX, 0(BP)
+ RET
+
+TEXT ·OrUint32(SB),$0-12
+ MOVQ addr+0(FP), BP
+ MOVL val+8(FP), AX
+ LOCK
+ ORL AX, 0(BP)
+ RET
+
+TEXT ·XorUint32(SB),$0-12
+ MOVQ addr+0(FP), BP
+ MOVL val+8(FP), AX
+ LOCK
+ XORL AX, 0(BP)
+ RET
+
+TEXT ·CompareAndSwapUint32(SB),$0-20
+ MOVQ addr+0(FP), DI
+ MOVL old+8(FP), AX
+ MOVL new+12(FP), DX
+ LOCK
+ CMPXCHGL DX, 0(DI)
+ MOVL AX, ret+16(FP)
+ RET
+
+TEXT ·AndUint64(SB),$0-16
+ MOVQ addr+0(FP), BP
+ MOVQ val+8(FP), AX
+ LOCK
+ ANDQ AX, 0(BP)
+ RET
+
+TEXT ·OrUint64(SB),$0-16
+ MOVQ addr+0(FP), BP
+ MOVQ val+8(FP), AX
+ LOCK
+ ORQ AX, 0(BP)
+ RET
+
+TEXT ·XorUint64(SB),$0-16
+ MOVQ addr+0(FP), BP
+ MOVQ val+8(FP), AX
+ LOCK
+ XORQ AX, 0(BP)
+ RET
+
+TEXT ·CompareAndSwapUint64(SB),$0-32
+ MOVQ addr+0(FP), DI
+ MOVQ old+8(FP), AX
+ MOVQ new+16(FP), DX
+ LOCK
+ CMPXCHGQ DX, 0(DI)
+ MOVQ AX, ret+24(FP)
+ RET
+
+TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
+ MOVQ addr+0(FP), DI
+ MOVL 0(DI), AX
+
+retry:
+ TESTL AX, AX
+ JZ fail
+ LEAL 1(AX), DX
+ LOCK
+ CMPXCHGL DX, 0(DI)
+ JNZ retry
+
+ SETEQ ret+8(FP)
+ RET
+
+fail:
+ MOVB AX, ret+8(FP)
+ RET
+
+TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
+ MOVQ addr+0(FP), DI
+ MOVL 0(DI), AX
+
+retry:
+ LEAL -1(AX), DX
+ TESTL DX, DX
+ JZ fail
+ LOCK
+ CMPXCHGL DX, 0(DI)
+ JNZ retry
+
+ SETEQ ret+8(FP)
+ RET
+
+fail:
+ MOVB DX, ret+8(FP)
+ RET
diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomic_bitops_common.go
new file mode 100644
index 000000000..b2a943dcb
--- /dev/null
+++ b/pkg/atomicbitops/atomic_bitops_common.go
@@ -0,0 +1,147 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !amd64
+
+package atomicbitops
+
+import (
+ "sync/atomic"
+)
+
+// AndUint32 atomically applies bitwise and operation to *addr with val.
+func AndUint32(addr *uint32, val uint32) {
+ for {
+ o := atomic.LoadUint32(addr)
+ n := o & val
+ if atomic.CompareAndSwapUint32(addr, o, n) {
+ break
+ }
+ }
+}
+
+// OrUint32 atomically applies bitwise or operation to *addr with val.
+func OrUint32(addr *uint32, val uint32) {
+ for {
+ o := atomic.LoadUint32(addr)
+ n := o | val
+ if atomic.CompareAndSwapUint32(addr, o, n) {
+ break
+ }
+ }
+}
+
+// XorUint32 atomically applies bitwise xor operation to *addr with val.
+func XorUint32(addr *uint32, val uint32) {
+ for {
+ o := atomic.LoadUint32(addr)
+ n := o ^ val
+ if atomic.CompareAndSwapUint32(addr, o, n) {
+ break
+ }
+ }
+}
+
+// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
+// the value previously stored at addr.
+func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
+ for {
+ prev = atomic.LoadUint32(addr)
+ if prev != old {
+ return
+ }
+ if atomic.CompareAndSwapUint32(addr, old, new) {
+ return
+ }
+ }
+}
+
+// AndUint64 atomically applies bitwise and operation to *addr with val.
+func AndUint64(addr *uint64, val uint64) {
+ for {
+ o := atomic.LoadUint64(addr)
+ n := o & val
+ if atomic.CompareAndSwapUint64(addr, o, n) {
+ break
+ }
+ }
+}
+
+// OrUint64 atomically applies bitwise or operation to *addr with val.
+func OrUint64(addr *uint64, val uint64) {
+ for {
+ o := atomic.LoadUint64(addr)
+ n := o | val
+ if atomic.CompareAndSwapUint64(addr, o, n) {
+ break
+ }
+ }
+}
+
+// XorUint64 atomically applies bitwise xor operation to *addr with val.
+func XorUint64(addr *uint64, val uint64) {
+ for {
+ o := atomic.LoadUint64(addr)
+ n := o ^ val
+ if atomic.CompareAndSwapUint64(addr, o, n) {
+ break
+ }
+ }
+}
+
+// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
+// the value previously stored at addr.
+func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
+ for {
+ prev = atomic.LoadUint64(addr)
+ if prev != old {
+ return
+ }
+ if atomic.CompareAndSwapUint64(addr, old, new) {
+ return
+ }
+ }
+}
+
+// IncUnlessZeroInt32 increments the value stored at the given address and
+// returns true; unless the value stored in the pointer is zero, in which case
+// it is left unmodified and false is returned.
+func IncUnlessZeroInt32(addr *int32) bool {
+ for {
+ v := atomic.LoadInt32(addr)
+ if v == 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt32(addr, v, v+1) {
+ return true
+ }
+ }
+}
+
+// DecUnlessOneInt32 decrements the value stored at the given address and
+// returns true; unless the value stored in the pointer is 1, in which case it
+// is left unmodified and false is returned.
+func DecUnlessOneInt32(addr *int32) bool {
+ for {
+ v := atomic.LoadInt32(addr)
+ if v == 1 {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt32(addr, v, v-1) {
+ return true
+ }
+ }
+}
diff --git a/pkg/atomicbitops/atomicbitops_state_autogen.go b/pkg/atomicbitops/atomicbitops_state_autogen.go
new file mode 100755
index 000000000..a74ea7d50
--- /dev/null
+++ b/pkg/atomicbitops/atomicbitops_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package atomicbitops
+
diff --git a/pkg/binary/binary.go b/pkg/binary/binary.go
new file mode 100644
index 000000000..631785f7b
--- /dev/null
+++ b/pkg/binary/binary.go
@@ -0,0 +1,256 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package binary translates between select fixed-sized types and a binary
+// representation.
+package binary
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "reflect"
+)
+
+// LittleEndian is the same as encoding/binary.LittleEndian.
+//
+// It is included here as a convenience.
+var LittleEndian = binary.LittleEndian
+
+// BigEndian is the same as encoding/binary.BigEndian.
+//
+// It is included here as a convenience.
+var BigEndian = binary.BigEndian
+
+// AppendUint16 appends the binary representation of a uint16 to buf.
+func AppendUint16(buf []byte, order binary.ByteOrder, num uint16) []byte {
+ buf = append(buf, make([]byte, 2)...)
+ order.PutUint16(buf[len(buf)-2:], num)
+ return buf
+}
+
+// AppendUint32 appends the binary representation of a uint32 to buf.
+func AppendUint32(buf []byte, order binary.ByteOrder, num uint32) []byte {
+ buf = append(buf, make([]byte, 4)...)
+ order.PutUint32(buf[len(buf)-4:], num)
+ return buf
+}
+
+// AppendUint64 appends the binary representation of a uint64 to buf.
+func AppendUint64(buf []byte, order binary.ByteOrder, num uint64) []byte {
+ buf = append(buf, make([]byte, 8)...)
+ order.PutUint64(buf[len(buf)-8:], num)
+ return buf
+}
+
+// Marshal appends a binary representation of data to buf.
+//
+// data must only contain fixed-length signed and unsigned ints, arrays,
+// slices, structs and compositions of said types. data may be a pointer,
+// but cannot contain pointers.
+func Marshal(buf []byte, order binary.ByteOrder, data interface{}) []byte {
+ return marshal(buf, order, reflect.Indirect(reflect.ValueOf(data)))
+}
+
+func marshal(buf []byte, order binary.ByteOrder, data reflect.Value) []byte {
+ switch data.Kind() {
+ case reflect.Int8:
+ buf = append(buf, byte(int8(data.Int())))
+ case reflect.Int16:
+ buf = AppendUint16(buf, order, uint16(int16(data.Int())))
+ case reflect.Int32:
+ buf = AppendUint32(buf, order, uint32(int32(data.Int())))
+ case reflect.Int64:
+ buf = AppendUint64(buf, order, uint64(data.Int()))
+
+ case reflect.Uint8:
+ buf = append(buf, byte(data.Uint()))
+ case reflect.Uint16:
+ buf = AppendUint16(buf, order, uint16(data.Uint()))
+ case reflect.Uint32:
+ buf = AppendUint32(buf, order, uint32(data.Uint()))
+ case reflect.Uint64:
+ buf = AppendUint64(buf, order, data.Uint())
+
+ case reflect.Array, reflect.Slice:
+ for i, l := 0, data.Len(); i < l; i++ {
+ buf = marshal(buf, order, data.Index(i))
+ }
+
+ case reflect.Struct:
+ for i, l := 0, data.NumField(); i < l; i++ {
+ buf = marshal(buf, order, data.Field(i))
+ }
+
+ default:
+ panic("invalid type: " + data.Type().String())
+ }
+ return buf
+}
+
+// Unmarshal unpacks buf into data.
+//
+// data must be a slice or a pointer and buf must have a length of exactly
+// Size(data). data must only contain fixed-length signed and unsigned ints,
+// arrays, slices, structs and compositions of said types.
+func Unmarshal(buf []byte, order binary.ByteOrder, data interface{}) {
+ value := reflect.ValueOf(data)
+ switch value.Kind() {
+ case reflect.Ptr:
+ value = value.Elem()
+ case reflect.Slice:
+ default:
+ panic("invalid type: " + value.Type().String())
+ }
+ buf = unmarshal(buf, order, value)
+ if len(buf) != 0 {
+ panic(fmt.Sprintf("buffer too long by %d bytes", len(buf)))
+ }
+}
+
+func unmarshal(buf []byte, order binary.ByteOrder, data reflect.Value) []byte {
+ switch data.Kind() {
+ case reflect.Int8:
+ data.SetInt(int64(int8(buf[0])))
+ buf = buf[1:]
+ case reflect.Int16:
+ data.SetInt(int64(int16(order.Uint16(buf))))
+ buf = buf[2:]
+ case reflect.Int32:
+ data.SetInt(int64(int32(order.Uint32(buf))))
+ buf = buf[4:]
+ case reflect.Int64:
+ data.SetInt(int64(order.Uint64(buf)))
+ buf = buf[8:]
+
+ case reflect.Uint8:
+ data.SetUint(uint64(buf[0]))
+ buf = buf[1:]
+ case reflect.Uint16:
+ data.SetUint(uint64(order.Uint16(buf)))
+ buf = buf[2:]
+ case reflect.Uint32:
+ data.SetUint(uint64(order.Uint32(buf)))
+ buf = buf[4:]
+ case reflect.Uint64:
+ data.SetUint(order.Uint64(buf))
+ buf = buf[8:]
+
+ case reflect.Array, reflect.Slice:
+ for i, l := 0, data.Len(); i < l; i++ {
+ buf = unmarshal(buf, order, data.Index(i))
+ }
+
+ case reflect.Struct:
+ for i, l := 0, data.NumField(); i < l; i++ {
+ if field := data.Field(i); field.CanSet() {
+ buf = unmarshal(buf, order, field)
+ } else {
+ buf = buf[sizeof(field):]
+ }
+ }
+
+ default:
+ panic("invalid type: " + data.Type().String())
+ }
+ return buf
+}
+
+// Size calculates the buffer sized needed by Marshal or Unmarshal.
+//
+// Size only support the types supported by Marshal.
+func Size(v interface{}) uintptr {
+ return sizeof(reflect.Indirect(reflect.ValueOf(v)))
+}
+
+func sizeof(data reflect.Value) uintptr {
+ switch data.Kind() {
+ case reflect.Int8, reflect.Uint8:
+ return 1
+ case reflect.Int16, reflect.Uint16:
+ return 2
+ case reflect.Int32, reflect.Uint32:
+ return 4
+ case reflect.Int64, reflect.Uint64:
+ return 8
+
+ case reflect.Array, reflect.Slice:
+ var size uintptr
+ for i, l := 0, data.Len(); i < l; i++ {
+ size += sizeof(data.Index(i))
+ }
+ return size
+
+ case reflect.Struct:
+ var size uintptr
+ for i, l := 0, data.NumField(); i < l; i++ {
+ size += sizeof(data.Field(i))
+ }
+ return size
+
+ default:
+ panic("invalid type: " + data.Type().String())
+ }
+}
+
+// ReadUint16 reads a uint16 from r.
+func ReadUint16(r io.Reader, order binary.ByteOrder) (uint16, error) {
+ buf := make([]byte, 2)
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return 0, err
+ }
+ return order.Uint16(buf), nil
+}
+
+// ReadUint32 reads a uint32 from r.
+func ReadUint32(r io.Reader, order binary.ByteOrder) (uint32, error) {
+ buf := make([]byte, 4)
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return 0, err
+ }
+ return order.Uint32(buf), nil
+}
+
+// ReadUint64 reads a uint64 from r.
+func ReadUint64(r io.Reader, order binary.ByteOrder) (uint64, error) {
+ buf := make([]byte, 8)
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return 0, err
+ }
+ return order.Uint64(buf), nil
+}
+
+// WriteUint16 writes a uint16 to w.
+func WriteUint16(w io.Writer, order binary.ByteOrder, num uint16) error {
+ buf := make([]byte, 2)
+ order.PutUint16(buf, num)
+ _, err := w.Write(buf)
+ return err
+}
+
+// WriteUint32 writes a uint32 to w.
+func WriteUint32(w io.Writer, order binary.ByteOrder, num uint32) error {
+ buf := make([]byte, 4)
+ order.PutUint32(buf, num)
+ _, err := w.Write(buf)
+ return err
+}
+
+// WriteUint64 writes a uint64 to w.
+func WriteUint64(w io.Writer, order binary.ByteOrder, num uint64) error {
+ buf := make([]byte, 8)
+ order.PutUint64(buf, num)
+ _, err := w.Write(buf)
+ return err
+}
diff --git a/pkg/binary/binary_state_autogen.go b/pkg/binary/binary_state_autogen.go
new file mode 100755
index 000000000..e29aeb344
--- /dev/null
+++ b/pkg/binary/binary_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package binary
+
diff --git a/pkg/bits/bits.go b/pkg/bits/bits.go
new file mode 100644
index 000000000..a26433ad6
--- /dev/null
+++ b/pkg/bits/bits.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bits includes all bit related types and operations.
+package bits
diff --git a/pkg/bits/bits32.go b/pkg/bits/bits32.go
new file mode 100755
index 000000000..4e9e45dce
--- /dev/null
+++ b/pkg/bits/bits32.go
@@ -0,0 +1,25 @@
+package bits
+
+// IsOn returns true if *all* bits set in 'bits' are set in 'mask'.
+func IsOn32(mask, bits uint32) bool {
+ return mask&bits == bits
+}
+
+// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'.
+func IsAnyOn32(mask, bits uint32) bool {
+ return mask&bits != 0
+}
+
+// Mask returns a T with all of the given bits set.
+func Mask32(is ...int) uint32 {
+ ret := uint32(0)
+ for _, i := range is {
+ ret |= MaskOf32(i)
+ }
+ return ret
+}
+
+// MaskOf is like Mask, but sets only a single bit (more efficiently).
+func MaskOf32(i int) uint32 {
+ return uint32(1) << uint32(i)
+}
diff --git a/pkg/bits/bits64.go b/pkg/bits/bits64.go
new file mode 100755
index 000000000..f49158792
--- /dev/null
+++ b/pkg/bits/bits64.go
@@ -0,0 +1,25 @@
+package bits
+
+// IsOn returns true if *all* bits set in 'bits' are set in 'mask'.
+func IsOn64(mask, bits uint64) bool {
+ return mask&bits == bits
+}
+
+// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'.
+func IsAnyOn64(mask, bits uint64) bool {
+ return mask&bits != 0
+}
+
+// Mask returns a T with all of the given bits set.
+func Mask64(is ...int) uint64 {
+ ret := uint64(0)
+ for _, i := range is {
+ ret |= MaskOf64(i)
+ }
+ return ret
+}
+
+// MaskOf is like Mask, but sets only a single bit (more efficiently).
+func MaskOf64(i int) uint64 {
+ return uint64(1) << uint64(i)
+}
diff --git a/pkg/bits/bits_state_autogen.go b/pkg/bits/bits_state_autogen.go
new file mode 100755
index 000000000..2abb1291b
--- /dev/null
+++ b/pkg/bits/bits_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package bits
+
diff --git a/pkg/bits/uint64_arch_amd64.go b/pkg/bits/uint64_arch_amd64.go
new file mode 100644
index 000000000..faccaa61a
--- /dev/null
+++ b/pkg/bits/uint64_arch_amd64.go
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package bits
+
+// TrailingZeros64 returns the number of bits before the least significant 1
+// bit in x; in other words, it returns the index of the least significant 1
+// bit in x. If x is 0, TrailingZeros64 returns 64.
+func TrailingZeros64(x uint64) int
+
+// MostSignificantOne64 returns the index of the most significant 1 bit in
+// x. If x is 0, MostSignificantOne64 returns 64.
+func MostSignificantOne64(x uint64) int
+
+// ForEachSetBit64 calls f once for each set bit in x, with argument i equal to
+// the set bit's index.
+func ForEachSetBit64(x uint64, f func(i int)) {
+ for x != 0 {
+ i := TrailingZeros64(x)
+ f(i)
+ x &^= MaskOf64(i)
+ }
+}
diff --git a/pkg/bits/uint64_arch_amd64_asm.s b/pkg/bits/uint64_arch_amd64_asm.s
new file mode 100644
index 000000000..8ff364181
--- /dev/null
+++ b/pkg/bits/uint64_arch_amd64_asm.s
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+TEXT ·TrailingZeros64(SB),$0-16
+ BSFQ x+0(FP), AX
+ JNZ end
+ MOVQ $64, AX
+end:
+ MOVQ AX, ret+8(FP)
+ RET
+
+TEXT ·MostSignificantOne64(SB),$0-16
+ BSRQ x+0(FP), AX
+ JNZ end
+ MOVQ $64, AX
+end:
+ MOVQ AX, ret+8(FP)
+ RET
diff --git a/pkg/bits/uint64_arch_generic.go b/pkg/bits/uint64_arch_generic.go
new file mode 100644
index 000000000..7dd2d1480
--- /dev/null
+++ b/pkg/bits/uint64_arch_generic.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !amd64
+
+package bits
+
+// TrailingZeros64 returns the number of bits before the least significant 1
+// bit in x; in other words, it returns the index of the least significant 1
+// bit in x. If x is 0, TrailingZeros64 returns 64.
+func TrailingZeros64(x uint64) int {
+ if x == 0 {
+ return 64
+ }
+ i := 0
+ for ; x&1 == 0; i++ {
+ x >>= 1
+ }
+ return i
+}
+
+// MostSignificantOne64 returns the index of the most significant 1 bit in
+// x. If x is 0, MostSignificantOne64 returns 64.
+func MostSignificantOne64(x uint64) int {
+ if x == 0 {
+ return 64
+ }
+ i := 63
+ for ; x&(1<<63) == 0; i-- {
+ x <<= 1
+ }
+ return i
+}
+
+// ForEachSetBit64 calls f once for each set bit in x, with argument i equal to
+// the set bit's index.
+func ForEachSetBit64(x uint64, f func(i int)) {
+ for i := 0; x != 0; i++ {
+ if x&1 != 0 {
+ f(i)
+ }
+ x >>= 1
+ }
+}
diff --git a/pkg/bpf/bpf.go b/pkg/bpf/bpf.go
new file mode 100644
index 000000000..eb546f48f
--- /dev/null
+++ b/pkg/bpf/bpf.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bpf provides tools for working with Berkeley Packet Filter (BPF)
+// programs. More information on BPF can be found at
+// https://www.freebsd.org/cgi/man.cgi?bpf(4)
+package bpf
+
+import "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+
+const (
+ // MaxInstructions is the maximum number of instructions in a BPF program,
+ // and is equal to Linux's BPF_MAXINSNS.
+ MaxInstructions = 4096
+
+ // ScratchMemRegisters is the number of M registers in a BPF virtual machine,
+ // and is equal to Linux's BPF_MEMWORDS.
+ ScratchMemRegisters = 16
+)
+
+// Parts of a linux.BPFInstruction.OpCode. Compare to the Linux kernel's
+// include/uapi/linux/filter.h.
+//
+// In the comments below:
+//
+// - A, X, and M[] are BPF virtual machine registers.
+//
+// - K refers to the instruction field linux.BPFInstruction.K.
+//
+// - Bits are counted from the LSB position.
+const (
+ // Instruction class, stored in bits 0-2.
+ Ld = 0x00 // load into A
+ Ldx = 0x01 // load into X
+ St = 0x02 // store from A
+ Stx = 0x03 // store from X
+ Alu = 0x04 // arithmetic
+ Jmp = 0x05 // jump
+ Ret = 0x06 // return
+ Misc = 0x07
+ instructionClassMask = 0x07
+
+ // Size of a load, stored in bits 3-4.
+ W = 0x00 // 32 bits
+ H = 0x08 // 16 bits
+ B = 0x10 // 8 bits
+ loadSizeMask = 0x18
+
+ // Source operand for a load, stored in bits 5-7.
+ // Address mode numbers in the comments come from Linux's
+ // Documentation/networking/filter.txt.
+ Imm = 0x00 // immediate value K (mode 4)
+ Abs = 0x20 // data in input at byte offset K (mode 1)
+ Ind = 0x40 // data in input at byte offset X+K (mode 2)
+ Mem = 0x60 // M[K] (mode 3)
+ Len = 0x80 // length of the input in bytes ("BPF extension len")
+ Msh = 0xa0 // 4 * lower nibble of input at byte offset K (mode 5)
+ loadModeMask = 0xe0
+
+ // Source operands for arithmetic, jump, and return instructions.
+ // Arithmetic and jump instructions can use K or X as source operands.
+ // Return instructions can use K or A as source operands.
+ K = 0x00 // still mode 4
+ X = 0x08 // mode 0
+ A = 0x10 // mode 9
+ srcAluJmpMask = 0x08
+ srcRetMask = 0x18
+
+ // Arithmetic instructions, stored in bits 4-7.
+ Add = 0x00
+ Sub = 0x10 // A - src
+ Mul = 0x20
+ Div = 0x30 // A / src
+ Or = 0x40
+ And = 0x50
+ Lsh = 0x60 // A << src
+ Rsh = 0x70 // A >> src
+ Neg = 0x80 // -A (src ignored)
+ Mod = 0x90 // A % src
+ Xor = 0xa0
+ aluMask = 0xf0
+
+ // Jump instructions, stored in bits 4-7.
+ Ja = 0x00 // unconditional (uses K for jump offset)
+ Jeq = 0x10 // if A == src
+ Jgt = 0x20 // if A > src
+ Jge = 0x30 // if A >= src
+ Jset = 0x40 // if (A & src) != 0
+ jmpMask = 0xf0
+
+ // Miscellaneous instructions, stored in bits 3-7.
+ Tax = 0x00 // A = X
+ Txa = 0x80 // X = A
+ miscMask = 0xf8
+
+ // Masks for bits that should be zero.
+ unusedBitsMask = 0xff00 // all valid instructions use only bits 0-7
+ storeUnusedBitsMask = 0xf8 // stores only use instruction class
+ retUnusedBitsMask = 0xe0 // returns only use instruction class and source operand
+)
+
+// Stmt returns a linux.BPFInstruction representing a BPF non-jump instruction.
+func Stmt(code uint16, k uint32) linux.BPFInstruction {
+ return linux.BPFInstruction{
+ OpCode: code,
+ K: k,
+ }
+}
+
+// Jump returns a linux.BPFInstruction representing a BPF jump instruction.
+func Jump(code uint16, k uint32, jt, jf uint8) linux.BPFInstruction {
+ return linux.BPFInstruction{
+ OpCode: code,
+ JumpIfTrue: jt,
+ JumpIfFalse: jf,
+ K: k,
+ }
+}
diff --git a/pkg/bpf/bpf_state_autogen.go b/pkg/bpf/bpf_state_autogen.go
new file mode 100755
index 000000000..05effb7b6
--- /dev/null
+++ b/pkg/bpf/bpf_state_autogen.go
@@ -0,0 +1,22 @@
+// automatically generated by stateify.
+
+package bpf
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Program) beforeSave() {}
+func (x *Program) save(m state.Map) {
+ x.beforeSave()
+ m.Save("instructions", &x.instructions)
+}
+
+func (x *Program) afterLoad() {}
+func (x *Program) load(m state.Map) {
+ m.Load("instructions", &x.instructions)
+}
+
+func init() {
+ state.Register("bpf.Program", (*Program)(nil), state.Fns{Save: (*Program).save, Load: (*Program).load})
+}
diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go
new file mode 100644
index 000000000..45c192215
--- /dev/null
+++ b/pkg/bpf/decoder.go
@@ -0,0 +1,245 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// DecodeProgram translates an array of BPF instructions into text format.
+func DecodeProgram(program []linux.BPFInstruction) (string, error) {
+ var ret bytes.Buffer
+ for line, s := range program {
+ ret.WriteString(fmt.Sprintf("%v: ", line))
+ if err := decode(s, line, &ret); err != nil {
+ return "", err
+ }
+ ret.WriteString("\n")
+ }
+ return ret.String(), nil
+}
+
+// Decode translates BPF instruction into text format.
+func Decode(inst linux.BPFInstruction) (string, error) {
+ var ret bytes.Buffer
+ err := decode(inst, -1, &ret)
+ return ret.String(), err
+}
+
+func decode(inst linux.BPFInstruction, line int, w *bytes.Buffer) error {
+ var err error
+ switch inst.OpCode & instructionClassMask {
+ case Ld:
+ err = decodeLd(inst, w)
+ case Ldx:
+ err = decodeLdx(inst, w)
+ case St:
+ w.WriteString(fmt.Sprintf("M[%v] <- A", inst.K))
+ case Stx:
+ w.WriteString(fmt.Sprintf("M[%v] <- X", inst.K))
+ case Alu:
+ err = decodeAlu(inst, w)
+ case Jmp:
+ err = decodeJmp(inst, line, w)
+ case Ret:
+ err = decodeRet(inst, w)
+ case Misc:
+ err = decodeMisc(inst, w)
+ default:
+ return fmt.Errorf("invalid BPF instruction: %v", inst)
+ }
+ return err
+}
+
+// A <- P[k:4]
+func decodeLd(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ w.WriteString("A <- ")
+
+ switch inst.OpCode & loadModeMask {
+ case Imm:
+ w.WriteString(fmt.Sprintf("%v", inst.K))
+ case Abs:
+ w.WriteString(fmt.Sprintf("P[%v:", inst.K))
+ if err := decodeLdSize(inst, w); err != nil {
+ return err
+ }
+ w.WriteString("]")
+ case Ind:
+ w.WriteString(fmt.Sprintf("P[X+%v:", inst.K))
+ if err := decodeLdSize(inst, w); err != nil {
+ return err
+ }
+ w.WriteString("]")
+ case Mem:
+ w.WriteString(fmt.Sprintf("M[%v]", inst.K))
+ case Len:
+ w.WriteString("len")
+ default:
+ return fmt.Errorf("invalid BPF LD instruction: %v", inst)
+ }
+ return nil
+}
+
+func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ switch inst.OpCode & loadSizeMask {
+ case W:
+ w.WriteString("4")
+ case H:
+ w.WriteString("2")
+ case B:
+ w.WriteString("1")
+ default:
+ return fmt.Errorf("Invalid BPF LD size: %v", inst)
+ }
+ return nil
+}
+
+// X <- P[k:4]
+func decodeLdx(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ w.WriteString("X <- ")
+
+ switch inst.OpCode & loadModeMask {
+ case Imm:
+ w.WriteString(fmt.Sprintf("%v", inst.K))
+ case Mem:
+ w.WriteString(fmt.Sprintf("M[%v]", inst.K))
+ case Len:
+ w.WriteString("len")
+ case Msh:
+ w.WriteString(fmt.Sprintf("4*(P[%v:1]&0xf)", inst.K))
+ default:
+ return fmt.Errorf("invalid BPF LDX instruction: %v", inst)
+ }
+ return nil
+}
+
+// A <- A + k
+func decodeAlu(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ code := inst.OpCode & aluMask
+ if code == Neg {
+ w.WriteString("A <- -A")
+ return nil
+ }
+
+ w.WriteString("A <- A ")
+ switch code {
+ case Add:
+ w.WriteString("+ ")
+ case Sub:
+ w.WriteString("- ")
+ case Mul:
+ w.WriteString("* ")
+ case Div:
+ w.WriteString("/ ")
+ case Or:
+ w.WriteString("| ")
+ case And:
+ w.WriteString("& ")
+ case Lsh:
+ w.WriteString("<< ")
+ case Rsh:
+ w.WriteString(">> ")
+ case Mod:
+ w.WriteString("% ")
+ case Xor:
+ w.WriteString("^ ")
+ default:
+ return fmt.Errorf("invalid BPF ALU instruction: %v", inst)
+ }
+ return decodeSource(inst, w)
+}
+
+func decodeSource(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ switch inst.OpCode & srcAluJmpMask {
+ case K:
+ w.WriteString(fmt.Sprintf("%v", inst.K))
+ case X:
+ w.WriteString("X")
+ default:
+ return fmt.Errorf("invalid BPF ALU/JMP source instruction: %v", inst)
+ }
+ return nil
+}
+
+// pc += (A > k) ? jt : jf
+func decodeJmp(inst linux.BPFInstruction, line int, w *bytes.Buffer) error {
+ code := inst.OpCode & jmpMask
+
+ w.WriteString("pc += ")
+ if code == Ja {
+ w.WriteString(printJmpTarget(inst.K, line))
+ } else {
+ w.WriteString("(A ")
+ switch code {
+ case Jeq:
+ w.WriteString("== ")
+ case Jgt:
+ w.WriteString("> ")
+ case Jge:
+ w.WriteString(">= ")
+ case Jset:
+ w.WriteString("& ")
+ default:
+ return fmt.Errorf("invalid BPF ALU instruction: %v", inst)
+ }
+ if err := decodeSource(inst, w); err != nil {
+ return err
+ }
+ w.WriteString(
+ fmt.Sprintf(") ? %s : %s",
+ printJmpTarget(uint32(inst.JumpIfTrue), line),
+ printJmpTarget(uint32(inst.JumpIfFalse), line)))
+ }
+ return nil
+}
+
+func printJmpTarget(target uint32, line int) string {
+ if line == -1 {
+ return fmt.Sprintf("%v", target)
+ }
+ return fmt.Sprintf("%v [%v]", target, int(target)+line+1)
+}
+
+// ret k
+func decodeRet(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ w.WriteString("ret ")
+
+ code := inst.OpCode & srcRetMask
+ switch code {
+ case K:
+ w.WriteString(fmt.Sprintf("%v", inst.K))
+ case A:
+ w.WriteString("A")
+ default:
+ return fmt.Errorf("invalid BPF RET source instruction: %v", inst)
+ }
+ return nil
+}
+
+func decodeMisc(inst linux.BPFInstruction, w *bytes.Buffer) error {
+ code := inst.OpCode & miscMask
+ switch code {
+ case Tax:
+ w.WriteString("X <- A")
+ case Txa:
+ w.WriteString("A <- X")
+ default:
+ return fmt.Errorf("invalid BPF ALU/JMP source instruction: %v", inst)
+ }
+ return nil
+}
diff --git a/pkg/bpf/input_bytes.go b/pkg/bpf/input_bytes.go
new file mode 100644
index 000000000..86b216cfc
--- /dev/null
+++ b/pkg/bpf/input_bytes.go
@@ -0,0 +1,58 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "encoding/binary"
+)
+
+// InputBytes implements the Input interface by providing access to a byte
+// slice. Unaligned loads are supported.
+type InputBytes struct {
+ // Data is the data accessed through the Input interface.
+ Data []byte
+
+ // Order is the byte order the data is accessed with.
+ Order binary.ByteOrder
+}
+
+// Load32 implements Input.Load32.
+func (i InputBytes) Load32(off uint32) (uint32, bool) {
+ if uint64(off)+4 > uint64(len(i.Data)) {
+ return 0, false
+ }
+ return i.Order.Uint32(i.Data[int(off):]), true
+}
+
+// Load16 implements Input.Load16.
+func (i InputBytes) Load16(off uint32) (uint16, bool) {
+ if uint64(off)+2 > uint64(len(i.Data)) {
+ return 0, false
+ }
+ return i.Order.Uint16(i.Data[int(off):]), true
+}
+
+// Load8 implements Input.Load8.
+func (i InputBytes) Load8(off uint32) (uint8, bool) {
+ if uint64(off)+1 > uint64(len(i.Data)) {
+ return 0, false
+ }
+ return i.Data[int(off)], true
+}
+
+// Length implements Input.Length.
+func (i InputBytes) Length() uint32 {
+ return uint32(len(i.Data))
+}
diff --git a/pkg/bpf/interpreter.go b/pkg/bpf/interpreter.go
new file mode 100644
index 000000000..86de523a2
--- /dev/null
+++ b/pkg/bpf/interpreter.go
@@ -0,0 +1,412 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// Possible values for ProgramError.Code.
+const (
+ // DivisionByZero indicates that a program contains, or executed, a
+ // division or modulo by zero.
+ DivisionByZero = iota
+
+ // InvalidEndOfProgram indicates that the last instruction of a program is
+ // not a return.
+ InvalidEndOfProgram
+
+ // InvalidInstructionCount indicates that a program has zero instructions
+ // or more than MaxInstructions instructions.
+ InvalidInstructionCount
+
+ // InvalidJumpTarget indicates that a program contains a jump whose target
+ // is outside of the program's bounds.
+ InvalidJumpTarget
+
+ // InvalidLoad indicates that a program executed an invalid load of input
+ // data.
+ InvalidLoad
+
+ // InvalidOpcode indicates that a program contains an instruction with an
+ // invalid opcode.
+ InvalidOpcode
+
+ // InvalidRegister indicates that a program contains a load from, or store
+ // to, a non-existent M register (index >= ScratchMemRegisters).
+ InvalidRegister
+)
+
+// Error is an error encountered while compiling or executing a BPF program.
+type Error struct {
+ // Code indicates the kind of error that occurred.
+ Code int
+
+ // PC is the program counter (index into the list of instructions) at which
+ // the error occurred.
+ PC int
+}
+
+func (e Error) codeString() string {
+ switch e.Code {
+ case DivisionByZero:
+ return "division by zero"
+ case InvalidEndOfProgram:
+ return "last instruction must be a return"
+ case InvalidInstructionCount:
+ return "invalid number of instructions"
+ case InvalidJumpTarget:
+ return "jump target out of bounds"
+ case InvalidLoad:
+ return "load out of bounds or violates input alignment requirements"
+ case InvalidOpcode:
+ return "invalid instruction opcode"
+ case InvalidRegister:
+ return "invalid M register"
+ default:
+ return "unknown error"
+ }
+}
+
+// Error implements error.Error.
+func (e Error) Error() string {
+ return fmt.Sprintf("at l%d: %s", e.PC, e.codeString())
+}
+
+// Program is a BPF program that has been validated for consistency.
+//
+// +stateify savable
+type Program struct {
+ instructions []linux.BPFInstruction
+}
+
+// Length returns the number of instructions in the program.
+func (p Program) Length() int {
+ return len(p.instructions)
+}
+
+// Compile performs validation on a sequence of BPF instructions before
+// wrapping them in a Program.
+func Compile(insns []linux.BPFInstruction) (Program, error) {
+ if len(insns) == 0 || len(insns) > MaxInstructions {
+ return Program{}, Error{InvalidInstructionCount, len(insns)}
+ }
+
+ // The last instruction must be a return.
+ if last := insns[len(insns)-1]; last.OpCode != (Ret|K) && last.OpCode != (Ret|A) {
+ return Program{}, Error{InvalidEndOfProgram, len(insns) - 1}
+ }
+
+ // Validate each instruction. Note that we skip a validation Linux does:
+ // Linux additionally verifies that every load from an M register is
+ // preceded, in every path, by a store to the same M register, in order to
+ // avoid having to clear M between programs
+ // (net/core/filter.c:check_load_and_stores). We always start with a zeroed
+ // M array.
+ for pc, i := range insns {
+ if i.OpCode&unusedBitsMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ switch i.OpCode & instructionClassMask {
+ case Ld:
+ mode := i.OpCode & loadModeMask
+ switch i.OpCode & loadSizeMask {
+ case W:
+ if mode != Imm && mode != Abs && mode != Ind && mode != Mem && mode != Len {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ if mode == Mem && i.K >= ScratchMemRegisters {
+ return Program{}, Error{InvalidRegister, pc}
+ }
+ case H, B:
+ if mode != Abs && mode != Ind {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ default:
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case Ldx:
+ mode := i.OpCode & loadModeMask
+ switch i.OpCode & loadSizeMask {
+ case W:
+ if mode != Imm && mode != Mem && mode != Len {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ if mode == Mem && i.K >= ScratchMemRegisters {
+ return Program{}, Error{InvalidRegister, pc}
+ }
+ case B:
+ if mode != Msh {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ default:
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case St, Stx:
+ if i.OpCode&storeUnusedBitsMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ if i.K >= ScratchMemRegisters {
+ return Program{}, Error{InvalidRegister, pc}
+ }
+ case Alu:
+ switch i.OpCode & aluMask {
+ case Add, Sub, Mul, Or, And, Lsh, Rsh, Xor:
+ break
+ case Div, Mod:
+ if src := i.OpCode & srcAluJmpMask; src == K && i.K == 0 {
+ return Program{}, Error{DivisionByZero, pc}
+ }
+ case Neg:
+ // Negation doesn't take a source operand.
+ if i.OpCode&srcAluJmpMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ default:
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case Jmp:
+ switch i.OpCode & jmpMask {
+ case Ja:
+ // Unconditional jump doesn't take a source operand.
+ if i.OpCode&srcAluJmpMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ // Do the comparison in 64 bits to avoid the possibility of
+ // overflow from a very large i.K.
+ if uint64(pc)+uint64(i.K)+1 >= uint64(len(insns)) {
+ return Program{}, Error{InvalidJumpTarget, pc}
+ }
+ case Jeq, Jgt, Jge, Jset:
+ // jt and jf are uint16s, so there's no threat of overflow.
+ if pc+int(i.JumpIfTrue)+1 >= len(insns) {
+ return Program{}, Error{InvalidJumpTarget, pc}
+ }
+ if pc+int(i.JumpIfFalse)+1 >= len(insns) {
+ return Program{}, Error{InvalidJumpTarget, pc}
+ }
+ default:
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case Ret:
+ if i.OpCode&retUnusedBitsMask != 0 {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ if src := i.OpCode & srcRetMask; src != K && src != A {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ case Misc:
+ if misc := i.OpCode & miscMask; misc != Tax && misc != Txa {
+ return Program{}, Error{InvalidOpcode, pc}
+ }
+ }
+ }
+
+ return Program{insns}, nil
+}
+
+// Input represents a source of input data for a BPF program. (BPF
+// documentation sometimes refers to the input data as the "packet" due to its
+// origins as a packet processing DSL.)
+//
+// For all of Input's Load methods:
+//
+// - The second (bool) return value is true if the load succeeded and false
+// otherwise.
+//
+// - Inputs should not assume that the loaded range falls within the input
+// data's length. Inputs should return false if the load falls outside of the
+// input data.
+//
+// - Inputs should not assume that the offset is correctly aligned. Inputs may
+// choose to service or reject loads to unaligned addresses.
+type Input interface {
+ // Load32 reads 32 bits from the input starting at the given byte offset.
+ Load32(off uint32) (uint32, bool)
+
+ // Load16 reads 16 bits from the input starting at the given byte offset.
+ Load16(off uint32) (uint16, bool)
+
+ // Load8 reads 8 bits from the input starting at the given byte offset.
+ Load8(off uint32) (uint8, bool)
+
+ // Length returns the length of the input in bytes.
+ Length() uint32
+}
+
+// machine represents the state of a BPF virtual machine.
+type machine struct {
+ A uint32
+ X uint32
+ M [ScratchMemRegisters]uint32
+}
+
+func conditionalJumpOffset(insn linux.BPFInstruction, cond bool) int {
+ if cond {
+ return int(insn.JumpIfTrue)
+ }
+ return int(insn.JumpIfFalse)
+}
+
+// Exec executes a BPF program over the given input and returns its return
+// value.
+func Exec(p Program, in Input) (uint32, error) {
+ var m machine
+ var pc int
+ for ; pc < len(p.instructions); pc++ {
+ i := p.instructions[pc]
+ switch i.OpCode {
+ case Ld | Imm | W:
+ m.A = i.K
+ case Ld | Abs | W:
+ val, ok := in.Load32(i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = val
+ case Ld | Abs | H:
+ val, ok := in.Load16(i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = uint32(val)
+ case Ld | Abs | B:
+ val, ok := in.Load8(i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = uint32(val)
+ case Ld | Ind | W:
+ val, ok := in.Load32(m.X + i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = val
+ case Ld | Ind | H:
+ val, ok := in.Load16(m.X + i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = uint32(val)
+ case Ld | Ind | B:
+ val, ok := in.Load8(m.X + i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.A = uint32(val)
+ case Ld | Mem | W:
+ m.A = m.M[int(i.K)]
+ case Ld | Len | W:
+ m.A = in.Length()
+ case Ldx | Imm | W:
+ m.X = i.K
+ case Ldx | Mem | W:
+ m.X = m.M[int(i.K)]
+ case Ldx | Len | W:
+ m.X = in.Length()
+ case Ldx | Msh | B:
+ val, ok := in.Load8(i.K)
+ if !ok {
+ return 0, Error{InvalidLoad, pc}
+ }
+ m.X = 4 * uint32(val&0xf)
+ case St:
+ m.M[int(i.K)] = m.A
+ case Stx:
+ m.M[int(i.K)] = m.X
+ case Alu | Add | K:
+ m.A += i.K
+ case Alu | Add | X:
+ m.A += m.X
+ case Alu | Sub | K:
+ m.A -= i.K
+ case Alu | Sub | X:
+ m.A -= m.X
+ case Alu | Mul | K:
+ m.A *= i.K
+ case Alu | Mul | X:
+ m.A *= m.X
+ case Alu | Div | K:
+ // K != 0 already checked by Compile.
+ m.A /= i.K
+ case Alu | Div | X:
+ if m.X == 0 {
+ return 0, Error{DivisionByZero, pc}
+ }
+ m.A /= m.X
+ case Alu | Or | K:
+ m.A |= i.K
+ case Alu | Or | X:
+ m.A |= m.X
+ case Alu | And | K:
+ m.A &= i.K
+ case Alu | And | X:
+ m.A &= m.X
+ case Alu | Lsh | K:
+ m.A <<= i.K
+ case Alu | Lsh | X:
+ m.A <<= m.X
+ case Alu | Rsh | K:
+ m.A >>= i.K
+ case Alu | Rsh | X:
+ m.A >>= m.X
+ case Alu | Neg:
+ m.A = uint32(-int32(m.A))
+ case Alu | Mod | K:
+ // K != 0 already checked by Compile.
+ m.A %= i.K
+ case Alu | Mod | X:
+ if m.X == 0 {
+ return 0, Error{DivisionByZero, pc}
+ }
+ m.A %= m.X
+ case Alu | Xor | K:
+ m.A ^= i.K
+ case Alu | Xor | X:
+ m.A ^= m.X
+ case Jmp | Ja:
+ pc += int(i.K)
+ case Jmp | Jeq | K:
+ pc += conditionalJumpOffset(i, m.A == i.K)
+ case Jmp | Jeq | X:
+ pc += conditionalJumpOffset(i, m.A == m.X)
+ case Jmp | Jgt | K:
+ pc += conditionalJumpOffset(i, m.A > i.K)
+ case Jmp | Jgt | X:
+ pc += conditionalJumpOffset(i, m.A > m.X)
+ case Jmp | Jge | K:
+ pc += conditionalJumpOffset(i, m.A >= i.K)
+ case Jmp | Jge | X:
+ pc += conditionalJumpOffset(i, m.A >= m.X)
+ case Jmp | Jset | K:
+ pc += conditionalJumpOffset(i, (m.A&i.K) != 0)
+ case Jmp | Jset | X:
+ pc += conditionalJumpOffset(i, (m.A&m.X) != 0)
+ case Ret | K:
+ return i.K, nil
+ case Ret | A:
+ return m.A, nil
+ case Misc | Tax:
+ m.A = m.X
+ case Misc | Txa:
+ m.X = m.A
+ default:
+ return 0, Error{InvalidOpcode, pc}
+ }
+ }
+ return 0, Error{InvalidEndOfProgram, pc}
+}
diff --git a/pkg/bpf/program_builder.go b/pkg/bpf/program_builder.go
new file mode 100644
index 000000000..fc9d27203
--- /dev/null
+++ b/pkg/bpf/program_builder.go
@@ -0,0 +1,191 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bpf
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+const (
+ labelTarget = math.MaxUint8
+ labelDirectTarget = math.MaxUint32
+)
+
+// ProgramBuilder assists with building a BPF program with jump
+// labels that are resolved to their proper offsets.
+type ProgramBuilder struct {
+ // Maps label names to label objects.
+ labels map[string]*label
+
+ // Array of BPF instructions that makes up the program.
+ instructions []linux.BPFInstruction
+}
+
+// NewProgramBuilder creates a new ProgramBuilder instance.
+func NewProgramBuilder() *ProgramBuilder {
+ return &ProgramBuilder{labels: map[string]*label{}}
+}
+
+// label contains information to resolve a label to an offset.
+type label struct {
+ // List of locations that reference the label in the program.
+ sources []source
+
+ // Program line when the label is located.
+ target int
+}
+
+type jmpType int
+
+const (
+ jDirect jmpType = iota
+ jTrue
+ jFalse
+)
+
+// source contains information about a single reference to a label.
+type source struct {
+ // Program line where the label reference is present.
+ line int
+
+ // True if label reference is in the 'jump if true' part of the jump.
+ // False if label reference is in the 'jump if false' part of the jump.
+ jt jmpType
+}
+
+// AddStmt adds a new statement to the program.
+func (b *ProgramBuilder) AddStmt(code uint16, k uint32) {
+ b.instructions = append(b.instructions, Stmt(code, k))
+}
+
+// AddJump adds a new jump to the program.
+func (b *ProgramBuilder) AddJump(code uint16, k uint32, jt, jf uint8) {
+ b.instructions = append(b.instructions, Jump(code, k, jt, jf))
+}
+
+// AddDirectJumpLabel adds a new jump to the program where is labelled.
+func (b *ProgramBuilder) AddDirectJumpLabel(labelName string) {
+ b.addLabelSource(labelName, jDirect)
+ b.AddJump(Jmp|Ja, labelDirectTarget, 0, 0)
+}
+
+// AddJumpTrueLabel adds a new jump to the program where 'jump if true' is a label.
+func (b *ProgramBuilder) AddJumpTrueLabel(code uint16, k uint32, jtLabel string, jf uint8) {
+ b.addLabelSource(jtLabel, jTrue)
+ b.AddJump(code, k, labelTarget, jf)
+}
+
+// AddJumpFalseLabel adds a new jump to the program where 'jump if false' is a label.
+func (b *ProgramBuilder) AddJumpFalseLabel(code uint16, k uint32, jt uint8, jfLabel string) {
+ b.addLabelSource(jfLabel, jFalse)
+ b.AddJump(code, k, jt, labelTarget)
+}
+
+// AddJumpLabels adds a new jump to the program where both jump targets are labels.
+func (b *ProgramBuilder) AddJumpLabels(code uint16, k uint32, jtLabel, jfLabel string) {
+ b.addLabelSource(jtLabel, jTrue)
+ b.addLabelSource(jfLabel, jFalse)
+ b.AddJump(code, k, labelTarget, labelTarget)
+}
+
+// AddLabel sets the given label name at the current location. The next instruction is executed
+// when the any code jumps to this label. More than one label can be added to the same location.
+func (b *ProgramBuilder) AddLabel(name string) error {
+ l, ok := b.labels[name]
+ if !ok {
+ // This is done to catch jump backwards cases, but it's not strictly wrong
+ // to have unused labels.
+ return fmt.Errorf("Adding a label that hasn't been used is not allowed: %v", name)
+ }
+ if l.target != -1 {
+ return fmt.Errorf("label %q target already set: %v", name, l.target)
+ }
+ l.target = len(b.instructions)
+ return nil
+}
+
+// Instructions returns an array of BPF instructions representing the program with all labels
+// resolved. Return error in case label resolution failed due to an invalid program.
+//
+// N.B. Partial results will be returned in the error case, which is useful for debugging.
+func (b *ProgramBuilder) Instructions() ([]linux.BPFInstruction, error) {
+ if err := b.resolveLabels(); err != nil {
+ return b.instructions, err
+ }
+ return b.instructions, nil
+}
+
+func (b *ProgramBuilder) addLabelSource(labelName string, t jmpType) {
+ l, ok := b.labels[labelName]
+ if !ok {
+ l = &label{sources: make([]source, 0), target: -1}
+ b.labels[labelName] = l
+ }
+ l.sources = append(l.sources, source{line: len(b.instructions), jt: t})
+}
+
+func (b *ProgramBuilder) resolveLabels() error {
+ for key, v := range b.labels {
+ if v.target == -1 {
+ return fmt.Errorf("label target not set: %v", key)
+ }
+ if v.target >= len(b.instructions) {
+ return fmt.Errorf("target is beyond end of ProgramBuilder")
+ }
+ for _, s := range v.sources {
+ // Finds jump instruction that references the label.
+ inst := b.instructions[s.line]
+ if s.line >= v.target {
+ return fmt.Errorf("cannot jump backwards")
+ }
+ // Calculates the jump offset from current line.
+ offset := v.target - s.line - 1
+ // Sets offset into jump instruction.
+ switch s.jt {
+ case jDirect:
+ if offset > labelDirectTarget {
+ return fmt.Errorf("jump offset to label '%v' is too large: %v, inst: %v, lineno: %v", key, offset, inst, s.line)
+ }
+ if inst.K != labelDirectTarget {
+ return fmt.Errorf("jump target is not a label")
+ }
+ inst.K = uint32(offset)
+ case jTrue:
+ if offset > labelTarget {
+ return fmt.Errorf("jump offset to label '%v' is too large: %v, inst: %v, lineno: %v", key, offset, inst, s.line)
+ }
+ if inst.JumpIfTrue != labelTarget {
+ return fmt.Errorf("jump target is not a label")
+ }
+ inst.JumpIfTrue = uint8(offset)
+ case jFalse:
+ if offset > labelTarget {
+ return fmt.Errorf("jump offset to label '%v' is too large: %v, inst: %v, lineno: %v", key, offset, inst, s.line)
+ }
+ if inst.JumpIfFalse != labelTarget {
+ return fmt.Errorf("jump target is not a label")
+ }
+ inst.JumpIfFalse = uint8(offset)
+ }
+
+ b.instructions[s.line] = inst
+ }
+ }
+ b.labels = map[string]*label{}
+ return nil
+}
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
new file mode 100644
index 000000000..8c14ccbfa
--- /dev/null
+++ b/pkg/compressio/compressio.go
@@ -0,0 +1,743 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package compressio provides parallel compression and decompression, as well
+// as optional SHA-256 hashing.
+//
+// The stream format is defined as follows.
+//
+// /------------------------------------------------------\
+// | chunk size (4-bytes) |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | compressed data |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | ...... |
+// \------------------------------------------------------/
+//
+// where each subsequent hash is calculated from the following items in order
+//
+// compressed data
+// compressed data size
+// previous hash
+//
+// so the stream integrity cannot be compromised by switching and mixing
+// compressed chunks.
+package compressio
+
+import (
+ "bytes"
+ "compress/flate"
+ "crypto/hmac"
+ "crypto/sha256"
+ "errors"
+ "hash"
+ "io"
+ "runtime"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+)
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ return bytes.NewBuffer(nil)
+ },
+}
+
+var chunkPool = sync.Pool{
+ New: func() interface{} {
+ return new(chunk)
+ },
+}
+
+// chunk is a unit of work.
+type chunk struct {
+ // compressed is compressed data.
+ //
+ // This will always be returned to the bufPool directly when work has
+ // finished (in schedule) and therefore must be allocated.
+ compressed *bytes.Buffer
+
+ // uncompressed is the uncompressed data.
+ //
+ // This is not returned to the bufPool automatically, since it may
+ // correspond to a inline slice (provided directly to Read or Write).
+ uncompressed *bytes.Buffer
+
+ // The current hash object. Only used in compress mode.
+ h hash.Hash
+
+ // The hash from previous chunks. Only used in uncompress mode.
+ lastSum []byte
+
+ // The expected hash after current chunk. Only used in uncompress mode.
+ sum []byte
+}
+
+// newChunk allocates a new chunk object (or pulls one from the pool). Buffers
+// will be allocated if nil is provided for compressed or uncompressed.
+func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk {
+ c := chunkPool.Get().(*chunk)
+ c.lastSum = lastSum
+ c.sum = sum
+ if compressed != nil {
+ c.compressed = compressed
+ } else {
+ c.compressed = bufPool.Get().(*bytes.Buffer)
+ }
+ if uncompressed != nil {
+ c.uncompressed = uncompressed
+ } else {
+ c.uncompressed = bufPool.Get().(*bytes.Buffer)
+ }
+ return c
+}
+
+// result is the result of some work; it includes the original chunk.
+type result struct {
+ *chunk
+ err error
+}
+
+// worker is a compression/decompression worker.
+//
+// The associated worker goroutine reads in uncompressed buffers from input and
+// writes compressed buffers to its output. Alternatively, the worker reads
+// compressed buffers from input and writes uncompressed buffers to its output.
+//
+// The goroutine will exit when input is closed, and the goroutine will close
+// output.
+type worker struct {
+ hashPool *hashPool
+ input chan *chunk
+ output chan result
+}
+
+// work is the main work routine; see worker.
+func (w *worker) work(compress bool, level int) {
+ defer close(w.output)
+
+ var h hash.Hash
+
+ for c := range w.input {
+ if h == nil && w.hashPool != nil {
+ h = w.hashPool.getHash()
+ }
+ if compress {
+ mw := io.Writer(c.compressed)
+ if h != nil {
+ mw = io.MultiWriter(mw, h)
+ }
+
+ // Encode this slice.
+ fw, err := flate.NewWriter(mw, level)
+ if err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+
+ // Encode the input.
+ if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+ if err := fw.Close(); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+
+ // Write the hash, if enabled.
+ if h != nil {
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ c.h = h
+ h = nil
+ }
+ } else {
+ // Check the hash of the compressed contents.
+ if h != nil {
+ h.Write(c.compressed.Bytes())
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum)))
+
+ sum := h.Sum(nil)
+ h.Reset()
+ if !hmac.Equal(c.sum, sum) {
+ w.output <- result{c, ErrHashMismatch}
+ continue
+ }
+ }
+
+ // Decode this slice.
+ fr := flate.NewReader(c.compressed)
+
+ // Decode the input.
+ if _, err := io.Copy(c.uncompressed, fr); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+ }
+
+ // Send the output.
+ w.output <- result{c, nil}
+ }
+}
+
+type hashPool struct {
+ // mu protexts the hash list.
+ mu sync.Mutex
+
+ // key is the key used to create hash objects.
+ key []byte
+
+ // hashes is the hash object free list. Note that this cannot be
+ // globally shared across readers or writers, as it is key-specific.
+ hashes []hash.Hash
+}
+
+// getHash gets a hash object for the pool. It should only be called when the
+// pool key is non-nil.
+func (p *hashPool) getHash() hash.Hash {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if len(p.hashes) == 0 {
+ return hmac.New(sha256.New, p.key)
+ }
+
+ h := p.hashes[len(p.hashes)-1]
+ p.hashes = p.hashes[:len(p.hashes)-1]
+ return h
+}
+
+func (p *hashPool) putHash(h hash.Hash) {
+ h.Reset()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.hashes = append(p.hashes, h)
+}
+
+// pool is common functionality for reader/writers.
+type pool struct {
+ // workers are the compression/decompression workers.
+ workers []worker
+
+ // chunkSize is the chunk size. This is the first four bytes in the
+ // stream and is shared across both the reader and writer.
+ chunkSize uint32
+
+ // mu protects below; it is generally the responsibility of users to
+ // acquire this mutex before calling any methods on the pool.
+ mu sync.Mutex
+
+ // nextInput is the next worker for input (scheduling).
+ nextInput int
+
+ // nextOutput is the next worker for output (result).
+ nextOutput int
+
+ // buf is the current active buffer; the exact semantics of this buffer
+ // depending on whether this is a reader or a writer.
+ buf *bytes.Buffer
+
+ // lasSum records the hash of the last chunk processed.
+ lastSum []byte
+
+ // hashPool is the hash object pool. It cannot be embedded into pool
+ // itself as worker refers to it and that would stop pool from being
+ // GCed.
+ hashPool *hashPool
+}
+
+// init initializes the worker pool.
+//
+// This should only be called once.
+func (p *pool) init(key []byte, workers int, compress bool, level int) {
+ if key != nil {
+ p.hashPool = &hashPool{key: key}
+ }
+ p.workers = make([]worker, workers)
+ for i := 0; i < len(p.workers); i++ {
+ p.workers[i] = worker{
+ hashPool: p.hashPool,
+ input: make(chan *chunk, 1),
+ output: make(chan result, 1),
+ }
+ go p.workers[i].work(compress, level) // S/R-SAFE: In save path only.
+ }
+ runtime.SetFinalizer(p, (*pool).stop)
+}
+
+// stop stops all workers.
+func (p *pool) stop() {
+ for i := 0; i < len(p.workers); i++ {
+ close(p.workers[i].input)
+ }
+ p.workers = nil
+ p.hashPool = nil
+}
+
+// handleResult calls the callback.
+func handleResult(r result, callback func(*chunk) error) error {
+ defer func() {
+ r.chunk.compressed.Reset()
+ bufPool.Put(r.chunk.compressed)
+ chunkPool.Put(r.chunk)
+ }()
+ if r.err != nil {
+ return r.err
+ }
+ return callback(r.chunk)
+}
+
+// schedule schedules the given buffers.
+//
+// If c is non-nil, then it will return as soon as the chunk is scheduled. If c
+// is nil, then it will return only when no more work is left to do.
+//
+// If no callback function is provided, then the output channel will be
+// ignored. You must be sure that the input is schedulable in this case.
+func (p *pool) schedule(c *chunk, callback func(*chunk) error) error {
+ for {
+ var (
+ inputChan chan *chunk
+ outputChan chan result
+ )
+ if c != nil && len(p.workers) != 0 {
+ inputChan = p.workers[(p.nextInput+1)%len(p.workers)].input
+ }
+ if callback != nil && p.nextOutput != p.nextInput && len(p.workers) != 0 {
+ outputChan = p.workers[(p.nextOutput+1)%len(p.workers)].output
+ }
+ if inputChan == nil && outputChan == nil {
+ return nil
+ }
+
+ select {
+ case inputChan <- c:
+ p.nextInput++
+ return nil
+ case r := <-outputChan:
+ p.nextOutput++
+ if err := handleResult(r, callback); err != nil {
+ return err
+ }
+ }
+ }
+}
+
+// reader chunks reads and decompresses.
+type reader struct {
+ pool
+
+ // in is the source.
+ in io.Reader
+}
+
+// NewReader returns a new compressed reader. If key is non-nil, the data stream
+// is assumed to contain expected hash values, which will be compared against
+// hash values computed from the compressed bytes. See package comments for
+// details.
+func NewReader(in io.Reader, key []byte) (io.Reader, error) {
+ r := &reader{
+ in: in,
+ }
+
+ // Use double buffering for read.
+ r.init(key, 2*runtime.GOMAXPROCS(0), false, 0)
+
+ var err error
+ if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil {
+ return nil, err
+ }
+
+ if r.hashPool != nil {
+ h := r.hashPool.getHash()
+ binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
+ r.lastSum = h.Sum(nil)
+ r.hashPool.putHash(h)
+ sum := make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(r.lastSum, sum) {
+ return nil, ErrHashMismatch
+ }
+ }
+
+ return r, nil
+}
+
+// errNewBuffer is returned when a new buffer is completed.
+var errNewBuffer = errors.New("buffer ready")
+
+// ErrHashMismatch is returned if the hash does not match.
+var ErrHashMismatch = errors.New("hash mismatch")
+
+// Read implements io.Reader.Read.
+func (r *reader) Read(p []byte) (int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ // Total bytes completed; this is declared up front because it must be
+ // adjustable by the callback below.
+ done := 0
+
+ // Total bytes pending in the asynchronous workers for buffers. This is
+ // used to process the proper regions of the input as inline buffers.
+ var (
+ pendingPre = r.nextInput - r.nextOutput
+ pendingInline = 0
+ )
+
+ // Define our callback for completed work.
+ callback := func(c *chunk) error {
+ // Check for an inline buffer.
+ if pendingPre == 0 && pendingInline > 0 {
+ pendingInline--
+ done += c.uncompressed.Len()
+ return nil
+ }
+
+ // Copy the resulting buffer to our intermediate one, and
+ // return errNewBuffer to ensure that we aren't called a second
+ // time. This error code is handled specially below.
+ //
+ // c.buf will be freed and return to the pool when it is done.
+ if pendingPre > 0 {
+ pendingPre--
+ }
+ r.buf = c.uncompressed
+ return errNewBuffer
+ }
+
+ for done < len(p) {
+ // Do we have buffered data available?
+ if r.buf != nil {
+ n, err := r.buf.Read(p[done:])
+ done += n
+ if err == io.EOF {
+ // This is the uncompressed buffer, it can be
+ // returned to the pool at this point.
+ r.buf.Reset()
+ bufPool.Put(r.buf)
+ r.buf = nil
+ } else if err != nil {
+ // Should never happen.
+ defer r.stop()
+ return done, err
+ }
+ continue
+ }
+
+ // Read the length of the next chunk and reset the
+ // reader. The length is used to limit the reader.
+ //
+ // See writer.flush.
+ l, err := binary.ReadUint32(r.in, binary.BigEndian)
+ if err != nil {
+ // This is generally okay as long as there
+ // are still buffers outstanding. We actually
+ // just wait for completion of those buffers here
+ // and continue our loop.
+ if err := r.schedule(nil, callback); err == nil {
+ // We've actually finished all buffers; this is
+ // the normal EOF exit path.
+ defer r.stop()
+ return done, io.EOF
+ } else if err == errNewBuffer {
+ // A new buffer is now available.
+ continue
+ } else {
+ // Some other error occurred; we cannot
+ // process any further.
+ defer r.stop()
+ return done, err
+ }
+ }
+
+ // Read this chunk and schedule decompression.
+ compressed := bufPool.Get().(*bytes.Buffer)
+ if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil {
+ // Some other error occurred; see above.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return done, err
+ }
+
+ var sum []byte
+ if r.hashPool != nil {
+ sum = make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return done, err
+ }
+ }
+
+ // Are we doing inline decoding?
+ //
+ // Note that we need to check the length here against
+ // bytes.MinRead, since the bytes library will choose to grow
+ // the slice if the available capacity is not at least
+ // bytes.MinRead. This limits inline decoding to chunkSizes
+ // that are at least bytes.MinRead (which is not unreasonable).
+ var c *chunk
+ start := done + ((pendingPre + pendingInline) * int(r.chunkSize))
+ if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead {
+ c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start]))
+ pendingInline++
+ } else {
+ c = newChunk(r.lastSum, sum, compressed, nil)
+ }
+ r.lastSum = sum
+ if err := r.schedule(c, callback); err == errNewBuffer {
+ // A new buffer was completed while we were reading.
+ // That's great, but we need to force schedule the
+ // current buffer so that it does not get lost.
+ //
+ // It is safe to pass nil as an output function here,
+ // because we know that we just freed up a slot above.
+ r.schedule(c, nil)
+ } else if err != nil {
+ // Some other error occurred; see above.
+ defer r.stop()
+ return done, err
+ }
+ }
+
+ // Make sure that everything has been decoded successfully, otherwise
+ // parts of p may not actually have completed.
+ for pendingInline > 0 {
+ if err := r.schedule(nil, func(c *chunk) error {
+ if err := callback(c); err != nil {
+ return err
+ }
+ // The nil case means that an inline buffer has
+ // completed. The callback will have already removed
+ // the inline buffer from the map, so we just return an
+ // error to check the top of the loop again.
+ return errNewBuffer
+ }); err != errNewBuffer {
+ // Some other error occurred; see above.
+ return done, err
+ }
+ }
+
+ // Need to return done here, since it may have been adjusted by the
+ // callback to compensation for partial reads on some inline buffer.
+ return done, nil
+}
+
+// writer chunks and schedules writes.
+type writer struct {
+ pool
+
+ // out is the underlying writer.
+ out io.Writer
+
+ // closed indicates whether the file has been closed.
+ closed bool
+}
+
+// NewWriter returns a new compressed writer. If key is non-nil, hash values are
+// generated and written out for compressed bytes. See package comments for
+// details.
+//
+// The recommended chunkSize is on the order of 1M. Extra memory may be
+// buffered (in the form of read-ahead, or buffered writes), and is limited to
+// O(chunkSize * [1+GOMAXPROCS]).
+func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) {
+ w := &writer{
+ pool: pool{
+ chunkSize: chunkSize,
+ buf: bufPool.Get().(*bytes.Buffer),
+ },
+ out: out,
+ }
+ w.init(key, 1+runtime.GOMAXPROCS(0), true, level)
+
+ if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil {
+ return nil, err
+ }
+
+ if w.hashPool != nil {
+ h := w.hashPool.getHash()
+ binary.WriteUint32(h, binary.BigEndian, chunkSize)
+ w.lastSum = h.Sum(nil)
+ w.hashPool.putHash(h)
+ if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
+ return nil, err
+ }
+ }
+
+ return w, nil
+}
+
+// flush writes a single buffer.
+func (w *writer) flush(c *chunk) error {
+ // Prefix each chunk with a length; this allows the reader to safely
+ // limit reads while buffering.
+ l := uint32(c.compressed.Len())
+ if err := binary.WriteUint32(w.out, binary.BigEndian, l); err != nil {
+ return err
+ }
+
+ // Write out to the stream.
+ if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil {
+ return err
+ }
+
+ if w.hashPool != nil {
+ io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
+ sum := c.h.Sum(nil)
+ w.hashPool.putHash(c.h)
+ c.h = nil
+ if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
+ return err
+ }
+ w.lastSum = sum
+ }
+
+ return nil
+}
+
+// Write implements io.Writer.Write.
+func (w *writer) Write(p []byte) (int, error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Did we close already?
+ if w.closed {
+ return 0, io.ErrUnexpectedEOF
+ }
+
+ // See above; we need to track in the same way.
+ var (
+ pendingPre = w.nextInput - w.nextOutput
+ pendingInline = 0
+ )
+ callback := func(c *chunk) error {
+ if pendingPre == 0 && pendingInline > 0 {
+ pendingInline--
+ return w.flush(c)
+ }
+ if pendingPre > 0 {
+ pendingPre--
+ }
+ err := w.flush(c)
+ c.uncompressed.Reset()
+ bufPool.Put(c.uncompressed)
+ return err
+ }
+
+ for done := 0; done < len(p); {
+ // Construct an inline buffer if we're doing an inline
+ // encoding; see above regarding the bytes.MinRead constraint.
+ if w.buf.Len() == 0 && len(p) >= done+int(w.chunkSize) && len(p) >= done+bytes.MinRead {
+ bufPool.Put(w.buf) // Return to the pool; never scheduled.
+ w.buf = bytes.NewBuffer(p[done : done+int(w.chunkSize)])
+ done += int(w.chunkSize)
+ pendingInline++
+ }
+
+ // Do we need to flush w.buf? Note that this case should be hit
+ // immediately following the inline case above.
+ left := int(w.chunkSize) - w.buf.Len()
+ if left == 0 {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil {
+ return done, err
+ }
+ // Reset the buffer, since this has now been scheduled
+ // for compression. Note that this may be trampled
+ // immediately by the bufPool.Put(w.buf) above if the
+ // next buffer happens to be inline, but that's okay.
+ w.buf = bufPool.Get().(*bytes.Buffer)
+ continue
+ }
+
+ // Read from p into w.buf.
+ toWrite := len(p) - done
+ if toWrite > left {
+ toWrite = left
+ }
+ n, err := w.buf.Write(p[done : done+toWrite])
+ done += n
+ if err != nil {
+ return done, err
+ }
+ }
+
+ // Make sure that everything has been flushed, we can't return until
+ // all the contents from p have been used.
+ for pendingInline > 0 {
+ if err := w.schedule(nil, func(c *chunk) error {
+ if err := callback(c); err != nil {
+ return err
+ }
+ // The flush was successful, return errNewBuffer here
+ // to break from the loop and check the condition
+ // again.
+ return errNewBuffer
+ }); err != errNewBuffer {
+ return len(p), err
+ }
+ }
+
+ return len(p), nil
+}
+
+// Close implements io.Closer.Close.
+func (w *writer) Close() error {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Did we already close? After the call to Close, we always mark as
+ // closed, regardless of whether the flush is successful.
+ if w.closed {
+ return io.ErrUnexpectedEOF
+ }
+ w.closed = true
+ defer w.stop()
+
+ // Schedule any remaining partial buffer; we pass w.flush directly here
+ // because the final buffer is guaranteed to not be an inline buffer.
+ if w.buf.Len() > 0 {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil {
+ return err
+ }
+ }
+
+ // Flush all scheduled buffers; see above.
+ if err := w.schedule(nil, w.flush); err != nil {
+ return err
+ }
+
+ // Close the underlying writer (if necessary).
+ if closer, ok := w.out.(io.Closer); ok {
+ return closer.Close()
+ }
+ return nil
+}
diff --git a/pkg/compressio/compressio_state_autogen.go b/pkg/compressio/compressio_state_autogen.go
new file mode 100755
index 000000000..cac5ea41c
--- /dev/null
+++ b/pkg/compressio/compressio_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package compressio
+
diff --git a/pkg/control/client/client.go b/pkg/control/client/client.go
new file mode 100644
index 000000000..3fec27846
--- /dev/null
+++ b/pkg/control/client/client.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package client provides a basic control client interface.
+package client
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+ "gvisor.googlesource.com/gvisor/pkg/urpc"
+)
+
+// ConnectTo attempts to connect to the sandbox with the given address.
+func ConnectTo(addr string) (*urpc.Client, error) {
+ // Connect to the server.
+ conn, err := unet.Connect(addr, false)
+ if err != nil {
+ return nil, err
+ }
+
+ // Wrap in our stream codec.
+ return urpc.NewClient(conn), nil
+}
diff --git a/pkg/control/client/client_state_autogen.go b/pkg/control/client/client_state_autogen.go
new file mode 100755
index 000000000..69ea753a9
--- /dev/null
+++ b/pkg/control/client/client_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package client
+
diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go
new file mode 100644
index 000000000..1a15da1a8
--- /dev/null
+++ b/pkg/control/server/server.go
@@ -0,0 +1,160 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+Package server provides a basic control server interface.
+
+Note that no objects are registered by default. Users must provide their own
+implementations of the control interface.
+*/
+package server
+
+import (
+ "os"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+ "gvisor.googlesource.com/gvisor/pkg/urpc"
+)
+
+// curUID is the unix user ID of the user that the control server is running as.
+var curUID = os.Getuid()
+
+// Server is a basic control server.
+type Server struct {
+ // socket is our bound socket.
+ socket *unet.ServerSocket
+
+ // server is our rpc server.
+ server *urpc.Server
+
+ // wg waits for the accept loop to terminate.
+ wg sync.WaitGroup
+}
+
+// New returns a new bound control server.
+func New(socket *unet.ServerSocket) *Server {
+ return &Server{
+ socket: socket,
+ server: urpc.NewServer(),
+ }
+}
+
+// FD returns the file descriptor that the server is running on.
+func (s *Server) FD() int {
+ return s.socket.FD()
+}
+
+// Wait waits for the main server goroutine to exit. This should be
+// called after a call to Serve.
+func (s *Server) Wait() {
+ s.wg.Wait()
+}
+
+// Stop stops the server. Note that this function should only be called once
+// and the server should not be used afterwards.
+func (s *Server) Stop() {
+ s.socket.Close()
+ s.wg.Wait()
+
+ // This will cause existing clients to be terminated safely.
+ s.server.Stop()
+}
+
+// StartServing starts listening for connect and spawns the main service
+// goroutine for handling incoming control requests. StartServing does not
+// block; to wait for the control server to exit, call Wait.
+func (s *Server) StartServing() error {
+ // Actually start listening.
+ if err := s.socket.Listen(); err != nil {
+ return err
+ }
+
+ s.wg.Add(1)
+ go func() { // S/R-SAFE: does not impact state directly.
+ s.serve()
+ s.wg.Done()
+ }()
+
+ return nil
+}
+
+// serve is the body of the main service goroutine. It handles incoming control
+// connections and dispatches requests to registered objects.
+func (s *Server) serve() {
+ for {
+ // Accept clients.
+ conn, err := s.socket.Accept()
+ if err != nil {
+ return
+ }
+
+ ucred, err := conn.GetPeerCred()
+ if err != nil {
+ log.Warningf("Control couldn't get credentials: %s", err.Error())
+ conn.Close()
+ continue
+ }
+
+ // Only allow this user and root.
+ if int(ucred.Uid) != curUID && ucred.Uid != 0 {
+ // Authentication failed.
+ log.Warningf("Control auth failure: other UID = %d, current UID = %d", ucred.Uid, curUID)
+ conn.Close()
+ continue
+ }
+
+ // Handle the connection non-blockingly.
+ s.server.StartHandling(conn)
+ }
+}
+
+// Register registers a specific control interface with the server.
+func (s *Server) Register(obj interface{}) {
+ s.server.Register(obj)
+}
+
+// CreateFromFD creates a new control bound to the given 'fd'. It has no
+// registered interfaces and will not start serving until StartServing is
+// called.
+func CreateFromFD(fd int) (*Server, error) {
+ socket, err := unet.NewServerSocket(fd)
+ if err != nil {
+ return nil, err
+ }
+ return New(socket), nil
+}
+
+// Create creates a new control server with an abstract unix socket
+// with the given address, which must must be unique and a valid
+// abstract socket name.
+func Create(addr string) (*Server, error) {
+ socket, err := unet.Bind(addr, false)
+ if err != nil {
+ return nil, err
+ }
+ return New(socket), nil
+}
+
+// CreateSocket creates a socket that can be used with control server,
+// but doesn't start control server. 'addr' must be a valid and unique
+// abstract socket name. Returns socket's FD, -1 in case of error.
+func CreateSocket(addr string) (int, error) {
+ socket, err := unet.Bind(addr, false)
+ if err != nil {
+ return -1, err
+ }
+ return socket.Release()
+}
diff --git a/pkg/control/server/server_state_autogen.go b/pkg/control/server/server_state_autogen.go
new file mode 100755
index 000000000..f2b4725d3
--- /dev/null
+++ b/pkg/control/server/server_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package server
+
diff --git a/pkg/cpuid/cpu_amd64.s b/pkg/cpuid/cpu_amd64.s
new file mode 100644
index 000000000..ac80d3c8a
--- /dev/null
+++ b/pkg/cpuid/cpu_amd64.s
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// func HostID(rax, rcx uint32) (ret0, ret1, ret2, ret3 uint32)
+TEXT ·HostID(SB),$0-48
+ MOVL ax+0(FP), AX
+ MOVL cx+4(FP), CX
+ CPUID
+ MOVL AX, ret0+8(FP)
+ MOVL BX, ret1+12(FP)
+ MOVL CX, ret2+16(FP)
+ MOVL DX, ret3+20(FP)
+ RET
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
new file mode 100644
index 000000000..3eb2bcd2b
--- /dev/null
+++ b/pkg/cpuid/cpuid.go
@@ -0,0 +1,941 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build i386 amd64
+
+// Package cpuid provides basic functionality for creating and adjusting CPU
+// feature sets.
+//
+// To use FeatureSets, one should start with an existing FeatureSet (either a
+// known platform, or HostFeatureSet()) and then add, remove, and test for
+// features as desired.
+//
+// For example: Test for hardware extended state saving, and if we don't have
+// it, don't expose AVX, which cannot be saved with fxsave.
+//
+// if !HostFeatureSet().HasFeature(X86FeatureXSAVE) {
+// exposedFeatures.Remove(X86FeatureAVX)
+// }
+package cpuid
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+// Common references for CPUID leaves and bits:
+//
+// Intel:
+// * Intel SDM Volume 2, Chapter 3.2 "CPUID" (more up-to-date)
+// * Intel Application Note 485 (more detailed)
+//
+// AMD:
+// * AMD64 APM Volume 3, Appendix 3 "Obtaining Processor Information ..."
+
+// Feature is a unique identifier for a particular cpu feature. We just use an
+// int as a feature number on x86.
+//
+// Features are numbered according to "blocks". Each block is 32 bits, and
+// feature bits from the same source (cpuid leaf/level) are in the same block.
+type Feature int
+
+// block is a collection of 32 Feature bits.
+type block int
+
+const blockSize = 32
+
+// Feature bits are numbered according to "blocks". Each block is 32 bits, and
+// feature bits from the same source (cpuid leaf/level) are in the same block.
+func featureID(b block, bit int) Feature {
+ return Feature(32*int(b) + bit)
+}
+
+// Block 0 constants are all of the "basic" feature bits returned by a cpuid in
+// ecx with eax=1.
+const (
+ X86FeatureSSE3 Feature = iota
+ X86FeaturePCLMULDQ
+ X86FeatureDTES64
+ X86FeatureMONITOR
+ X86FeatureDSCPL
+ X86FeatureVMX
+ X86FeatureSMX
+ X86FeatureEST
+ X86FeatureTM2
+ X86FeatureSSSE3 // Not a typo, "supplemental" SSE3.
+ X86FeatureCNXTID
+ X86FeatureSDBG
+ X86FeatureFMA
+ X86FeatureCX16
+ X86FeatureXTPR
+ X86FeaturePDCM
+ _ // ecx bit 16 is reserved.
+ X86FeaturePCID
+ X86FeatureDCA
+ X86FeatureSSE4_1
+ X86FeatureSSE4_2
+ X86FeatureX2APIC
+ X86FeatureMOVBE
+ X86FeaturePOPCNT
+ X86FeatureTSCD
+ X86FeatureAES
+ X86FeatureXSAVE
+ X86FeatureOSXSAVE
+ X86FeatureAVX
+ X86FeatureF16C
+ X86FeatureRDRAND
+ _ // ecx bit 31 is reserved.
+)
+
+// Block 1 constants are all of the "basic" feature bits returned by a cpuid in
+// edx with eax=1.
+const (
+ X86FeatureFPU Feature = 32 + iota
+ X86FeatureVME
+ X86FeatureDE
+ X86FeaturePSE
+ X86FeatureTSC
+ X86FeatureMSR
+ X86FeaturePAE
+ X86FeatureMCE
+ X86FeatureCX8
+ X86FeatureAPIC
+ _ // edx bit 10 is reserved.
+ X86FeatureSEP
+ X86FeatureMTRR
+ X86FeaturePGE
+ X86FeatureMCA
+ X86FeatureCMOV
+ X86FeaturePAT
+ X86FeaturePSE36
+ X86FeaturePSN
+ X86FeatureCLFSH
+ _ // edx bit 20 is reserved.
+ X86FeatureDS
+ X86FeatureACPI
+ X86FeatureMMX
+ X86FeatureFXSR
+ X86FeatureSSE
+ X86FeatureSSE2
+ X86FeatureSS
+ X86FeatureHTT
+ X86FeatureTM
+ X86FeatureIA64
+ X86FeaturePBE
+)
+
+// Block 2 bits are the "structured extended" features returned in ebx for
+// eax=7, ecx=0.
+const (
+ X86FeatureFSGSBase Feature = 2*32 + iota
+ X86FeatureTSC_ADJUST
+ _ // ebx bit 2 is reserved.
+ X86FeatureBMI1
+ X86FeatureHLE
+ X86FeatureAVX2
+ X86FeatureFDP_EXCPTN_ONLY
+ X86FeatureSMEP
+ X86FeatureBMI2
+ X86FeatureERMS
+ X86FeatureINVPCID
+ X86FeatureRTM
+ X86FeatureCQM
+ X86FeatureFPCSDS
+ X86FeatureMPX
+ X86FeatureRDT
+ X86FeatureAVX512F
+ X86FeatureAVX512DQ
+ X86FeatureRDSEED
+ X86FeatureADX
+ X86FeatureSMAP
+ X86FeatureAVX512IFMA
+ X86FeaturePCOMMIT
+ X86FeatureCLFLUSHOPT
+ X86FeatureCLWB
+ X86FeatureIPT // Intel processor trace.
+ X86FeatureAVX512PF
+ X86FeatureAVX512ER
+ X86FeatureAVX512CD
+ X86FeatureSHA
+ X86FeatureAVX512BW
+ X86FeatureAVX512VL
+)
+
+// Block 3 bits are the "extended" features returned in ecx for eax=7, ecx=0.
+const (
+ X86FeaturePREFETCHWT1 Feature = 3*32 + iota
+ X86FeatureAVX512VBMI
+ X86FeatureUMIP
+ X86FeaturePKU
+)
+
+// Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX.
+// The CPUID leaf is available only if 'X86FeatureXSAVE' is present.
+const (
+ X86FeatureXSAVEOPT Feature = 4*32 + iota
+ X86FeatureXSAVEC
+ X86FeatureXGETBV1
+ X86FeatureXSAVES
+ // EAX[31:4] are reserved.
+)
+
+// Block 5 constants are the extended feature bits in
+// CPUID.(EAX=0x80000001):ECX.
+const (
+ X86FeatureLAHF64 Feature = 5*32 + iota
+ X86FeatureCMP_LEGACY
+ X86FeatureSVM
+ X86FeatureEXTAPIC
+ X86FeatureCR8_LEGACY
+ X86FeatureLZCNT
+ X86FeatureSSE4A
+ X86FeatureMISALIGNSSE
+ X86FeaturePREFETCHW
+ X86FeatureOSVW
+ X86FeatureIBS
+ X86FeatureXOP
+ X86FeatureSKINIT
+ X86FeatureWDT
+ _ // ecx bit 14 is reserved.
+ X86FeatureLWP
+ X86FeatureFMA4
+ X86FeatureTCE
+ _ // ecx bit 18 is reserved.
+ _ // ecx bit 19 is reserved.
+ _ // ecx bit 20 is reserved.
+ X86FeatureTBM
+ X86FeatureTOPOLOGY
+ X86FeaturePERFCTR_CORE
+ X86FeaturePERFCTR_NB
+ _ // ecx bit 25 is reserved.
+ X86FeatureBPEXT
+ X86FeaturePERFCTR_TSC
+ X86FeaturePERFCTR_LLC
+ X86FeatureMWAITX
+ // ECX[31:30] are reserved.
+)
+
+// Block 6 constants are the extended feature bits in
+// CPUID.(EAX=0x80000001):EDX.
+//
+// These are sparse, and so the bit positions are assigned manually.
+const (
+ // On AMD, EDX[24:23] | EDX[17:12] | EDX[9:0] are duplicate features
+ // also defined in block 1 (in identical bit positions). Those features
+ // are not listed here.
+ block6DuplicateMask = 0x183f3ff
+
+ X86FeatureSYSCALL Feature = 6*32 + 11
+ X86FeatureNX Feature = 6*32 + 20
+ X86FeatureMMXEXT Feature = 6*32 + 22
+ X86FeatureFXSR_OPT Feature = 6*32 + 25
+ X86FeatureGBPAGES Feature = 6*32 + 26
+ X86FeatureRDTSCP Feature = 6*32 + 27
+ X86FeatureLM Feature = 6*32 + 29
+ X86Feature3DNOWEXT Feature = 6*32 + 30
+ X86Feature3DNOW Feature = 6*32 + 31
+)
+
+// linuxBlockOrder defines the order in which linux organizes the feature
+// blocks. Linux also tracks feature bits in 32-bit blocks, but in an order
+// which doesn't match well here, so for the /proc/cpuinfo generation we simply
+// re-map the blocks to Linux's ordering and then go through the bits in each
+// block.
+var linuxBlockOrder = []block{1, 6, 0, 5, 2, 4, 3}
+
+// To make emulation of /proc/cpuinfo easy, these names match the names of the
+// basic features in Linux defined in arch/x86/kernel/cpu/capflags.c.
+var x86FeatureStrings = map[Feature]string{
+ // Block 0.
+ X86FeatureSSE3: "pni",
+ X86FeaturePCLMULDQ: "pclmulqdq",
+ X86FeatureDTES64: "dtes64",
+ X86FeatureMONITOR: "monitor",
+ X86FeatureDSCPL: "ds_cpl",
+ X86FeatureVMX: "vmx",
+ X86FeatureSMX: "smx",
+ X86FeatureEST: "est",
+ X86FeatureTM2: "tm2",
+ X86FeatureSSSE3: "ssse3",
+ X86FeatureCNXTID: "cid",
+ X86FeatureSDBG: "sdbg",
+ X86FeatureFMA: "fma",
+ X86FeatureCX16: "cx16",
+ X86FeatureXTPR: "xtpr",
+ X86FeaturePDCM: "pdcm",
+ X86FeaturePCID: "pcid",
+ X86FeatureDCA: "dca",
+ X86FeatureSSE4_1: "sse4_1",
+ X86FeatureSSE4_2: "sse4_2",
+ X86FeatureX2APIC: "x2apic",
+ X86FeatureMOVBE: "movbe",
+ X86FeaturePOPCNT: "popcnt",
+ X86FeatureTSCD: "tsc_deadline_timer",
+ X86FeatureAES: "aes",
+ X86FeatureXSAVE: "xsave",
+ X86FeatureAVX: "avx",
+ X86FeatureF16C: "f16c",
+ X86FeatureRDRAND: "rdrand",
+
+ // Block 1.
+ X86FeatureFPU: "fpu",
+ X86FeatureVME: "vme",
+ X86FeatureDE: "de",
+ X86FeaturePSE: "pse",
+ X86FeatureTSC: "tsc",
+ X86FeatureMSR: "msr",
+ X86FeaturePAE: "pae",
+ X86FeatureMCE: "mce",
+ X86FeatureCX8: "cx8",
+ X86FeatureAPIC: "apic",
+ X86FeatureSEP: "sep",
+ X86FeatureMTRR: "mtrr",
+ X86FeaturePGE: "pge",
+ X86FeatureMCA: "mca",
+ X86FeatureCMOV: "cmov",
+ X86FeaturePAT: "pat",
+ X86FeaturePSE36: "pse36",
+ X86FeaturePSN: "pn",
+ X86FeatureCLFSH: "clflush",
+ X86FeatureDS: "dts",
+ X86FeatureACPI: "acpi",
+ X86FeatureMMX: "mmx",
+ X86FeatureFXSR: "fxsr",
+ X86FeatureSSE: "sse",
+ X86FeatureSSE2: "sse2",
+ X86FeatureSS: "ss",
+ X86FeatureHTT: "ht",
+ X86FeatureTM: "tm",
+ X86FeatureIA64: "ia64",
+ X86FeaturePBE: "pbe",
+
+ // Block 2.
+ X86FeatureFSGSBase: "fsgsbase",
+ X86FeatureTSC_ADJUST: "tsc_adjust",
+ X86FeatureBMI1: "bmi1",
+ X86FeatureHLE: "hle",
+ X86FeatureAVX2: "avx2",
+ X86FeatureSMEP: "smep",
+ X86FeatureBMI2: "bmi2",
+ X86FeatureERMS: "erms",
+ X86FeatureINVPCID: "invpcid",
+ X86FeatureRTM: "rtm",
+ X86FeatureCQM: "cqm",
+ X86FeatureMPX: "mpx",
+ X86FeatureRDT: "rdt_a",
+ X86FeatureAVX512F: "avx512f",
+ X86FeatureAVX512DQ: "avx512dq",
+ X86FeatureRDSEED: "rdseed",
+ X86FeatureADX: "adx",
+ X86FeatureSMAP: "smap",
+ X86FeatureCLWB: "clwb",
+ X86FeatureAVX512PF: "avx512pf",
+ X86FeatureAVX512ER: "avx512er",
+ X86FeatureAVX512CD: "avx512cd",
+ X86FeatureSHA: "sha_ni",
+ X86FeatureAVX512BW: "avx512bw",
+ X86FeatureAVX512VL: "avx512vl",
+
+ // Block 3.
+ X86FeatureAVX512VBMI: "avx512vbmi",
+ X86FeatureUMIP: "umip",
+ X86FeaturePKU: "pku",
+
+ // Block 4.
+ X86FeatureXSAVEOPT: "xsaveopt",
+ X86FeatureXSAVEC: "xsavec",
+ X86FeatureXGETBV1: "xgetbv1",
+ X86FeatureXSAVES: "xsaves",
+
+ // Block 5.
+ X86FeatureLAHF64: "lahf_lm", // LAHF/SAHF in long mode
+ X86FeatureCMP_LEGACY: "cmp_legacy",
+ X86FeatureSVM: "svm",
+ X86FeatureEXTAPIC: "extapic",
+ X86FeatureCR8_LEGACY: "cr8_legacy",
+ X86FeatureLZCNT: "abm", // Advanced bit manipulation
+ X86FeatureSSE4A: "sse4a",
+ X86FeatureMISALIGNSSE: "misalignsse",
+ X86FeaturePREFETCHW: "3dnowprefetch",
+ X86FeatureOSVW: "osvw",
+ X86FeatureIBS: "ibs",
+ X86FeatureXOP: "xop",
+ X86FeatureSKINIT: "skinit",
+ X86FeatureWDT: "wdt",
+ X86FeatureLWP: "lwp",
+ X86FeatureFMA4: "fma4",
+ X86FeatureTCE: "tce",
+ X86FeatureTBM: "tbm",
+ X86FeatureTOPOLOGY: "topoext",
+ X86FeaturePERFCTR_CORE: "perfctr_core",
+ X86FeaturePERFCTR_NB: "perfctr_nb",
+ X86FeatureBPEXT: "bpext",
+ X86FeaturePERFCTR_TSC: "ptsc",
+ X86FeaturePERFCTR_LLC: "perfctr_llc",
+ X86FeatureMWAITX: "mwaitx",
+
+ // Block 6.
+ X86FeatureSYSCALL: "syscall",
+ X86FeatureNX: "nx",
+ X86FeatureMMXEXT: "mmxext",
+ X86FeatureFXSR_OPT: "fxsr_opt",
+ X86FeatureGBPAGES: "pdpe1gb",
+ X86FeatureRDTSCP: "rdtscp",
+ X86FeatureLM: "lm",
+ X86Feature3DNOWEXT: "3dnowext",
+ X86Feature3DNOW: "3dnow",
+}
+
+// These flags are parse only---they can be used for setting / unsetting the
+// flags, but will not get printed out in /proc/cpuinfo.
+var x86FeatureParseOnlyStrings = map[Feature]string{
+ // Block 0.
+ X86FeatureOSXSAVE: "osxsave",
+
+ // Block 2.
+ X86FeatureFDP_EXCPTN_ONLY: "fdp_excptn_only",
+ X86FeatureFPCSDS: "fpcsds",
+ X86FeatureIPT: "pt",
+ X86FeatureCLFLUSHOPT: "clfushopt",
+
+ // Block 3.
+ X86FeaturePREFETCHWT1: "prefetchwt1",
+}
+
+// Just a way to wrap cpuid function numbers.
+type cpuidFunction uint32
+
+// The constants below are the lower or "standard" cpuid functions, ordered as
+// defined by the hardware.
+const (
+ vendorID cpuidFunction = iota // Returns vendor ID and largest standard function.
+ featureInfo // Returns basic feature bits and processor signature.
+ intelCacheDescriptors // Returns list of cache descriptors. Intel only.
+ intelSerialNumber // Returns processor serial number (obsolete on new hardware). Intel only.
+ intelDeterministicCacheParams // Returns deterministic cache information. Intel only.
+ monitorMwaitParams // Returns information about monitor/mwait instructions.
+ powerParams // Returns information about power management and thermal sensors.
+ extendedFeatureInfo // Returns extended feature bits.
+ _ // Function 0x8 is reserved.
+ intelDCAParams // Returns direct cache access information. Intel only.
+ intelPMCInfo // Returns information about performance monitoring features. Intel only.
+ intelX2APICInfo // Returns core/logical processor topology. Intel only.
+ _ // Function 0xc is reserved.
+ xSaveInfo // Returns information about extended state management.
+)
+
+// The "extended" functions start at 0x80000000.
+const (
+ extendedFunctionInfo cpuidFunction = 0x80000000 + iota // Returns highest available extended function in eax.
+ extendedFeatures // Returns some extended feature bits in edx and ecx.
+)
+
+// These are the extended floating point state features. They are used to
+// enumerate floating point features in XCR0, XSTATE_BV, etc.
+const (
+ XSAVEFeatureX87 = 1 << 0
+ XSAVEFeatureSSE = 1 << 1
+ XSAVEFeatureAVX = 1 << 2
+ XSAVEFeatureBNDREGS = 1 << 3
+ XSAVEFeatureBNDCSR = 1 << 4
+ XSAVEFeatureAVX512op = 1 << 5
+ XSAVEFeatureAVX512zmm0 = 1 << 6
+ XSAVEFeatureAVX512zmm16 = 1 << 7
+ XSAVEFeaturePKRU = 1 << 9
+)
+
+var cpuFreqMHz float64
+
+// x86FeaturesFromString includes features from x86FeatureStrings and
+// x86FeatureParseOnlyStrings.
+var x86FeaturesFromString = make(map[string]Feature)
+
+// FeatureFromString returns the Feature associated with the given feature
+// string plus a bool to indicate if it could find the feature.
+func FeatureFromString(s string) (Feature, bool) {
+ f, b := x86FeaturesFromString[s]
+ return f, b
+}
+
+// String implements fmt.Stringer.
+func (f Feature) String() string {
+ if s := f.flagString(false); s != "" {
+ return s
+ }
+
+ block := int(f) / 32
+ bit := int(f) % 32
+ return fmt.Sprintf("<cpuflag %d; block %d bit %d>", f, block, bit)
+}
+
+func (f Feature) flagString(cpuinfoOnly bool) string {
+ if s, ok := x86FeatureStrings[f]; ok {
+ return s
+ }
+ if !cpuinfoOnly {
+ return x86FeatureParseOnlyStrings[f]
+ }
+ return ""
+}
+
+// FeatureSet is a set of Features for a cpu.
+//
+// +stateify savable
+type FeatureSet struct {
+ // Set is the set of features that are enabled in this FeatureSet.
+ Set map[Feature]bool
+
+ // VendorID is the 12-char string returned in ebx:edx:ecx for eax=0.
+ VendorID string
+
+ // ExtendedFamily is part of the processor signature.
+ ExtendedFamily uint8
+
+ // ExtendedModel is part of the processor signature.
+ ExtendedModel uint8
+
+ // ProcessorType is part of the processor signature.
+ ProcessorType uint8
+
+ // Family is part of the processor signature.
+ Family uint8
+
+ // Model is part of the processor signature.
+ Model uint8
+
+ // SteppingID is part of the processor signature.
+ SteppingID uint8
+}
+
+// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
+// equivalent to the "flags" field in /proc/cpuinfo.
+func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
+ var s []string
+ for _, b := range linuxBlockOrder {
+ for i := 0; i < blockSize; i++ {
+ if f := featureID(b, i); fs.Set[f] {
+ if fstr := f.flagString(cpuinfoOnly); fstr != "" {
+ s = append(s, fstr)
+ }
+ }
+ }
+ }
+ return strings.Join(s, " ")
+}
+
+// CPUInfo is to generate a section of one cpu in /proc/cpuinfo. This is a
+// minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
+// not always printed in Linux. The bogomips field is simply made up.
+func (fs FeatureSet) CPUInfo(cpu uint) string {
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(&b, "vendor_id\t: %s\n", fs.VendorID)
+ fmt.Fprintf(&b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
+ fmt.Fprintf(&b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
+ fmt.Fprintf(&b, "model name\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(&b, "stepping\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(&b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
+ fmt.Fprintln(&b, "fpu\t\t: yes")
+ fmt.Fprintln(&b, "fpu_exception\t: yes")
+ fmt.Fprintf(&b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
+ fmt.Fprintln(&b, "wp\t\t: yes")
+ fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true))
+ fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(&b, "clflush size\t: %d\n", 64)
+ fmt.Fprintf(&b, "cache_alignment\t: %d\n", 64)
+ fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
+ fmt.Fprintln(&b, "power management:") // This is always here, but can be blank.
+ fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline.
+ return b.String()
+}
+
+// AMD returns true if fs describes an AMD CPU.
+func (fs *FeatureSet) AMD() bool {
+ return fs.VendorID == "AuthenticAMD"
+}
+
+// Intel returns true if fs describes an Intel CPU.
+func (fs *FeatureSet) Intel() bool {
+ return fs.VendorID == "GenuineIntel"
+}
+
+// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
+// subset of the host feature set.
+type ErrIncompatible struct {
+ message string
+}
+
+// Error implements error.
+func (e ErrIncompatible) Error() string {
+ return e.message
+}
+
+// CheckHostCompatible returns nil if fs is a subset of the host feature set.
+func (fs *FeatureSet) CheckHostCompatible() error {
+ hfs := HostFeatureSet()
+ if diff := fs.Subtract(hfs); diff != nil {
+ return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
+ }
+ return nil
+}
+
+// Helper to convert 3 regs into 12-byte vendor ID.
+func vendorIDFromRegs(bx, cx, dx uint32) string {
+ bytes := make([]byte, 0, 12)
+ for i := uint(0); i < 4; i++ {
+ b := byte(bx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ b := byte(dx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ b := byte(cx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+ return string(bytes)
+}
+
+// ExtendedStateSize returns the number of bytes needed to save the "extended
+// state" for this processor and the boundary it must be aligned to. Extended
+// state includes floating point registers, and other cpu state that's not
+// associated with the normal task context.
+//
+// Note: We can save some space here with an optimiazation where we use a
+// smaller chunk of memory depending on features that are actually enabled.
+// Currently we just use the largest possible size for simplicity (which is
+// about 2.5K worst case, with avx512).
+func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
+ if fs.UseXsave() {
+ // Leaf 0 of xsaveinfo function returns the size for currently
+ // enabled xsave features in ebx, the maximum size if all valid
+ // features are saved with xsave in ecx, and valid XCR0 bits in
+ // edx:eax.
+ _, _, maxSize, _ := HostID(uint32(xSaveInfo), 0)
+ return uint(maxSize), 64
+ }
+
+ // If we don't support xsave, we fall back to fxsave, which requires
+ // 512 bytes aligned to 16 bytes.
+ return 512, 16
+}
+
+// ValidXCR0Mask returns the bits that may be set to 1 in control register
+// XCR0.
+func (fs *FeatureSet) ValidXCR0Mask() uint64 {
+ if !fs.UseXsave() {
+ return 0
+ }
+ eax, _, _, edx := HostID(uint32(xSaveInfo), 0)
+ return uint64(edx)<<32 | uint64(eax)
+}
+
+// vendorIDRegs returns the 3 register values used to construct the 12-byte
+// vendor ID string for eax=0.
+func (fs *FeatureSet) vendorIDRegs() (bx, dx, cx uint32) {
+ for i := uint(0); i < 4; i++ {
+ bx |= uint32(fs.VendorID[i]) << (i * 8)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ dx |= uint32(fs.VendorID[i+4]) << (i * 8)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ cx |= uint32(fs.VendorID[i+8]) << (i * 8)
+ }
+ return
+}
+
+// signature returns the signature dword that's returned in eax when eax=1.
+func (fs *FeatureSet) signature() uint32 {
+ var s uint32
+ s |= uint32(fs.SteppingID & 0xf)
+ s |= uint32(fs.Model&0xf) << 4
+ s |= uint32(fs.Family&0xf) << 8
+ s |= uint32(fs.ProcessorType&0x3) << 12
+ s |= uint32(fs.ExtendedModel&0xf) << 16
+ s |= uint32(fs.ExtendedFamily&0xff) << 20
+ return s
+}
+
+// Helper to deconstruct signature dword.
+func signatureSplit(v uint32) (ef, em, pt, f, m, sid uint8) {
+ sid = uint8(v & 0xf)
+ m = uint8(v>>4) & 0xf
+ f = uint8(v>>8) & 0xf
+ pt = uint8(v>>12) & 0x3
+ em = uint8(v>>16) & 0xf
+ ef = uint8(v >> 20)
+ return
+}
+
+// Helper to convert blockwise feature bit masks into a set of features. Masks
+// must be provided in order for each block, without skipping them. If a block
+// does not matter for this feature set, 0 is specified.
+func setFromBlockMasks(blocks ...uint32) map[Feature]bool {
+ s := make(map[Feature]bool)
+ for b, blockMask := range blocks {
+ for i := 0; i < blockSize; i++ {
+ if blockMask&1 != 0 {
+ s[featureID(block(b), i)] = true
+ }
+ blockMask >>= 1
+ }
+ }
+ return s
+}
+
+// blockMask returns the 32-bit mask associated with a block of features.
+func (fs *FeatureSet) blockMask(b block) uint32 {
+ var mask uint32
+ for i := 0; i < blockSize; i++ {
+ if fs.Set[featureID(b, i)] {
+ mask |= 1 << uint(i)
+ }
+ }
+ return mask
+}
+
+// Remove removes a Feature from a FeatureSet. It ignores features
+// that are not in the FeatureSet.
+func (fs *FeatureSet) Remove(feature Feature) {
+ delete(fs.Set, feature)
+}
+
+// Add adds a Feature to a FeatureSet. It ignores duplicate features.
+func (fs *FeatureSet) Add(feature Feature) {
+ fs.Set[feature] = true
+}
+
+// HasFeature tests whether or not a feature is in the given feature set.
+func (fs *FeatureSet) HasFeature(feature Feature) bool {
+ return fs.Set[feature]
+}
+
+// IsSubset returns true if the FeatureSet is a subset of the FeatureSet passed in.
+// This is useful if you want to see if a FeatureSet is compatible with another
+// FeatureSet, since you can only run with a given FeatureSet if it's a subset of
+// the host's.
+func (fs *FeatureSet) IsSubset(other *FeatureSet) bool {
+ return fs.Subtract(other) == nil
+}
+
+// Subtract returns the features present in fs that are not present in other.
+// If all features in fs are present in other, Subtract returns nil.
+func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
+ for f := range fs.Set {
+ if !other.Set[f] {
+ if diff == nil {
+ diff = make(map[Feature]bool)
+ }
+ diff[f] = true
+ }
+ }
+
+ return
+}
+
+// TakeFeatureIntersection will set the features in `fs` to the intersection of
+// the features in `fs` and `other` (effectively clearing any feature bits on
+// `fs` that are not also set in `other`).
+func (fs *FeatureSet) TakeFeatureIntersection(other *FeatureSet) {
+ for f := range fs.Set {
+ if !other.Set[f] {
+ delete(fs.Set, f)
+ }
+ }
+}
+
+// EmulateID emulates a cpuid instruction based on the feature set.
+func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
+ switch cpuidFunction(origAx) {
+ case vendorID:
+ ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
+ bx, dx, cx = fs.vendorIDRegs()
+ case featureInfo:
+ // clflush line size (ebx bits[15:8]) hardcoded as 8. This
+ // means cache lines of size 64 bytes.
+ bx = 8 << 8
+ cx = fs.blockMask(block(0))
+ dx = fs.blockMask(block(1))
+ ax = fs.signature()
+ case intelCacheDescriptors:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // "The least-significant byte in register EAX (register AL)
+ // will always return 01H. Software should ignore this value
+ // and not interpret it as an informational descriptor." - SDM
+ //
+ // We do not support exposing cache information, but we do set
+ // this fixed field because some language runtimes (dlang) get
+ // confused by ax = 0 and will loop infinitely.
+ ax = 1
+ case xSaveInfo:
+ if !fs.UseXsave() {
+ return 0, 0, 0, 0
+ }
+ return HostID(uint32(xSaveInfo), origCx)
+ case extendedFeatureInfo:
+ if origCx != 0 {
+ break // Only leaf 0 is supported.
+ }
+ bx = fs.blockMask(block(2))
+ cx = fs.blockMask(block(3))
+ case extendedFunctionInfo:
+ // We only support showing the extended features.
+ ax = uint32(extendedFeatures)
+ cx = 0
+ case extendedFeatures:
+ cx = fs.blockMask(block(5))
+ dx = fs.blockMask(block(6))
+ if fs.AMD() {
+ // AMD duplicates some block 1 features in block 6.
+ dx |= fs.blockMask(block(1)) & block6DuplicateMask
+ }
+ }
+
+ return
+}
+
+// UseXsave returns the choice of fp state saving instruction.
+func (fs *FeatureSet) UseXsave() bool {
+ return fs.HasFeature(X86FeatureXSAVE) && fs.HasFeature(X86FeatureOSXSAVE)
+}
+
+// UseXsaveopt returns true if 'fs' supports the "xsaveopt" instruction.
+func (fs *FeatureSet) UseXsaveopt() bool {
+ return fs.UseXsave() && fs.HasFeature(X86FeatureXSAVEOPT)
+}
+
+// HostID executes a native CPUID instruction.
+func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32)
+
+// HostFeatureSet uses cpuid to get host values and construct a feature set
+// that matches that of the host machine. Note that there are several places
+// where there appear to be some unnecessary assignments between register names
+// (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show
+// where the different feature blocks come from, to make the code easier to
+// inspect and read.
+func HostFeatureSet() *FeatureSet {
+ // eax=0 gets max supported feature and vendor ID.
+ _, bx, cx, dx := HostID(0, 0)
+ vendorID := vendorIDFromRegs(bx, cx, dx)
+
+ // eax=1 gets basic features in ecx:edx.
+ ax, _, cx, dx := HostID(1, 0)
+ featureBlock0 := cx
+ featureBlock1 := dx
+ ef, em, pt, f, m, sid := signatureSplit(ax)
+
+ // eax=7, ecx=0 gets extended features in ecx:ebx.
+ _, bx, cx, _ = HostID(7, 0)
+ featureBlock2 := bx
+ featureBlock3 := cx
+
+ // Leaf 0xd is supported only if CPUID.1:ECX.XSAVE[bit 26] is set.
+ var featureBlock4 uint32
+ if (featureBlock0 & (1 << 26)) != 0 {
+ featureBlock4, _, _, _ = HostID(uint32(xSaveInfo), 1)
+ }
+
+ // eax=0x80000000 gets supported extended levels. We use this to
+ // determine if there are any non-zero block 4 or block 6 bits to find.
+ var featureBlock5, featureBlock6 uint32
+ if ax, _, _, _ := HostID(uint32(extendedFunctionInfo), 0); ax >= uint32(extendedFeatures) {
+ // eax=0x80000001 gets AMD added feature bits.
+ _, _, cx, dx = HostID(uint32(extendedFeatures), 0)
+ featureBlock5 = cx
+ // Ignore features duplicated from block 1 on AMD. These bits
+ // are reserved on Intel.
+ featureBlock6 = dx &^ block6DuplicateMask
+ }
+
+ set := setFromBlockMasks(featureBlock0, featureBlock1, featureBlock2, featureBlock3, featureBlock4, featureBlock5, featureBlock6)
+ return &FeatureSet{
+ Set: set,
+ VendorID: vendorID,
+ ExtendedFamily: ef,
+ ExtendedModel: em,
+ ProcessorType: pt,
+ Family: f,
+ Model: m,
+ SteppingID: sid,
+ }
+}
+
+// Reads max cpu frequency from host /proc/cpuinfo. Must run before
+// whitelisting. This value is used to create the fake /proc/cpuinfo from a
+// FeatureSet.
+func initCPUFreq() {
+ cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ // Leave it as 0... The standalone VDSO bails out in the same
+ // way.
+ log.Warningf("Could not read /proc/cpuinfo: %v", err)
+ return
+ }
+ cpuinfo := string(cpuinfob)
+
+ // We get the value straight from host /proc/cpuinfo. On machines with
+ // frequency scaling enabled, this will only get the current value
+ // which will likely be inaccurate. This is fine on machines with
+ // frequency scaling disabled.
+ for _, line := range strings.Split(cpuinfo, "\n") {
+ if strings.Contains(line, "cpu MHz") {
+ splitMHz := strings.Split(line, ":")
+ if len(splitMHz) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed cpu MHz line")
+ return
+ }
+
+ // If there was a problem, leave cpuFreqMHz as 0.
+ var err error
+ cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
+ if err != nil {
+ log.Warningf("Could not parse cpu MHz value %v: %v", splitMHz[1], err)
+ cpuFreqMHz = 0
+ return
+ }
+ return
+ }
+ }
+ log.Warningf("Could not parse /proc/cpuinfo, it is empty or does not contain cpu MHz")
+}
+
+func initFeaturesFromString() {
+ for f, s := range x86FeatureStrings {
+ x86FeaturesFromString[s] = f
+ }
+ for f, s := range x86FeatureParseOnlyStrings {
+ x86FeaturesFromString[s] = f
+ }
+}
+
+func init() {
+ // initCpuFreq must be run before whitelists are enabled.
+ initCPUFreq()
+ initFeaturesFromString()
+}
diff --git a/pkg/cpuid/cpuid_state_autogen.go b/pkg/cpuid/cpuid_state_autogen.go
new file mode 100755
index 000000000..59807a916
--- /dev/null
+++ b/pkg/cpuid/cpuid_state_autogen.go
@@ -0,0 +1,36 @@
+// automatically generated by stateify.
+
+package cpuid
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *FeatureSet) beforeSave() {}
+func (x *FeatureSet) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Set", &x.Set)
+ m.Save("VendorID", &x.VendorID)
+ m.Save("ExtendedFamily", &x.ExtendedFamily)
+ m.Save("ExtendedModel", &x.ExtendedModel)
+ m.Save("ProcessorType", &x.ProcessorType)
+ m.Save("Family", &x.Family)
+ m.Save("Model", &x.Model)
+ m.Save("SteppingID", &x.SteppingID)
+}
+
+func (x *FeatureSet) afterLoad() {}
+func (x *FeatureSet) load(m state.Map) {
+ m.Load("Set", &x.Set)
+ m.Load("VendorID", &x.VendorID)
+ m.Load("ExtendedFamily", &x.ExtendedFamily)
+ m.Load("ExtendedModel", &x.ExtendedModel)
+ m.Load("ProcessorType", &x.ProcessorType)
+ m.Load("Family", &x.Family)
+ m.Load("Model", &x.Model)
+ m.Load("SteppingID", &x.SteppingID)
+}
+
+func init() {
+ state.Register("cpuid.FeatureSet", (*FeatureSet)(nil), state.Fns{Save: (*FeatureSet).save, Load: (*FeatureSet).load})
+}
diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go
new file mode 100644
index 000000000..4c8ae573b
--- /dev/null
+++ b/pkg/eventchannel/event.go
@@ -0,0 +1,165 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventchannel contains functionality for sending any protobuf message
+// on a socketpair.
+//
+// The wire format is a uvarint length followed by a binary protobuf.Any
+// message.
+package eventchannel
+
+import (
+ "encoding/binary"
+ "fmt"
+ "sync"
+ "syscall"
+
+ "github.com/golang/protobuf/proto"
+ "github.com/golang/protobuf/ptypes"
+ pb "gvisor.googlesource.com/gvisor/pkg/eventchannel/eventchannel_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// Emitter emits a proto message.
+type Emitter interface {
+ // Emit writes a single eventchannel message to an emitter. Emit should
+ // return hangup = true to indicate an emitter has "hung up" and no further
+ // messages should be directed to it.
+ Emit(msg proto.Message) (hangup bool, err error)
+
+ // Close closes this emitter. Emit cannot be used after Close is called.
+ Close() error
+}
+
+var (
+ mu sync.Mutex
+ emitters = make(map[Emitter]struct{})
+)
+
+// Emit emits a message using all added emitters.
+func Emit(msg proto.Message) error {
+ mu.Lock()
+ defer mu.Unlock()
+
+ var err error
+ for e := range emitters {
+ hangup, eerr := e.Emit(msg)
+ if eerr != nil {
+ if err == nil {
+ err = fmt.Errorf("error emitting %v: on %v: %v", msg, e, eerr)
+ } else {
+ err = fmt.Errorf("%v; on %v: %v", err, e, eerr)
+ }
+
+ // Log as well, since most callers ignore the error.
+ log.Warningf("Error emitting %v on %v: %v", msg, e, eerr)
+ }
+ if hangup {
+ log.Infof("Hangup on eventchannel emitter %v.", e)
+ delete(emitters, e)
+ }
+ }
+
+ return err
+}
+
+// AddEmitter adds a new emitter.
+func AddEmitter(e Emitter) {
+ mu.Lock()
+ defer mu.Unlock()
+ emitters[e] = struct{}{}
+}
+
+func marshal(msg proto.Message) ([]byte, error) {
+ anypb, err := ptypes.MarshalAny(msg)
+ if err != nil {
+ return nil, err
+ }
+
+ // Wire format is uvarint message length followed by binary proto.
+ bufMsg, err := proto.Marshal(anypb)
+ if err != nil {
+ return nil, err
+ }
+ p := make([]byte, binary.MaxVarintLen64)
+ n := binary.PutUvarint(p, uint64(len(bufMsg)))
+ return append(p[:n], bufMsg...), nil
+}
+
+// socketEmitter emits proto messages on a socket.
+type socketEmitter struct {
+ socket *unet.Socket
+}
+
+// SocketEmitter creates a new event channel based on the given fd.
+//
+// SocketEmitter takes ownership of fd.
+func SocketEmitter(fd int) (Emitter, error) {
+ s, err := unet.NewSocket(fd)
+ if err != nil {
+ return nil, err
+ }
+
+ return &socketEmitter{
+ socket: s,
+ }, nil
+}
+
+// Emit implements Emitter.Emit.
+func (s *socketEmitter) Emit(msg proto.Message) (bool, error) {
+ p, err := marshal(msg)
+ if err != nil {
+ return false, err
+ }
+ for done := 0; done < len(p); {
+ n, err := s.socket.Write(p[done:])
+ if err != nil {
+ return (err == syscall.EPIPE), err
+ }
+ done += n
+ }
+ return false, nil
+}
+
+// Close implements Emitter.Emit.
+func (s *socketEmitter) Close() error {
+ return s.socket.Close()
+}
+
+// debugEmitter wraps an emitter to emit stringified event messages. This is
+// useful for debugging -- when the messages are intended for humans.
+type debugEmitter struct {
+ inner Emitter
+}
+
+// DebugEmitterFrom creates a new event channel emitter by wraping an existing
+// raw emitter.
+func DebugEmitterFrom(inner Emitter) Emitter {
+ return &debugEmitter{
+ inner: inner,
+ }
+}
+
+func (d *debugEmitter) Emit(msg proto.Message) (bool, error) {
+ ev := &pb.DebugEvent{
+ Name: proto.MessageName(msg),
+ Text: proto.MarshalTextString(msg),
+ }
+ return d.inner.Emit(ev)
+}
+
+func (d *debugEmitter) Close() error {
+ return d.inner.Close()
+}
diff --git a/pkg/eventchannel/eventchannel_go_proto/event.pb.go b/pkg/eventchannel/eventchannel_go_proto/event.pb.go
new file mode 100755
index 000000000..bb71ed3e6
--- /dev/null
+++ b/pkg/eventchannel/eventchannel_go_proto/event.pb.go
@@ -0,0 +1,85 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: pkg/eventchannel/event.proto
+
+package gvisor
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type DebugEvent struct {
+ Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
+ Text string `protobuf:"bytes,2,opt,name=text,proto3" json:"text,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *DebugEvent) Reset() { *m = DebugEvent{} }
+func (m *DebugEvent) String() string { return proto.CompactTextString(m) }
+func (*DebugEvent) ProtoMessage() {}
+func (*DebugEvent) Descriptor() ([]byte, []int) {
+ return fileDescriptor_fcfbd51abd9de962, []int{0}
+}
+
+func (m *DebugEvent) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_DebugEvent.Unmarshal(m, b)
+}
+func (m *DebugEvent) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_DebugEvent.Marshal(b, m, deterministic)
+}
+func (m *DebugEvent) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_DebugEvent.Merge(m, src)
+}
+func (m *DebugEvent) XXX_Size() int {
+ return xxx_messageInfo_DebugEvent.Size(m)
+}
+func (m *DebugEvent) XXX_DiscardUnknown() {
+ xxx_messageInfo_DebugEvent.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_DebugEvent proto.InternalMessageInfo
+
+func (m *DebugEvent) GetName() string {
+ if m != nil {
+ return m.Name
+ }
+ return ""
+}
+
+func (m *DebugEvent) GetText() string {
+ if m != nil {
+ return m.Text
+ }
+ return ""
+}
+
+func init() {
+ proto.RegisterType((*DebugEvent)(nil), "gvisor.DebugEvent")
+}
+
+func init() { proto.RegisterFile("pkg/eventchannel/event.proto", fileDescriptor_fcfbd51abd9de962) }
+
+var fileDescriptor_fcfbd51abd9de962 = []byte{
+ // 103 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x29, 0xc8, 0x4e, 0xd7,
+ 0x4f, 0x2d, 0x4b, 0xcd, 0x2b, 0x49, 0xce, 0x48, 0xcc, 0xcb, 0x4b, 0xcd, 0x81, 0x70, 0xf4, 0x0a,
+ 0x8a, 0xf2, 0x4b, 0xf2, 0x85, 0xd8, 0xd2, 0xcb, 0x32, 0x8b, 0xf3, 0x8b, 0x94, 0x4c, 0xb8, 0xb8,
+ 0x5c, 0x52, 0x93, 0x4a, 0xd3, 0x5d, 0x41, 0x72, 0x42, 0x42, 0x5c, 0x2c, 0x79, 0x89, 0xb9, 0xa9,
+ 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x60, 0x36, 0x48, 0xac, 0x24, 0xb5, 0xa2, 0x44, 0x82,
+ 0x09, 0x22, 0x06, 0x62, 0x27, 0xb1, 0x81, 0x0d, 0x31, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x17,
+ 0xee, 0x7f, 0xef, 0x64, 0x00, 0x00, 0x00,
+}
diff --git a/pkg/eventchannel/eventchannel_state_autogen.go b/pkg/eventchannel/eventchannel_state_autogen.go
new file mode 100755
index 000000000..cfd3a5e43
--- /dev/null
+++ b/pkg/eventchannel/eventchannel_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package eventchannel
+
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
new file mode 100644
index 000000000..2785243a2
--- /dev/null
+++ b/pkg/fd/fd.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fd provides types for working with file descriptors.
+package fd
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+)
+
+// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
+// does not take ownership of fd.
+type ReadWriter struct {
+ // fd is accessed atomically so FD.Close/Release can swap it.
+ fd int64
+}
+
+var _ io.ReadWriter = (*ReadWriter)(nil)
+var _ io.ReaderAt = (*ReadWriter)(nil)
+var _ io.WriterAt = (*ReadWriter)(nil)
+
+// NewReadWriter creates a ReadWriter for fd.
+func NewReadWriter(fd int) *ReadWriter {
+ return &ReadWriter{int64(fd)}
+}
+
+func fixCount(n int, err error) (int, error) {
+ if n < 0 {
+ n = 0
+ }
+ return n, err
+}
+
+// Read implements io.Reader.
+func (r *ReadWriter) Read(b []byte) (int, error) {
+ c, err := fixCount(syscall.Read(int(atomic.LoadInt64(&r.fd)), b))
+ if c == 0 && len(b) > 0 && err == nil {
+ return 0, io.EOF
+ }
+ return c, err
+}
+
+// ReadAt implements io.ReaderAt.
+//
+// ReadAt always returns a non-nil error when c < len(b).
+func (r *ReadWriter) ReadAt(b []byte, off int64) (c int, err error) {
+ for len(b) > 0 {
+ var m int
+ m, err = fixCount(syscall.Pread(int(atomic.LoadInt64(&r.fd)), b, off))
+ if m == 0 && err == nil {
+ return c, io.EOF
+ }
+ if err != nil {
+ return c, err
+ }
+ c += m
+ b = b[m:]
+ off += int64(m)
+ }
+ return
+}
+
+// Write implements io.Writer.
+func (r *ReadWriter) Write(b []byte) (int, error) {
+ var err error
+ var n, remaining int
+ for remaining = len(b); remaining > 0; {
+ woff := len(b) - remaining
+ n, err = syscall.Write(int(atomic.LoadInt64(&r.fd)), b[woff:])
+
+ if n > 0 {
+ // syscall.Write wrote some bytes. This is the common case.
+ remaining -= n
+ } else {
+ if err == nil {
+ // syscall.Write did not write anything nor did it return an error.
+ //
+ // There is no way to guarantee that a subsequent syscall.Write will
+ // make forward progress so just panic.
+ panic(fmt.Sprintf("syscall.Write returned %d with no error", n))
+ }
+
+ if err != syscall.EINTR {
+ // If the write failed for anything other than a signal, bail out.
+ break
+ }
+ }
+ }
+
+ return len(b) - remaining, err
+}
+
+// WriteAt implements io.WriterAt.
+func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) {
+ for len(b) > 0 {
+ var m int
+ m, err = fixCount(syscall.Pwrite(int(atomic.LoadInt64(&r.fd)), b, off))
+ if err != nil {
+ break
+ }
+ c += m
+ b = b[m:]
+ off += int64(m)
+ }
+ return
+}
+
+// FD owns a host file descriptor.
+//
+// It is similar to os.File, with a few important distinctions:
+//
+// FD provies a Release() method which relinquishes ownership. Like os.File,
+// FD adds a finalizer to close the backing FD. However, the finalizer cannot
+// be removed from os.File, forever pinning the lifetime of an FD to its
+// os.File.
+//
+// FD supports both blocking and non-blocking operation. os.File only
+// supports blocking operation.
+type FD struct {
+ ReadWriter
+}
+
+// New creates a new FD.
+//
+// New takes ownership of fd.
+func New(fd int) *FD {
+ if fd < 0 {
+ return &FD{ReadWriter{-1}}
+ }
+ f := &FD{ReadWriter{int64(fd)}}
+ runtime.SetFinalizer(f, (*FD).Close)
+ return f
+}
+
+// NewFromFile creates a new FD from an os.File.
+//
+// NewFromFile does not transfer ownership of the file descriptor (it will be
+// duplicated, so both the os.File and FD will eventually need to be closed
+// and some (but not all) changes made to the FD will be applied to the
+// os.File as well).
+//
+// The returned FD is always blocking (Go 1.9+).
+func NewFromFile(file *os.File) (*FD, error) {
+ fd, err := syscall.Dup(int(file.Fd()))
+ // Technically, the runtime may call the finalizer on file as soon as
+ // Fd() returns.
+ runtime.KeepAlive(file)
+ if err != nil {
+ return &FD{ReadWriter{-1}}, err
+ }
+ return New(fd), nil
+}
+
+// Open is equivallent to open(2).
+func Open(path string, openmode int, perm uint32) (*FD, error) {
+ f, err := syscall.Open(path, openmode|syscall.O_LARGEFILE, perm)
+ if err != nil {
+ return nil, err
+ }
+ return New(f), nil
+}
+
+// OpenAt is equivallent to openat(2).
+func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
+ f, err := syscall.Openat(dir.FD(), path, flags, mode)
+ if err != nil {
+ return nil, err
+ }
+ return New(f), nil
+}
+
+// Close closes the file descriptor contained in the FD.
+//
+// Close is safe to call multiple times, but will return an error after the
+// first call.
+//
+// Concurrently calling Close and any other method is undefined.
+func (f *FD) Close() error {
+ runtime.SetFinalizer(f, nil)
+ return syscall.Close(int(atomic.SwapInt64(&f.fd, -1)))
+}
+
+// Release relinquishes ownership of the contained file descriptor.
+//
+// Concurrently calling Release and any other method is undefined.
+func (f *FD) Release() int {
+ runtime.SetFinalizer(f, nil)
+ return int(atomic.SwapInt64(&f.fd, -1))
+}
+
+// FD returns the file descriptor owned by FD. FD retains ownership.
+func (f *FD) FD() int {
+ return int(atomic.LoadInt64(&f.fd))
+}
+
+// File converts the FD to an os.File.
+//
+// FD does not transfer ownership of the file descriptor (it will be
+// duplicated, so both the FD and os.File will eventually need to be closed
+// and some (but not all) changes made to the os.File will be applied to the
+// FD as well).
+//
+// This operation is somewhat expensive, so care should be taken to minimize
+// its use.
+func (f *FD) File() (*os.File, error) {
+ fd, err := syscall.Dup(int(atomic.LoadInt64(&f.fd)))
+ if err != nil {
+ return nil, err
+ }
+ return os.NewFile(uintptr(fd), ""), nil
+}
+
+// ReleaseToFile returns an os.File that takes ownership of the FD.
+//
+// name is passed to os.NewFile.
+func (f *FD) ReleaseToFile(name string) *os.File {
+ return os.NewFile(uintptr(f.Release()), name)
+}
diff --git a/pkg/fd/fd_state_autogen.go b/pkg/fd/fd_state_autogen.go
new file mode 100755
index 000000000..0320140b0
--- /dev/null
+++ b/pkg/fd/fd_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package fd
+
diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go
new file mode 100644
index 000000000..f0b028b0b
--- /dev/null
+++ b/pkg/fdnotifier/fdnotifier.go
@@ -0,0 +1,202 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package fdnotifier contains an adapter that translates IO events (e.g., a
+// file became readable/writable) from native FDs to the notifications in the
+// waiter package. It uses epoll in edge-triggered mode to receive notifications
+// for registered FDs.
+package fdnotifier
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type fdInfo struct {
+ queue *waiter.Queue
+ waiting bool
+}
+
+// notifier holds all the state necessary to issue notifications when IO events
+// occur in the observed FDs.
+type notifier struct {
+ // epFD is the epoll file descriptor used to register for io
+ // notifications.
+ epFD int
+
+ // mu protects fdMap.
+ mu sync.Mutex
+
+ // fdMap maps file descriptors to their notification queues and waiting
+ // status.
+ fdMap map[int32]*fdInfo
+}
+
+// newNotifier creates a new notifier object.
+func newNotifier() (*notifier, error) {
+ epfd, err := syscall.EpollCreate1(0)
+ if err != nil {
+ return nil, err
+ }
+
+ w := &notifier{
+ epFD: epfd,
+ fdMap: make(map[int32]*fdInfo),
+ }
+
+ go w.waitAndNotify() // S/R-SAFE: no waiter exists during save / load.
+
+ return w, nil
+}
+
+// waitFD waits on mask for fd. The fdMap mutex must be hold.
+func (n *notifier) waitFD(fd int32, fi *fdInfo, mask waiter.EventMask) error {
+ if !fi.waiting && mask == 0 {
+ return nil
+ }
+
+ e := syscall.EpollEvent{
+ Events: mask.ToLinux() | -syscall.EPOLLET,
+ Fd: fd,
+ }
+
+ switch {
+ case !fi.waiting && mask != 0:
+ if err := syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_ADD, int(fd), &e); err != nil {
+ return err
+ }
+ fi.waiting = true
+ case fi.waiting && mask == 0:
+ syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_DEL, int(fd), nil)
+ fi.waiting = false
+ case fi.waiting && mask != 0:
+ if err := syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_MOD, int(fd), &e); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// addFD adds an FD to the list of FDs observed by n.
+func (n *notifier) addFD(fd int32, queue *waiter.Queue) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Panic if we're already notifying on this FD.
+ if _, ok := n.fdMap[fd]; ok {
+ panic(fmt.Sprintf("File descriptor %v added twice", fd))
+ }
+
+ // We have nothing to wait for at the moment. Just add it to the map.
+ n.fdMap[fd] = &fdInfo{queue: queue}
+}
+
+// updateFD updates the set of events the fd needs to be notified on.
+func (n *notifier) updateFD(fd int32) error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if fi, ok := n.fdMap[fd]; ok {
+ return n.waitFD(fd, fi, fi.queue.Events())
+ }
+
+ return nil
+}
+
+// RemoveFD removes an FD from the list of FDs observed by n.
+func (n *notifier) removeFD(fd int32) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Remove from map, then from epoll object.
+ n.waitFD(fd, n.fdMap[fd], 0)
+ delete(n.fdMap, fd)
+}
+
+// hasFD returns true if the fd is in the list of observed FDs.
+func (n *notifier) hasFD(fd int32) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ _, ok := n.fdMap[fd]
+ return ok
+}
+
+// waitAndNotify run is its own goroutine and loops waiting for io event
+// notifications from the epoll object. Once notifications arrive, they are
+// dispatched to the registered queue.
+func (n *notifier) waitAndNotify() error {
+ e := make([]syscall.EpollEvent, 100)
+ for {
+ v, err := epollWait(n.epFD, e, -1)
+ if err == syscall.EINTR {
+ continue
+ }
+
+ if err != nil {
+ return err
+ }
+
+ n.mu.Lock()
+ for i := 0; i < v; i++ {
+ if fi, ok := n.fdMap[e[i].Fd]; ok {
+ fi.queue.Notify(waiter.EventMaskFromLinux(e[i].Events))
+ }
+ }
+ n.mu.Unlock()
+ }
+}
+
+var shared struct {
+ notifier *notifier
+ once sync.Once
+ initErr error
+}
+
+// AddFD adds an FD to the list of observed FDs.
+func AddFD(fd int32, queue *waiter.Queue) error {
+ shared.once.Do(func() {
+ shared.notifier, shared.initErr = newNotifier()
+ })
+
+ if shared.initErr != nil {
+ return shared.initErr
+ }
+
+ shared.notifier.addFD(fd, queue)
+ return nil
+}
+
+// UpdateFD updates the set of events the fd needs to be notified on.
+func UpdateFD(fd int32) error {
+ return shared.notifier.updateFD(fd)
+}
+
+// RemoveFD removes an FD from the list of observed FDs.
+func RemoveFD(fd int32) {
+ shared.notifier.removeFD(fd)
+}
+
+// HasFD returns true if the FD is in the list of observed FDs.
+//
+// This should only be used by tests to assert that FDs are correctly registered.
+func HasFD(fd int32) bool {
+ return shared.notifier.hasFD(fd)
+}
diff --git a/pkg/fdnotifier/fdnotifier_state_autogen.go b/pkg/fdnotifier/fdnotifier_state_autogen.go
new file mode 100755
index 000000000..6f6076b7b
--- /dev/null
+++ b/pkg/fdnotifier/fdnotifier_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package fdnotifier
+
diff --git a/pkg/fdnotifier/poll_unsafe.go b/pkg/fdnotifier/poll_unsafe.go
new file mode 100644
index 000000000..bc5e0ac44
--- /dev/null
+++ b/pkg/fdnotifier/poll_unsafe.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdnotifier
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// NonBlockingPoll polls the given FD in non-blocking fashion. It is used just
+// to query the FD's current state.
+func NonBlockingPoll(fd int32, mask waiter.EventMask) waiter.EventMask {
+ e := struct {
+ fd int32
+ events int16
+ revents int16
+ }{
+ fd: fd,
+ events: int16(mask.ToLinux()),
+ }
+
+ for {
+ n, _, err := syscall.RawSyscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(&e)), 1, 0)
+ // Interrupted by signal, try again.
+ if err == syscall.EINTR {
+ continue
+ }
+ // If an error occur we'll conservatively say the FD is ready for
+ // whatever is being checked.
+ if err != 0 {
+ return mask
+ }
+
+ // If no FDs were returned, it wasn't ready for anything.
+ if n == 0 {
+ return 0
+ }
+
+ // Otherwise we got the ready events in the revents field.
+ return waiter.EventMaskFromLinux(uint32(e.revents))
+ }
+}
+
+// epollWait performs a blocking wait on epfd.
+//
+// Preconditions:
+// * len(events) > 0
+func epollWait(epfd int, events []syscall.EpollEvent, msec int) (int, error) {
+ if len(events) == 0 {
+ panic("Empty events passed to EpollWait")
+ }
+
+ // We actually use epoll_pwait with NULL sigmask instead of epoll_wait
+ // since that is what the Go >= 1.11 runtime prefers.
+ r, _, e := syscall.Syscall6(syscall.SYS_EPOLL_PWAIT, uintptr(epfd), uintptr(unsafe.Pointer(&events[0])), uintptr(len(events)), uintptr(msec), 0, 0)
+ if e != 0 {
+ return 0, e
+ }
+ return int(r), nil
+}
diff --git a/pkg/gate/gate.go b/pkg/gate/gate.go
new file mode 100644
index 000000000..bda6aae09
--- /dev/null
+++ b/pkg/gate/gate.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gate provides a usage Gate synchronization primitive.
+package gate
+
+import (
+ "sync/atomic"
+)
+
+const (
+ // gateClosed is the bit set in the gate's user count to indicate that
+ // it has been closed. It is the MSB of the 32-bit field; the other 31
+ // bits carry the actual count.
+ gateClosed = 0x80000000
+)
+
+// Gate is a synchronization primitive that allows concurrent goroutines to
+// "enter" it as long as it hasn't been closed yet. Once it's been closed,
+// goroutines cannot enter it anymore, but are allowed to leave, and the closer
+// will be informed when all goroutines have left.
+//
+// Many goroutines are allowed to enter the gate concurrently, but only one is
+// allowed to close it.
+//
+// This is similar to a r/w critical section, except that goroutines "entering"
+// never block: they either enter immediately or fail to enter. The closer will
+// block waiting for all goroutines currently inside the gate to leave.
+//
+// This function is implemented efficiently. On x86, only one interlocked
+// operation is performed on enter, and one on leave.
+//
+// This is useful, for example, in cases when a goroutine is trying to clean up
+// an object for which multiple goroutines have pointers. In such a case, users
+// would be required to enter and leave the gates, and the cleaner would wait
+// until all users are gone (and no new ones are allowed) before proceeding.
+//
+// Users:
+//
+// if !g.Enter() {
+// // Gate is closed, we can't use the object.
+// return
+// }
+//
+// // Do something with object.
+// [...]
+//
+// g.Leave()
+//
+// Closer:
+//
+// // Prevent new users from using the object, and wait for the existing
+// // ones to complete.
+// g.Close()
+//
+// // Clean up the object.
+// [...]
+//
+type Gate struct {
+ userCount uint32
+ done chan struct{}
+}
+
+// Enter tries to enter the gate. It will succeed if it hasn't been closed yet,
+// in which case the caller must eventually call Leave().
+//
+// This function is thread-safe.
+func (g *Gate) Enter() bool {
+ if g == nil {
+ return false
+ }
+
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&gateClosed != 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v+1) {
+ return true
+ }
+ }
+}
+
+// Leave leaves the gate. This must only be called after a successful call to
+// Enter(). If the gate has been closed and this is the last one inside the
+// gate, it will notify the closer that the gate is done.
+//
+// This function is thread-safe.
+func (g *Gate) Leave() {
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&^gateClosed == 0 {
+ panic("leaving a gate with zero usage count")
+ }
+
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v-1) {
+ if v == gateClosed+1 {
+ close(g.done)
+ }
+ return
+ }
+ }
+}
+
+// Close closes the gate for entering, and waits until all goroutines [that are
+// currently inside the gate] leave before returning.
+//
+// Only one goroutine can call this function.
+func (g *Gate) Close() {
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&^gateClosed != 0 && g.done == nil {
+ g.done = make(chan struct{})
+ }
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v|gateClosed) {
+ if v&^gateClosed != 0 {
+ <-g.done
+ }
+ return
+ }
+ }
+}
diff --git a/pkg/gate/gate_state_autogen.go b/pkg/gate/gate_state_autogen.go
new file mode 100755
index 000000000..a81fca776
--- /dev/null
+++ b/pkg/gate/gate_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package gate
+
diff --git a/pkg/ilist/ilist_state_autogen.go b/pkg/ilist/ilist_state_autogen.go
new file mode 100755
index 000000000..18a239fcf
--- /dev/null
+++ b/pkg/ilist/ilist_state_autogen.go
@@ -0,0 +1,38 @@
+// automatically generated by stateify.
+
+package ilist
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *List) beforeSave() {}
+func (x *List) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *List) afterLoad() {}
+func (x *List) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *Entry) beforeSave() {}
+func (x *Entry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *Entry) afterLoad() {}
+func (x *Entry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("ilist.List", (*List)(nil), state.Fns{Save: (*List).save, Load: (*List).load})
+ state.Register("ilist.Entry", (*Entry)(nil), state.Fns{Save: (*Entry).save, Load: (*Entry).load})
+}
diff --git a/pkg/ilist/interface_list.go b/pkg/ilist/interface_list.go
new file mode 100755
index 000000000..940c2d3f6
--- /dev/null
+++ b/pkg/ilist/interface_list.go
@@ -0,0 +1,192 @@
+package ilist
+
+// Linker is the interface that objects must implement if they want to be added
+// to and/or removed from List objects.
+//
+// N.B. When substituted in a template instantiation, Linker doesn't need to
+// be an interface, and in most cases won't be.
+type Linker interface {
+ Next() Element
+ Prev() Element
+ SetNext(Element)
+ SetPrev(Element)
+}
+
+// Element the item that is used at the API level.
+//
+// N.B. Like Linker, this is unlikely to be an interface in most cases.
+type Element interface {
+ Linker
+}
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type ElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (ElementMapper) linkerFor(elem Element) Linker { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type List struct {
+ head Element
+ tail Element
+}
+
+// Reset resets list l to the empty state.
+func (l *List) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *List) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *List) Front() Element {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *List) Back() Element {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *List) PushFront(e Element) {
+ ElementMapper{}.linkerFor(e).SetNext(l.head)
+ ElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ ElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *List) PushBack(e Element) {
+ ElementMapper{}.linkerFor(e).SetNext(nil)
+ ElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ ElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *List) PushBackList(m *List) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ ElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ ElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *List) InsertAfter(b, e Element) {
+ a := ElementMapper{}.linkerFor(b).Next()
+ ElementMapper{}.linkerFor(e).SetNext(a)
+ ElementMapper{}.linkerFor(e).SetPrev(b)
+ ElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ ElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *List) InsertBefore(a, e Element) {
+ b := ElementMapper{}.linkerFor(a).Prev()
+ ElementMapper{}.linkerFor(e).SetNext(a)
+ ElementMapper{}.linkerFor(e).SetPrev(b)
+ ElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ ElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *List) Remove(e Element) {
+ prev := ElementMapper{}.linkerFor(e).Prev()
+ next := ElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ ElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ ElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type Entry struct {
+ next Element
+ prev Element
+}
+
+// Next returns the entry that follows e in the list.
+func (e *Entry) Next() Element {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *Entry) Prev() Element {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *Entry) SetNext(elem Element) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *Entry) SetPrev(elem Element) {
+ e.prev = elem
+}
diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go
new file mode 100644
index 000000000..cd6e4e2ce
--- /dev/null
+++ b/pkg/linewriter/linewriter.go
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package linewriter provides an io.Writer which calls an emitter on each line.
+package linewriter
+
+import (
+ "bytes"
+ "sync"
+)
+
+// Writer is an io.Writer which buffers input, flushing
+// individual lines through an emitter function.
+type Writer struct {
+ // the mutex locks buf.
+ sync.Mutex
+
+ // buf holds the data we haven't emitted yet.
+ buf bytes.Buffer
+
+ // emit is used to flush individual lines.
+ emit func(p []byte)
+}
+
+// NewWriter creates a Writer which emits using emitter.
+// The emitter must not retain p. It may change after emitter returns.
+func NewWriter(emitter func(p []byte)) *Writer {
+ return &Writer{emit: emitter}
+}
+
+// Write implements io.Writer.Write.
+// It calls emit on each line of input, not including the newline.
+// Write may be called concurrently.
+func (w *Writer) Write(p []byte) (int, error) {
+ w.Lock()
+ defer w.Unlock()
+
+ total := 0
+ for len(p) > 0 {
+ emit := true
+ i := bytes.IndexByte(p, '\n')
+ if i < 0 {
+ // No newline, we will buffer everything.
+ i = len(p)
+ emit = false
+ }
+
+ n, err := w.buf.Write(p[:i])
+ if err != nil {
+ return total, err
+ }
+ total += n
+
+ p = p[i:]
+
+ if emit {
+ // Skip the newline, but still count it.
+ p = p[1:]
+ total++
+
+ w.emit(w.buf.Bytes())
+ w.buf.Reset()
+ }
+ }
+
+ return total, nil
+}
diff --git a/pkg/linewriter/linewriter_state_autogen.go b/pkg/linewriter/linewriter_state_autogen.go
new file mode 100755
index 000000000..194088d76
--- /dev/null
+++ b/pkg/linewriter/linewriter_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package linewriter
+
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
new file mode 100644
index 000000000..5732785b4
--- /dev/null
+++ b/pkg/log/glog.go
@@ -0,0 +1,163 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "os"
+ "time"
+)
+
+// GoogleEmitter is a wrapper that emits logs in a format compatible with
+// package github.com/golang/glog.
+type GoogleEmitter struct {
+ // Emitter is the underlying emitter.
+ Emitter
+}
+
+// buffer is a simple inline buffer to avoid churn. The data slice is generally
+// kept to the local byte array, and we avoid having to allocate it on the heap.
+type buffer struct {
+ local [256]byte
+ data []byte
+}
+
+func (b *buffer) start() {
+ b.data = b.local[:0]
+}
+
+func (b *buffer) String() string {
+ return unsafeString(b.data)
+}
+
+func (b *buffer) write(c byte) {
+ b.data = append(b.data, c)
+}
+
+func (b *buffer) writeAll(d []byte) {
+ b.data = append(b.data, d...)
+}
+
+func (b *buffer) writeOneDigit(d byte) {
+ b.write('0' + d)
+}
+
+func (b *buffer) writeTwoDigits(v int) {
+ v = v % 100
+ b.writeOneDigit(byte(v / 10))
+ b.writeOneDigit(byte(v % 10))
+}
+
+func (b *buffer) writeSixDigits(v int) {
+ v = v % 1000000
+ b.writeOneDigit(byte(v / 100000))
+ b.writeOneDigit(byte((v % 100000) / 10000))
+ b.writeOneDigit(byte((v % 10000) / 1000))
+ b.writeOneDigit(byte((v % 1000) / 100))
+ b.writeOneDigit(byte((v % 100) / 10))
+ b.writeOneDigit(byte(v % 10))
+}
+
+func calculateBytes(v int, pad int) []byte {
+ var d []byte
+ r := 1
+
+ for n := 10; v >= r; n = n * 10 {
+ d = append(d, '0'+byte((v%n)/r))
+ r = n
+ }
+
+ for i := len(d); i < pad; i++ {
+ d = append(d, ' ')
+ }
+
+ for i := 0; i < len(d)/2; i++ {
+ d[i], d[len(d)-(i+1)] = d[len(d)-(i+1)], d[i]
+ }
+ return d
+}
+
+// pid is used for the threadid component of the header.
+//
+// The glog package logger uses 7 spaces of padding. See
+// glob.loggingT.formatHeader.
+var pid = calculateBytes(os.Getpid(), 7)
+
+// caller is faked out as the caller. See FIXME below.
+var caller = []byte("x:0")
+
+// Emit emits the message, google-style.
+func (g GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+ var b buffer
+ b.start()
+
+ // Log lines have this form:
+ // Lmmdd hh:mm:ss.uuuuuu threadid file:line] msg...
+ //
+ // where the fields are defined as follows:
+ // L A single character, representing the log level (eg 'I' for INFO)
+ // mm The month (zero padded; ie May is '05')
+ // dd The day (zero padded)
+ // hh:mm:ss.uuuuuu Time in hours, minutes and fractional seconds
+ // threadid The space-padded thread ID as returned by GetTID()
+ // file The file name
+ // line The line number
+ // msg The user-supplied message
+
+ // Log level.
+ switch level {
+ case Debug:
+ b.write('D')
+ case Info:
+ b.write('I')
+ case Warning:
+ b.write('W')
+ }
+
+ // Timestamp.
+ _, month, day := timestamp.Date()
+ hour, minute, second := timestamp.Clock()
+ b.writeTwoDigits(int(month))
+ b.writeTwoDigits(int(day))
+ b.write(' ')
+ b.writeTwoDigits(int(hour))
+ b.write(':')
+ b.writeTwoDigits(int(minute))
+ b.write(':')
+ b.writeTwoDigits(int(second))
+ b.write('.')
+ b.writeSixDigits(int(timestamp.Nanosecond() / 1000))
+ b.write(' ')
+
+ // The pid.
+ b.writeAll(pid)
+ b.write(' ')
+
+ // FIXME(b/73383460): The caller, fabricated. This really sucks, but it
+ // is unacceptable to put runtime.Callers() in the hot path.
+ b.writeAll(caller)
+ b.write(']')
+ b.write(' ')
+
+ // User-provided format string, copied.
+ for i := 0; i < len(format); i++ {
+ b.write(format[i])
+ }
+
+ // End with a newline.
+ b.write('\n')
+
+ // Pass to the underlying routine.
+ g.Emitter.Emit(level, timestamp, b.String(), args...)
+}
diff --git a/pkg/log/glog_unsafe.go b/pkg/log/glog_unsafe.go
new file mode 100644
index 000000000..ea17ae349
--- /dev/null
+++ b/pkg/log/glog_unsafe.go
@@ -0,0 +1,32 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafeString returns a string that points to the given byte array.
+// The byte array must be preserved until the string is disposed.
+func unsafeString(data []byte) (s string) {
+ if len(data) == 0 {
+ return
+ }
+
+ (*reflect.StringHeader)(unsafe.Pointer(&s)).Data = uintptr(unsafe.Pointer(&data[0]))
+ (*reflect.StringHeader)(unsafe.Pointer(&s)).Len = len(data)
+ return
+}
diff --git a/pkg/log/json.go b/pkg/log/json.go
new file mode 100644
index 000000000..a278c8fc8
--- /dev/null
+++ b/pkg/log/json.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "encoding/json"
+ "fmt"
+ "time"
+)
+
+type jsonLog struct {
+ Msg string `json:"msg"`
+ Level Level `json:"level"`
+ Time time.Time `json:"time"`
+}
+
+// MarshalJSON implements json.Marshaler.MarashalJSON.
+func (lv Level) MarshalJSON() ([]byte, error) {
+ switch lv {
+ case Warning:
+ return []byte(`"warning"`), nil
+ case Info:
+ return []byte(`"info"`), nil
+ case Debug:
+ return []byte(`"debug"`), nil
+ default:
+ return nil, fmt.Errorf("unknown level %v", lv)
+ }
+}
+
+// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON. It can unmarshal
+// from both string names and integers.
+func (lv *Level) UnmarshalJSON(b []byte) error {
+ switch s := string(b); s {
+ case "0", `"warning"`:
+ *lv = Warning
+ case "1", `"info"`:
+ *lv = Info
+ case "2", `"debug"`:
+ *lv = Debug
+ default:
+ return fmt.Errorf("unknown level %q", s)
+ }
+ return nil
+}
+
+// JSONEmitter logs messages in json format.
+type JSONEmitter struct {
+ Writer
+}
+
+// Emit implements Emitter.Emit.
+func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+ j := jsonLog{
+ Msg: fmt.Sprintf(format, v...),
+ Level: level,
+ Time: timestamp,
+ }
+ b, err := json.Marshal(j)
+ if err != nil {
+ panic(err)
+ }
+ e.Writer.Write(b)
+}
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
new file mode 100644
index 000000000..c2c019915
--- /dev/null
+++ b/pkg/log/json_k8s.go
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package log
+
+import (
+ "encoding/json"
+ "fmt"
+ "time"
+)
+
+type k8sJSONLog struct {
+ Log string `json:"log"`
+ Level Level `json:"level"`
+ Time time.Time `json:"time"`
+}
+
+// K8sJSONEmitter logs messages in json format that is compatible with
+// Kubernetes fluent configuration.
+type K8sJSONEmitter struct {
+ Writer
+}
+
+// Emit implements Emitter.Emit.
+func (e K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+ j := k8sJSONLog{
+ Log: fmt.Sprintf(format, v...),
+ Level: level,
+ Time: timestamp,
+ }
+ b, err := json.Marshal(j)
+ if err != nil {
+ panic(err)
+ }
+ e.Writer.Write(b)
+}
diff --git a/pkg/log/log.go b/pkg/log/log.go
new file mode 100644
index 000000000..7d563241e
--- /dev/null
+++ b/pkg/log/log.go
@@ -0,0 +1,323 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package log implements a library for logging.
+//
+// This is separate from the standard logging package because logging may be a
+// high-impact activity, and therefore we wanted to provide as much flexibility
+// as possible in the underlying implementation.
+package log
+
+import (
+ "fmt"
+ "io"
+ stdlog "log"
+ "os"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/linewriter"
+)
+
+// Level is the log level.
+type Level uint32
+
+// The following levels are fixed, and can never be changed. Since some control
+// RPCs allow for changing the level as an integer, it is only possible to add
+// additional levels, and the existing one cannot be removed.
+const (
+ // Warning indicates that output should always be emitted.
+ Warning Level = iota
+
+ // Info indicates that output should normally be emitted.
+ Info
+
+ // Debug indicates that output should not normally be emitted.
+ Debug
+)
+
+// Emitter is the final destination for logs.
+type Emitter interface {
+ // Emit emits the given log statement. This allows for control over the
+ // timestamp used for logging.
+ Emit(level Level, timestamp time.Time, format string, v ...interface{})
+}
+
+// Writer writes the output to the given writer.
+type Writer struct {
+ // Next is where output is written.
+ Next io.Writer
+
+ // mu protects fields below.
+ mu sync.Mutex
+
+ // errors counts failures to write log messages so it can be reported
+ // when writer start to work again. Needs to be accessed using atomics
+ // to make race detector happy because it's read outside the mutex.
+ errors int32
+}
+
+// Write writes out the given bytes, handling non-blocking sockets.
+func (l *Writer) Write(data []byte) (int, error) {
+ n := 0
+
+ for n < len(data) {
+ w, err := l.Next.Write(data[n:])
+ n += w
+
+ // Is it a non-blocking socket?
+ if pathErr, ok := err.(*os.PathError); ok && pathErr.Err == syscall.EAGAIN {
+ runtime.Gosched()
+ continue
+ }
+
+ // Some other error?
+ if err != nil {
+ l.mu.Lock()
+ atomic.AddInt32(&l.errors, 1)
+ l.mu.Unlock()
+ return n, err
+ }
+ }
+
+ // Do we need to end with a '\n'?
+ if len(data) == 0 || data[len(data)-1] != '\n' {
+ l.Write([]byte{'\n'})
+ }
+
+ // Dirty read in case there were errors (rare).
+ if atomic.LoadInt32(&l.errors) > 0 {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ // Recheck condition under lock.
+ if e := atomic.LoadInt32(&l.errors); e > 0 {
+ msg := fmt.Sprintf("\n*** Dropped %d log messages ***\n", e)
+ if _, err := l.Next.Write([]byte(msg)); err == nil {
+ atomic.StoreInt32(&l.errors, 0)
+ }
+ }
+ }
+
+ return n, nil
+}
+
+// Emit emits the message.
+func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+ fmt.Fprintf(l, format, args...)
+}
+
+// MultiEmitter is an emitter that emits to multiple Emitters.
+type MultiEmitter []Emitter
+
+// Emit emits to all emitters.
+func (m MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+ for _, e := range m {
+ e.Emit(level, timestamp, format, v...)
+ }
+}
+
+// TestLogger is implemented by testing.T and testing.B.
+type TestLogger interface {
+ Logf(format string, v ...interface{})
+}
+
+// TestEmitter may be used for wrapping tests.
+type TestEmitter struct {
+ TestLogger
+}
+
+// Emit emits to the TestLogger.
+func (t TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+ t.Logf(format, v...)
+}
+
+// Logger is a high-level logging interface. It is in fact, not used within the
+// log package. Rather it is provided for others to provide contextual loggers
+// that may append some addition information to log statement. BasicLogger
+// satisfies this interface, and may be passed around as a Logger.
+type Logger interface {
+ // Debugf logs a debug statement.
+ Debugf(format string, v ...interface{})
+
+ // Infof logs at an info level.
+ Infof(format string, v ...interface{})
+
+ // Warningf logs at a warning level.
+ Warningf(format string, v ...interface{})
+
+ // IsLogging returns true iff this level is being logged. This may be
+ // used to short-circuit expensive operations for debugging calls.
+ IsLogging(level Level) bool
+}
+
+// BasicLogger is the default implementation of Logger.
+type BasicLogger struct {
+ Level
+ Emitter
+}
+
+// Debugf implements logger.Debugf.
+func (l *BasicLogger) Debugf(format string, v ...interface{}) {
+ if l.IsLogging(Debug) {
+ l.Emit(Debug, time.Now(), format, v...)
+ }
+}
+
+// Infof implements logger.Infof.
+func (l *BasicLogger) Infof(format string, v ...interface{}) {
+ if l.IsLogging(Info) {
+ l.Emit(Info, time.Now(), format, v...)
+ }
+}
+
+// Warningf implements logger.Warningf.
+func (l *BasicLogger) Warningf(format string, v ...interface{}) {
+ if l.IsLogging(Warning) {
+ l.Emit(Warning, time.Now(), format, v...)
+ }
+}
+
+// IsLogging implements logger.IsLogging.
+func (l *BasicLogger) IsLogging(level Level) bool {
+ return atomic.LoadUint32((*uint32)(&l.Level)) >= uint32(level)
+}
+
+// SetLevel sets the logging level.
+func (l *BasicLogger) SetLevel(level Level) {
+ atomic.StoreUint32((*uint32)(&l.Level), uint32(level))
+}
+
+// logMu protects Log below. We use atomic operations to read the value, but
+// updates require logMu to ensure consistency.
+var logMu sync.Mutex
+
+// log is the default logger.
+var log atomic.Value
+
+// Log retieves the global logger.
+func Log() *BasicLogger {
+ return log.Load().(*BasicLogger)
+}
+
+// SetTarget sets the log target.
+//
+// This is not thread safe and shouldn't be called concurrently with any
+// logging calls.
+func SetTarget(target Emitter) {
+ logMu.Lock()
+ defer logMu.Unlock()
+ oldLog := Log()
+ log.Store(&BasicLogger{Level: oldLog.Level, Emitter: target})
+}
+
+// SetLevel sets the log level.
+func SetLevel(newLevel Level) {
+ Log().SetLevel(newLevel)
+}
+
+// Debugf logs to the global logger.
+func Debugf(format string, v ...interface{}) {
+ Log().Debugf(format, v...)
+}
+
+// Infof logs to the global logger.
+func Infof(format string, v ...interface{}) {
+ Log().Infof(format, v...)
+}
+
+// Warningf logs to the global logger.
+func Warningf(format string, v ...interface{}) {
+ Log().Warningf(format, v...)
+}
+
+// defaultStackSize is the default buffer size to allocate for stack traces.
+const defaultStackSize = 1 << 16 // 64KB
+
+// maxStackSize is the maximum buffer size to allocate for stack traces.
+const maxStackSize = 1 << 26 // 64MB
+
+// Stacks returns goroutine stacks, like panic.
+func Stacks(all bool) []byte {
+ var trace []byte
+ for s := defaultStackSize; s <= maxStackSize; s *= 4 {
+ trace = make([]byte, s)
+ nbytes := runtime.Stack(trace, all)
+ if nbytes == s {
+ continue
+ }
+ return trace[:nbytes]
+ }
+ trace = append(trace, []byte("\n\n...<too large, truncated>")...)
+ return trace
+}
+
+// Traceback logs the given message and dumps a stacktrace of the current
+// goroutine.
+//
+// This will be print a traceback, tb, as Warningf(format+":\n%s", v..., tb).
+func Traceback(format string, v ...interface{}) {
+ v = append(v, Stacks(false))
+ Warningf(format+":\n%s", v...)
+}
+
+// TracebackAll logs the given message and dumps a stacktrace of all goroutines.
+//
+// This will be print a traceback, tb, as Warningf(format+":\n%s", v..., tb).
+func TracebackAll(format string, v ...interface{}) {
+ v = append(v, Stacks(true))
+ Warningf(format+":\n%s", v...)
+}
+
+// IsLogging returns whether the global logger is logging.
+func IsLogging(level Level) bool {
+ return Log().IsLogging(level)
+}
+
+// CopyStandardLogTo redirects the stdlib log package global output to the global
+// logger for the specified level.
+func CopyStandardLogTo(l Level) error {
+ var f func(string, ...interface{})
+
+ switch l {
+ case Debug:
+ f = Debugf
+ case Info:
+ f = Infof
+ case Warning:
+ f = Warningf
+ default:
+ return fmt.Errorf("Unknown log level %v", l)
+ }
+
+ stdlog.SetOutput(linewriter.NewWriter(func(p []byte) {
+ // We must not retain p, but log formatting is not required to
+ // be synchronous (though the in-package implementations are),
+ // so we must make a copy.
+ b := make([]byte, len(p))
+ copy(b, p)
+
+ f("%s", b)
+ }))
+
+ return nil
+}
+
+func init() {
+ // Store the initial value for the log.
+ log.Store(&BasicLogger{Level: Info, Emitter: GoogleEmitter{&Writer{Next: os.Stderr}}})
+}
diff --git a/pkg/log/log_state_autogen.go b/pkg/log/log_state_autogen.go
new file mode 100755
index 000000000..010b760a5
--- /dev/null
+++ b/pkg/log/log_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package log
+
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
new file mode 100644
index 000000000..803709cc4
--- /dev/null
+++ b/pkg/metric/metric.go
@@ -0,0 +1,250 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package metric provides primitives for collecting metrics.
+package metric
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/eventchannel"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ pb "gvisor.googlesource.com/gvisor/pkg/metric/metric_go_proto"
+)
+
+var (
+ // ErrNameInUse indicates that another metric is already defined for
+ // the given name.
+ ErrNameInUse = errors.New("metric name already in use")
+
+ // ErrInitializationDone indicates that the caller tried to create a
+ // new metric after initialization.
+ ErrInitializationDone = errors.New("metric cannot be created after initialization is complete")
+)
+
+// Uint64Metric encapsulates a uint64 that represents some kind of metric to be
+// monitored.
+//
+// All metrics must be cumulative, meaning that their values will only increase
+// over time.
+//
+// Metrics are not saved across save/restore and thus reset to zero on restore.
+//
+// TODO(b/67298402): Support non-cumulative metrics.
+// TODO(b/67298427): Support metric fields.
+//
+type Uint64Metric struct {
+ // value is the actual value of the metric. It must be accessed
+ // atomically.
+ value uint64
+}
+
+var (
+ // initialized indicates that all metrics are registered. allMetrics is
+ // immutable once initialized is true.
+ initialized bool
+
+ // allMetrics are the registered metrics.
+ allMetrics = makeMetricSet()
+)
+
+// Initialize sends a metric registration event over the event channel.
+//
+// Precondition:
+// * All metrics are registered.
+// * Initialize/Disable has not been called.
+func Initialize() {
+ if initialized {
+ panic("Initialize/Disable called more than once")
+ }
+ initialized = true
+
+ m := pb.MetricRegistration{}
+ for _, v := range allMetrics.m {
+ m.Metrics = append(m.Metrics, v.metadata)
+ }
+ eventchannel.Emit(&m)
+}
+
+// Disable sends an empty metric registration event over the event channel,
+// disabling metric collection.
+//
+// Precondition:
+// * All metrics are registered.
+// * Initialize/Disable has not been called.
+func Disable() {
+ if initialized {
+ panic("Initialize/Disable called more than once")
+ }
+ initialized = true
+
+ m := pb.MetricRegistration{}
+ if err := eventchannel.Emit(&m); err != nil {
+ panic("unable to emit metric disable event: " + err.Error())
+ }
+}
+
+type customUint64Metric struct {
+ // metadata describes the metric. It is immutable.
+ metadata *pb.MetricMetadata
+
+ // value returns the current value of the metric.
+ value func() uint64
+}
+
+// RegisterCustomUint64Metric registers a metric with the given name.
+//
+// Register must only be called at init and will return and error if called
+// after Initialized.
+//
+// All metrics must be cumulative, meaning that the return values of value must
+// only increase over time.
+//
+// Preconditions:
+// * name must be globally unique.
+// * Initialize/Disable have not been called.
+func RegisterCustomUint64Metric(name string, sync bool, description string, value func() uint64) error {
+ if initialized {
+ return ErrInitializationDone
+ }
+
+ if _, ok := allMetrics.m[name]; ok {
+ return ErrNameInUse
+ }
+
+ allMetrics.m[name] = customUint64Metric{
+ metadata: &pb.MetricMetadata{
+ Name: name,
+ Description: description,
+ Cumulative: true,
+ Sync: sync,
+ Type: pb.MetricMetadata_UINT64,
+ },
+ value: value,
+ }
+ return nil
+}
+
+// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric and panics
+// if it returns an error.
+func MustRegisterCustomUint64Metric(name string, sync bool, description string, value func() uint64) {
+ if err := RegisterCustomUint64Metric(name, sync, description, value); err != nil {
+ panic(fmt.Sprintf("Unable to register metric %q: %v", name, err))
+ }
+}
+
+// NewUint64Metric creates and registers a new metric with the given name.
+//
+// Metrics must be statically defined (i.e., at init).
+func NewUint64Metric(name string, sync bool, description string) (*Uint64Metric, error) {
+ var m Uint64Metric
+ return &m, RegisterCustomUint64Metric(name, sync, description, m.Value)
+}
+
+// MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an
+// error.
+func MustCreateNewUint64Metric(name string, sync bool, description string) *Uint64Metric {
+ m, err := NewUint64Metric(name, sync, description)
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
+ }
+ return m
+}
+
+// Value returns the current value of the metric.
+func (m *Uint64Metric) Value() uint64 {
+ return atomic.LoadUint64(&m.value)
+}
+
+// Increment increments the metric by 1.
+func (m *Uint64Metric) Increment() {
+ atomic.AddUint64(&m.value, 1)
+}
+
+// IncrementBy increments the metric by v.
+func (m *Uint64Metric) IncrementBy(v uint64) {
+ atomic.AddUint64(&m.value, v)
+}
+
+// metricSet holds named metrics.
+type metricSet struct {
+ m map[string]customUint64Metric
+}
+
+// makeMetricSet returns a new metricSet.
+func makeMetricSet() metricSet {
+ return metricSet{
+ m: make(map[string]customUint64Metric),
+ }
+}
+
+// Values returns a snapshot of all values in m.
+func (m *metricSet) Values() metricValues {
+ vals := make(metricValues)
+ for k, v := range m.m {
+ vals[k] = v.value()
+ }
+ return vals
+}
+
+// metricValues contains a copy of the values of all metrics.
+type metricValues map[string]uint64
+
+var (
+ // emitMu protects metricsAtLastEmit and ensures that all emitted
+ // metrics are strongly ordered (older metrics are never emitted after
+ // newer metrics).
+ emitMu sync.Mutex
+
+ // metricsAtLastEmit contains the state of the metrics at the last emit event.
+ metricsAtLastEmit metricValues
+)
+
+// EmitMetricUpdate emits a MetricUpdate over the event channel.
+//
+// Only metrics that have changed since the last call are emitted.
+//
+// EmitMetricUpdate is thread-safe.
+//
+// Preconditions:
+// * Initialize has been called.
+func EmitMetricUpdate() {
+ emitMu.Lock()
+ defer emitMu.Unlock()
+
+ snapshot := allMetrics.Values()
+
+ m := pb.MetricUpdate{}
+ for k, v := range snapshot {
+ // On the first call metricsAtLastEmit will be empty. Include
+ // all metrics then.
+ if prev, ok := metricsAtLastEmit[k]; !ok || prev != v {
+ m.Metrics = append(m.Metrics, &pb.MetricValue{
+ Name: k,
+ Value: &pb.MetricValue_Uint64Value{v},
+ })
+ }
+ }
+
+ metricsAtLastEmit = snapshot
+ if len(m.Metrics) == 0 {
+ return
+ }
+
+ log.Debugf("Emitting metrics: %v", m)
+ eventchannel.Emit(&m)
+}
diff --git a/pkg/metric/metric_go_proto/metric.pb.go b/pkg/metric/metric_go_proto/metric.pb.go
new file mode 100755
index 000000000..553236535
--- /dev/null
+++ b/pkg/metric/metric_go_proto/metric.pb.go
@@ -0,0 +1,297 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: pkg/metric/metric.proto
+
+package gvisor
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type MetricMetadata_Type int32
+
+const (
+ MetricMetadata_UINT64 MetricMetadata_Type = 0
+)
+
+var MetricMetadata_Type_name = map[int32]string{
+ 0: "UINT64",
+}
+
+var MetricMetadata_Type_value = map[string]int32{
+ "UINT64": 0,
+}
+
+func (x MetricMetadata_Type) String() string {
+ return proto.EnumName(MetricMetadata_Type_name, int32(x))
+}
+
+func (MetricMetadata_Type) EnumDescriptor() ([]byte, []int) {
+ return fileDescriptor_87b8778a4ff2ab5c, []int{0, 0}
+}
+
+type MetricMetadata struct {
+ Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
+ Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"`
+ Cumulative bool `protobuf:"varint,3,opt,name=cumulative,proto3" json:"cumulative,omitempty"`
+ Sync bool `protobuf:"varint,4,opt,name=sync,proto3" json:"sync,omitempty"`
+ Type MetricMetadata_Type `protobuf:"varint,5,opt,name=type,proto3,enum=gvisor.MetricMetadata_Type" json:"type,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *MetricMetadata) Reset() { *m = MetricMetadata{} }
+func (m *MetricMetadata) String() string { return proto.CompactTextString(m) }
+func (*MetricMetadata) ProtoMessage() {}
+func (*MetricMetadata) Descriptor() ([]byte, []int) {
+ return fileDescriptor_87b8778a4ff2ab5c, []int{0}
+}
+
+func (m *MetricMetadata) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_MetricMetadata.Unmarshal(m, b)
+}
+func (m *MetricMetadata) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_MetricMetadata.Marshal(b, m, deterministic)
+}
+func (m *MetricMetadata) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_MetricMetadata.Merge(m, src)
+}
+func (m *MetricMetadata) XXX_Size() int {
+ return xxx_messageInfo_MetricMetadata.Size(m)
+}
+func (m *MetricMetadata) XXX_DiscardUnknown() {
+ xxx_messageInfo_MetricMetadata.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_MetricMetadata proto.InternalMessageInfo
+
+func (m *MetricMetadata) GetName() string {
+ if m != nil {
+ return m.Name
+ }
+ return ""
+}
+
+func (m *MetricMetadata) GetDescription() string {
+ if m != nil {
+ return m.Description
+ }
+ return ""
+}
+
+func (m *MetricMetadata) GetCumulative() bool {
+ if m != nil {
+ return m.Cumulative
+ }
+ return false
+}
+
+func (m *MetricMetadata) GetSync() bool {
+ if m != nil {
+ return m.Sync
+ }
+ return false
+}
+
+func (m *MetricMetadata) GetType() MetricMetadata_Type {
+ if m != nil {
+ return m.Type
+ }
+ return MetricMetadata_UINT64
+}
+
+type MetricRegistration struct {
+ Metrics []*MetricMetadata `protobuf:"bytes,1,rep,name=metrics,proto3" json:"metrics,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *MetricRegistration) Reset() { *m = MetricRegistration{} }
+func (m *MetricRegistration) String() string { return proto.CompactTextString(m) }
+func (*MetricRegistration) ProtoMessage() {}
+func (*MetricRegistration) Descriptor() ([]byte, []int) {
+ return fileDescriptor_87b8778a4ff2ab5c, []int{1}
+}
+
+func (m *MetricRegistration) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_MetricRegistration.Unmarshal(m, b)
+}
+func (m *MetricRegistration) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_MetricRegistration.Marshal(b, m, deterministic)
+}
+func (m *MetricRegistration) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_MetricRegistration.Merge(m, src)
+}
+func (m *MetricRegistration) XXX_Size() int {
+ return xxx_messageInfo_MetricRegistration.Size(m)
+}
+func (m *MetricRegistration) XXX_DiscardUnknown() {
+ xxx_messageInfo_MetricRegistration.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_MetricRegistration proto.InternalMessageInfo
+
+func (m *MetricRegistration) GetMetrics() []*MetricMetadata {
+ if m != nil {
+ return m.Metrics
+ }
+ return nil
+}
+
+type MetricValue struct {
+ Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
+ // Types that are valid to be assigned to Value:
+ // *MetricValue_Uint64Value
+ Value isMetricValue_Value `protobuf_oneof:"value"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *MetricValue) Reset() { *m = MetricValue{} }
+func (m *MetricValue) String() string { return proto.CompactTextString(m) }
+func (*MetricValue) ProtoMessage() {}
+func (*MetricValue) Descriptor() ([]byte, []int) {
+ return fileDescriptor_87b8778a4ff2ab5c, []int{2}
+}
+
+func (m *MetricValue) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_MetricValue.Unmarshal(m, b)
+}
+func (m *MetricValue) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_MetricValue.Marshal(b, m, deterministic)
+}
+func (m *MetricValue) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_MetricValue.Merge(m, src)
+}
+func (m *MetricValue) XXX_Size() int {
+ return xxx_messageInfo_MetricValue.Size(m)
+}
+func (m *MetricValue) XXX_DiscardUnknown() {
+ xxx_messageInfo_MetricValue.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_MetricValue proto.InternalMessageInfo
+
+func (m *MetricValue) GetName() string {
+ if m != nil {
+ return m.Name
+ }
+ return ""
+}
+
+type isMetricValue_Value interface {
+ isMetricValue_Value()
+}
+
+type MetricValue_Uint64Value struct {
+ Uint64Value uint64 `protobuf:"varint,2,opt,name=uint64_value,json=uint64Value,proto3,oneof"`
+}
+
+func (*MetricValue_Uint64Value) isMetricValue_Value() {}
+
+func (m *MetricValue) GetValue() isMetricValue_Value {
+ if m != nil {
+ return m.Value
+ }
+ return nil
+}
+
+func (m *MetricValue) GetUint64Value() uint64 {
+ if x, ok := m.GetValue().(*MetricValue_Uint64Value); ok {
+ return x.Uint64Value
+ }
+ return 0
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*MetricValue) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*MetricValue_Uint64Value)(nil),
+ }
+}
+
+type MetricUpdate struct {
+ Metrics []*MetricValue `protobuf:"bytes,1,rep,name=metrics,proto3" json:"metrics,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *MetricUpdate) Reset() { *m = MetricUpdate{} }
+func (m *MetricUpdate) String() string { return proto.CompactTextString(m) }
+func (*MetricUpdate) ProtoMessage() {}
+func (*MetricUpdate) Descriptor() ([]byte, []int) {
+ return fileDescriptor_87b8778a4ff2ab5c, []int{3}
+}
+
+func (m *MetricUpdate) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_MetricUpdate.Unmarshal(m, b)
+}
+func (m *MetricUpdate) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_MetricUpdate.Marshal(b, m, deterministic)
+}
+func (m *MetricUpdate) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_MetricUpdate.Merge(m, src)
+}
+func (m *MetricUpdate) XXX_Size() int {
+ return xxx_messageInfo_MetricUpdate.Size(m)
+}
+func (m *MetricUpdate) XXX_DiscardUnknown() {
+ xxx_messageInfo_MetricUpdate.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_MetricUpdate proto.InternalMessageInfo
+
+func (m *MetricUpdate) GetMetrics() []*MetricValue {
+ if m != nil {
+ return m.Metrics
+ }
+ return nil
+}
+
+func init() {
+ proto.RegisterEnum("gvisor.MetricMetadata_Type", MetricMetadata_Type_name, MetricMetadata_Type_value)
+ proto.RegisterType((*MetricMetadata)(nil), "gvisor.MetricMetadata")
+ proto.RegisterType((*MetricRegistration)(nil), "gvisor.MetricRegistration")
+ proto.RegisterType((*MetricValue)(nil), "gvisor.MetricValue")
+ proto.RegisterType((*MetricUpdate)(nil), "gvisor.MetricUpdate")
+}
+
+func init() { proto.RegisterFile("pkg/metric/metric.proto", fileDescriptor_87b8778a4ff2ab5c) }
+
+var fileDescriptor_87b8778a4ff2ab5c = []byte{
+ // 288 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x91, 0xc1, 0x4b, 0xc3, 0x30,
+ 0x14, 0xc6, 0x17, 0xd7, 0x75, 0xfa, 0x3a, 0x86, 0x44, 0xd0, 0x80, 0x20, 0xa5, 0x5e, 0x7a, 0xb1,
+ 0x93, 0x39, 0x76, 0xf3, 0xe2, 0x41, 0xf4, 0x30, 0x85, 0xb0, 0x79, 0x95, 0xd8, 0x86, 0x12, 0x5c,
+ 0xdb, 0xd0, 0xa4, 0x85, 0xfe, 0x75, 0xfe, 0x6b, 0xd2, 0x17, 0x95, 0x4d, 0x76, 0xca, 0xcb, 0xfb,
+ 0xde, 0xf7, 0xf1, 0xcb, 0x0b, 0x5c, 0xe8, 0xcf, 0x7c, 0x56, 0x48, 0x5b, 0xab, 0xf4, 0xe7, 0x48,
+ 0x74, 0x5d, 0xd9, 0x8a, 0xfa, 0x79, 0xab, 0x4c, 0x55, 0x47, 0x5f, 0x04, 0xa6, 0x2b, 0x14, 0x56,
+ 0xd2, 0x8a, 0x4c, 0x58, 0x41, 0x29, 0x78, 0xa5, 0x28, 0x24, 0x23, 0x21, 0x89, 0x4f, 0x38, 0xd6,
+ 0x34, 0x84, 0x20, 0x93, 0x26, 0xad, 0x95, 0xb6, 0xaa, 0x2a, 0xd9, 0x11, 0x4a, 0xbb, 0x2d, 0x7a,
+ 0x05, 0x90, 0x36, 0x45, 0xb3, 0x15, 0x56, 0xb5, 0x92, 0x0d, 0x43, 0x12, 0x1f, 0xf3, 0x9d, 0x4e,
+ 0x9f, 0x6a, 0xba, 0x32, 0x65, 0x1e, 0x2a, 0x58, 0xd3, 0x19, 0x78, 0xb6, 0xd3, 0x92, 0x8d, 0x42,
+ 0x12, 0x4f, 0xe7, 0x97, 0x89, 0x63, 0x4a, 0xf6, 0x79, 0x92, 0x75, 0xa7, 0x25, 0xc7, 0xc1, 0x88,
+ 0x82, 0xd7, 0xdf, 0x28, 0x80, 0xbf, 0x79, 0x7e, 0x59, 0x2f, 0x17, 0xa7, 0x83, 0xe8, 0x11, 0xa8,
+ 0x33, 0x70, 0x99, 0x2b, 0x63, 0x6b, 0x81, 0x38, 0xb7, 0x30, 0x76, 0xef, 0x35, 0x8c, 0x84, 0xc3,
+ 0x38, 0x98, 0x9f, 0x1f, 0x4e, 0xe7, 0xbf, 0x63, 0xd1, 0x2b, 0x04, 0x4e, 0x7a, 0x13, 0xdb, 0x46,
+ 0x1e, 0xdc, 0xc2, 0x35, 0x4c, 0x1a, 0x55, 0xda, 0xe5, 0xe2, 0xbd, 0xed, 0x67, 0x70, 0x0d, 0xde,
+ 0xd3, 0x80, 0x07, 0xae, 0x8b, 0xc6, 0x87, 0x31, 0x8c, 0x50, 0x8d, 0xee, 0x61, 0xe2, 0x02, 0x37,
+ 0x3a, 0x13, 0x56, 0xd2, 0x9b, 0xff, 0x48, 0x67, 0xfb, 0x48, 0x68, 0xff, 0xe3, 0xf9, 0xf0, 0xf1,
+ 0xa3, 0xee, 0xbe, 0x03, 0x00, 0x00, 0xff, 0xff, 0xcb, 0x7f, 0xcb, 0x46, 0xc3, 0x01, 0x00, 0x00,
+}
diff --git a/pkg/metric/metric_state_autogen.go b/pkg/metric/metric_state_autogen.go
new file mode 100755
index 000000000..985c28832
--- /dev/null
+++ b/pkg/metric/metric_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package metric
+
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go
new file mode 100644
index 000000000..249536d8a
--- /dev/null
+++ b/pkg/p9/buffer.go
@@ -0,0 +1,263 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "encoding/binary"
+)
+
+// encoder is used for messages and 9P primitives.
+type encoder interface {
+ // Decode decodes from the given buffer. Decode may be called more than once
+ // to reuse the instance. It must clear any previous state.
+ //
+ // This may not fail, exhaustion will be recorded in the buffer.
+ Decode(b *buffer)
+
+ // Encode encodes to the given buffer.
+ //
+ // This may not fail.
+ Encode(b *buffer)
+}
+
+// order is the byte order used for encoding.
+var order = binary.LittleEndian
+
+// buffer is a slice that is consumed.
+//
+// This is passed to the encoder methods.
+type buffer struct {
+ // data is the underlying data. This may grow during Encode.
+ data []byte
+
+ // overflow indicates whether an overflow has occurred.
+ overflow bool
+}
+
+// append appends n bytes to the buffer and returns a slice pointing to the
+// newly appended bytes.
+func (b *buffer) append(n int) []byte {
+ b.data = append(b.data, make([]byte, n)...)
+ return b.data[len(b.data)-n:]
+}
+
+// consume consumes n bytes from the buffer.
+func (b *buffer) consume(n int) ([]byte, bool) {
+ if !b.has(n) {
+ b.markOverrun()
+ return nil, false
+ }
+ rval := b.data[:n]
+ b.data = b.data[n:]
+ return rval, true
+}
+
+// has returns true if n bytes are available.
+func (b *buffer) has(n int) bool {
+ return len(b.data) >= n
+}
+
+// markOverrun immediately marks this buffer as overrun.
+//
+// This is used by ReadString, since some invalid data implies the rest of the
+// buffer is no longer valid either.
+func (b *buffer) markOverrun() {
+ b.overflow = true
+}
+
+// isOverrun returns true if this buffer has run past the end.
+func (b *buffer) isOverrun() bool {
+ return b.overflow
+}
+
+// Read8 reads a byte from the buffer.
+func (b *buffer) Read8() uint8 {
+ v, ok := b.consume(1)
+ if !ok {
+ return 0
+ }
+ return uint8(v[0])
+}
+
+// Read16 reads a 16-bit value from the buffer.
+func (b *buffer) Read16() uint16 {
+ v, ok := b.consume(2)
+ if !ok {
+ return 0
+ }
+ return order.Uint16(v)
+}
+
+// Read32 reads a 32-bit value from the buffer.
+func (b *buffer) Read32() uint32 {
+ v, ok := b.consume(4)
+ if !ok {
+ return 0
+ }
+ return order.Uint32(v)
+}
+
+// Read64 reads a 64-bit value from the buffer.
+func (b *buffer) Read64() uint64 {
+ v, ok := b.consume(8)
+ if !ok {
+ return 0
+ }
+ return order.Uint64(v)
+}
+
+// ReadQIDType reads a QIDType value.
+func (b *buffer) ReadQIDType() QIDType {
+ return QIDType(b.Read8())
+}
+
+// ReadTag reads a Tag value.
+func (b *buffer) ReadTag() Tag {
+ return Tag(b.Read16())
+}
+
+// ReadFID reads a FID value.
+func (b *buffer) ReadFID() FID {
+ return FID(b.Read32())
+}
+
+// ReadUID reads a UID value.
+func (b *buffer) ReadUID() UID {
+ return UID(b.Read32())
+}
+
+// ReadGID reads a GID value.
+func (b *buffer) ReadGID() GID {
+ return GID(b.Read32())
+}
+
+// ReadPermissions reads a file mode value and applies the mask for permissions.
+func (b *buffer) ReadPermissions() FileMode {
+ return b.ReadFileMode() & permissionsMask
+}
+
+// ReadFileMode reads a file mode value.
+func (b *buffer) ReadFileMode() FileMode {
+ return FileMode(b.Read32())
+}
+
+// ReadOpenFlags reads an OpenFlags.
+func (b *buffer) ReadOpenFlags() OpenFlags {
+ return OpenFlags(b.Read32())
+}
+
+// ReadConnectFlags reads a ConnectFlags.
+func (b *buffer) ReadConnectFlags() ConnectFlags {
+ return ConnectFlags(b.Read32())
+}
+
+// ReadMsgType writes a MsgType.
+func (b *buffer) ReadMsgType() MsgType {
+ return MsgType(b.Read8())
+}
+
+// ReadString deserializes a string.
+func (b *buffer) ReadString() string {
+ l := b.Read16()
+ if !b.has(int(l)) {
+ // Mark the buffer as corrupted.
+ b.markOverrun()
+ return ""
+ }
+
+ bs := make([]byte, l)
+ for i := 0; i < int(l); i++ {
+ bs[i] = byte(b.Read8())
+ }
+ return string(bs)
+}
+
+// Write8 writes a byte to the buffer.
+func (b *buffer) Write8(v uint8) {
+ b.append(1)[0] = byte(v)
+}
+
+// Write16 writes a 16-bit value to the buffer.
+func (b *buffer) Write16(v uint16) {
+ order.PutUint16(b.append(2), v)
+}
+
+// Write32 writes a 32-bit value to the buffer.
+func (b *buffer) Write32(v uint32) {
+ order.PutUint32(b.append(4), v)
+}
+
+// Write64 writes a 64-bit value to the buffer.
+func (b *buffer) Write64(v uint64) {
+ order.PutUint64(b.append(8), v)
+}
+
+// WriteQIDType writes a QIDType value.
+func (b *buffer) WriteQIDType(qidType QIDType) {
+ b.Write8(uint8(qidType))
+}
+
+// WriteTag writes a Tag value.
+func (b *buffer) WriteTag(tag Tag) {
+ b.Write16(uint16(tag))
+}
+
+// WriteFID writes a FID value.
+func (b *buffer) WriteFID(fid FID) {
+ b.Write32(uint32(fid))
+}
+
+// WriteUID writes a UID value.
+func (b *buffer) WriteUID(uid UID) {
+ b.Write32(uint32(uid))
+}
+
+// WriteGID writes a GID value.
+func (b *buffer) WriteGID(gid GID) {
+ b.Write32(uint32(gid))
+}
+
+// WritePermissions applies a permissions mask and writes the FileMode.
+func (b *buffer) WritePermissions(perm FileMode) {
+ b.WriteFileMode(perm & permissionsMask)
+}
+
+// WriteFileMode writes a FileMode.
+func (b *buffer) WriteFileMode(mode FileMode) {
+ b.Write32(uint32(mode))
+}
+
+// WriteOpenFlags writes an OpenFlags.
+func (b *buffer) WriteOpenFlags(flags OpenFlags) {
+ b.Write32(uint32(flags))
+}
+
+// WriteConnectFlags writes a ConnectFlags.
+func (b *buffer) WriteConnectFlags(flags ConnectFlags) {
+ b.Write32(uint32(flags))
+}
+
+// WriteMsgType writes a MsgType.
+func (b *buffer) WriteMsgType(t MsgType) {
+ b.Write8(uint8(t))
+}
+
+// WriteString serializes the given string.
+func (b *buffer) WriteString(s string) {
+ b.Write16(uint16(len(s)))
+ for i := 0; i < len(s); i++ {
+ b.Write8(byte(s[i]))
+ }
+}
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
new file mode 100644
index 000000000..56587e2cf
--- /dev/null
+++ b/pkg/p9/client.go
@@ -0,0 +1,307 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// ErrOutOfTags indicates no tags are available.
+var ErrOutOfTags = errors.New("out of tags -- messages lost?")
+
+// ErrOutOfFIDs indicates no more FIDs are available.
+var ErrOutOfFIDs = errors.New("out of FIDs -- messages lost?")
+
+// ErrUnexpectedTag indicates a response with an unexpected tag was received.
+var ErrUnexpectedTag = errors.New("unexpected tag in response")
+
+// ErrVersionsExhausted indicates that all versions to negotiate have been exhausted.
+var ErrVersionsExhausted = errors.New("exhausted all versions to negotiate")
+
+// ErrBadVersionString indicates that the version string is malformed or unsupported.
+var ErrBadVersionString = errors.New("bad version string")
+
+// ErrBadResponse indicates the response didn't match the request.
+type ErrBadResponse struct {
+ Got MsgType
+ Want MsgType
+}
+
+// Error returns a highly descriptive error.
+func (e *ErrBadResponse) Error() string {
+ return fmt.Sprintf("unexpected message type: got %v, want %v", e.Got, e.Want)
+}
+
+// response is the asynchronous return from recv.
+//
+// This is used in the pending map below.
+type response struct {
+ r message
+ done chan error
+}
+
+var responsePool = sync.Pool{
+ New: func() interface{} {
+ return &response{
+ done: make(chan error, 1),
+ }
+ },
+}
+
+// Client is at least a 9P2000.L client.
+type Client struct {
+ // socket is the connected socket.
+ socket *unet.Socket
+
+ // tagPool is the collection of available tags.
+ tagPool pool
+
+ // fidPool is the collection of available fids.
+ fidPool pool
+
+ // pending is the set of pending messages.
+ pending map[Tag]*response
+ pendingMu sync.Mutex
+
+ // sendMu is the lock for sending a request.
+ sendMu sync.Mutex
+
+ // recvr is essentially a mutex for calling recv.
+ //
+ // Whoever writes to this channel is permitted to call recv. When
+ // finished calling recv, this channel should be emptied.
+ recvr chan bool
+
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write
+ // request. For large reads and writes this means that the
+ // read or write is broken up into buffer-size/payloadSize
+ // requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+}
+
+// NewClient creates a new client. It performs a Tversion exchange with
+// the server to assert that messageSize is ok to use.
+//
+// You should not use the same socket for multiple clients.
+func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
+ // Need at least one byte of payload.
+ if messageSize <= msgRegistry.largestFixedSize {
+ return nil, &ErrMessageTooLarge{
+ size: messageSize,
+ msize: msgRegistry.largestFixedSize,
+ }
+ }
+
+ // Compute a payload size and round to 512 (normal block size)
+ // if it's larger than a single block.
+ payloadSize := messageSize - msgRegistry.largestFixedSize
+ if payloadSize > 512 && payloadSize%512 != 0 {
+ payloadSize -= (payloadSize % 512)
+ }
+ c := &Client{
+ socket: socket,
+ tagPool: pool{start: 1, limit: uint64(NoTag)},
+ fidPool: pool{start: 1, limit: uint64(NoFID)},
+ pending: make(map[Tag]*response),
+ recvr: make(chan bool, 1),
+ messageSize: messageSize,
+ payloadSize: payloadSize,
+ }
+ // Agree upon a version.
+ requested, ok := parseVersion(version)
+ if !ok {
+ return nil, ErrBadVersionString
+ }
+ for {
+ rversion := Rversion{}
+ err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion)
+
+ // The server told us to try again with a lower version.
+ if err == syscall.EAGAIN {
+ if requested == lowestSupportedVersion {
+ return nil, ErrVersionsExhausted
+ }
+ requested--
+ continue
+ }
+
+ // We requested an impossible version or our other parameters were bogus.
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse the version.
+ version, ok := parseVersion(rversion.Version)
+ if !ok {
+ // The server gave us a bad version. We return a generically worrisome error.
+ log.Warningf("server returned bad version string %q", rversion.Version)
+ return nil, ErrBadVersionString
+ }
+ c.version = version
+ break
+ }
+ return c, nil
+}
+
+// handleOne handles a single incoming message.
+//
+// This should only be called with the token from recvr. Note that the received
+// tag will automatically be cleared from pending.
+func (c *Client) handleOne() {
+ tag, r, err := recv(c.socket, c.messageSize, func(tag Tag, t MsgType) (message, error) {
+ c.pendingMu.Lock()
+ resp := c.pending[tag]
+ c.pendingMu.Unlock()
+
+ // Not expecting this message?
+ if resp == nil {
+ log.Warningf("client received unexpected tag %v, ignoring", tag)
+ return nil, ErrUnexpectedTag
+ }
+
+ // Is it an error? We specifically allow this to
+ // go through, and then we deserialize below.
+ if t == MsgRlerror {
+ return &Rlerror{}, nil
+ }
+
+ // Does it match expectations?
+ if t != resp.r.Type() {
+ return nil, &ErrBadResponse{Got: t, Want: resp.r.Type()}
+ }
+
+ // Return the response.
+ return resp.r, nil
+ })
+
+ if err != nil {
+ // No tag was extracted (probably a socket error).
+ //
+ // Likely catastrophic. Notify all waiters and clear pending.
+ c.pendingMu.Lock()
+ for _, resp := range c.pending {
+ resp.done <- err
+ }
+ c.pending = make(map[Tag]*response)
+ c.pendingMu.Unlock()
+ } else {
+ // Process the tag.
+ //
+ // We know that is is contained in the map because our lookup function
+ // above must have succeeded (found the tag) to return nil err.
+ c.pendingMu.Lock()
+ resp := c.pending[tag]
+ delete(c.pending, tag)
+ c.pendingMu.Unlock()
+ resp.r = r
+ resp.done <- err
+ }
+}
+
+// waitAndRecv co-ordinates with other receivers to handle responses.
+func (c *Client) waitAndRecv(done chan error) error {
+ for {
+ select {
+ case err := <-done:
+ return err
+ case c.recvr <- true:
+ select {
+ case err := <-done:
+ // It's possible that we got the token, despite
+ // done also being available. Check for that.
+ <-c.recvr
+ return err
+ default:
+ // Handle receiving one tag.
+ c.handleOne()
+
+ // Return the token.
+ <-c.recvr
+ }
+ }
+ }
+}
+
+// sendRecv performs a roundtrip message exchange.
+//
+// This is called by internal functions.
+func (c *Client) sendRecv(t message, r message) error {
+ tag, ok := c.tagPool.Get()
+ if !ok {
+ return ErrOutOfTags
+ }
+ defer c.tagPool.Put(tag)
+
+ // Indicate we're expecting a response.
+ //
+ // Note that the tag will be cleared from pending
+ // automatically (see handleOne for details).
+ resp := responsePool.Get().(*response)
+ defer responsePool.Put(resp)
+ resp.r = r
+ c.pendingMu.Lock()
+ c.pending[Tag(tag)] = resp
+ c.pendingMu.Unlock()
+
+ // Send the request over the wire.
+ c.sendMu.Lock()
+ err := send(c.socket, Tag(tag), t)
+ c.sendMu.Unlock()
+ if err != nil {
+ return err
+ }
+
+ // Co-ordinate with other receivers.
+ if err := c.waitAndRecv(resp.done); err != nil {
+ return err
+ }
+
+ // Is it an error message?
+ //
+ // For convenience, we transform these directly
+ // into errors. Handlers need not handle this case.
+ if rlerr, ok := resp.r.(*Rlerror); ok {
+ return syscall.Errno(rlerr.Error)
+ }
+
+ // At this point, we know it matches.
+ //
+ // Per recv call above, we will only allow a type
+ // match (and give our r) or an instance of Rlerror.
+ return nil
+}
+
+// Version returns the negotiated 9P2000.L.Google version number.
+func (c *Client) Version() uint32 {
+ return c.version
+}
+
+// Close closes the underlying socket.
+func (c *Client) Close() error {
+ return c.socket.Close()
+}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
new file mode 100644
index 000000000..258080f67
--- /dev/null
+++ b/pkg/p9/client_file.go
@@ -0,0 +1,632 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "io"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+// Attach attaches to a server.
+//
+// Note that authentication is not currently supported.
+func (c *Client) Attach(name string) (File, error) {
+ fid, ok := c.fidPool.Get()
+ if !ok {
+ return nil, ErrOutOfFIDs
+ }
+
+ rattach := Rattach{}
+ if err := c.sendRecv(&Tattach{FID: FID(fid), Auth: Tauth{AttachName: name, AuthenticationFID: NoFID, UID: NoUID}}, &rattach); err != nil {
+ c.fidPool.Put(fid)
+ return nil, err
+ }
+
+ return c.newFile(FID(fid)), nil
+}
+
+// newFile returns a new client file.
+func (c *Client) newFile(fid FID) *clientFile {
+ cf := &clientFile{
+ client: c,
+ fid: fid,
+ }
+
+ // Make sure the file is closed.
+ runtime.SetFinalizer(cf, (*clientFile).Close)
+
+ return cf
+}
+
+// clientFile is provided to clients.
+//
+// This proxies all of the interfaces found in file.go.
+type clientFile struct {
+ // client is the originating client.
+ client *Client
+
+ // fid is the FID for this file.
+ fid FID
+
+ // closed indicates whether this file has been closed.
+ closed uint32
+}
+
+// Walk implements File.Walk.
+func (c *clientFile) Walk(names []string) ([]QID, File, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, syscall.EBADF
+ }
+
+ fid, ok := c.client.fidPool.Get()
+ if !ok {
+ return nil, nil, ErrOutOfFIDs
+ }
+
+ rwalk := Rwalk{}
+ if err := c.client.sendRecv(&Twalk{FID: c.fid, NewFID: FID(fid), Names: names}, &rwalk); err != nil {
+ c.client.fidPool.Put(fid)
+ return nil, nil, err
+ }
+
+ // Return a new client file.
+ return rwalk.QIDs, c.client.newFile(FID(fid)), nil
+}
+
+// WalkGetAttr implements File.WalkGetAttr.
+func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, Attr, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, AttrMask{}, Attr{}, syscall.EBADF
+ }
+
+ if !versionSupportsTwalkgetattr(c.client.version) {
+ qids, file, err := c.Walk(components)
+ if err != nil {
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ _, valid, attr, err := file.GetAttr(AttrMaskAll())
+ if err != nil {
+ file.Close()
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ return qids, file, valid, attr, nil
+ }
+
+ fid, ok := c.client.fidPool.Get()
+ if !ok {
+ return nil, nil, AttrMask{}, Attr{}, ErrOutOfFIDs
+ }
+
+ rwalkgetattr := Rwalkgetattr{}
+ if err := c.client.sendRecv(&Twalkgetattr{FID: c.fid, NewFID: FID(fid), Names: components}, &rwalkgetattr); err != nil {
+ c.client.fidPool.Put(fid)
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+
+ // Return a new client file.
+ return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil
+}
+
+// StatFS implements File.StatFS.
+func (c *clientFile) StatFS() (FSStat, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return FSStat{}, syscall.EBADF
+ }
+
+ rstatfs := Rstatfs{}
+ if err := c.client.sendRecv(&Tstatfs{FID: c.fid}, &rstatfs); err != nil {
+ return FSStat{}, err
+ }
+
+ return rstatfs.FSStat, nil
+}
+
+// FSync implements File.FSync.
+func (c *clientFile) FSync() error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tfsync{FID: c.fid}, &Rfsync{})
+}
+
+// GetAttr implements File.GetAttr.
+func (c *clientFile) GetAttr(req AttrMask) (QID, AttrMask, Attr, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, AttrMask{}, Attr{}, syscall.EBADF
+ }
+
+ rgetattr := Rgetattr{}
+ if err := c.client.sendRecv(&Tgetattr{FID: c.fid, AttrMask: req}, &rgetattr); err != nil {
+ return QID{}, AttrMask{}, Attr{}, err
+ }
+
+ return rgetattr.QID, rgetattr.Valid, rgetattr.Attr, nil
+}
+
+// SetAttr implements File.SetAttr.
+func (c *clientFile) SetAttr(valid SetAttrMask, attr SetAttr) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tsetattr{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattr{})
+}
+
+// Allocate implements File.Allocate.
+func (c *clientFile) Allocate(mode AllocateMode, offset, length uint64) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsTallocate(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tallocate{FID: c.fid, Mode: mode, Offset: offset, Length: length}, &Rallocate{})
+}
+
+// Remove implements File.Remove.
+//
+// N.B. This method is no longer part of the file interface and should be
+// considered deprecated.
+func (c *clientFile) Remove() error {
+ // Avoid double close.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return syscall.EBADF
+ }
+ runtime.SetFinalizer(c, nil)
+
+ // Send the remove message.
+ if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil {
+ return err
+ }
+
+ // "It is correct to consider remove to be a clunk with the side effect
+ // of removing the file if permissions allow."
+ // https://swtch.com/plan9port/man/man9/remove.html
+
+ // Return the FID to the pool.
+ c.client.fidPool.Put(uint64(c.fid))
+ return nil
+}
+
+// Close implements File.Close.
+func (c *clientFile) Close() error {
+ // Avoid double close.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return syscall.EBADF
+ }
+ runtime.SetFinalizer(c, nil)
+
+ // Send the close message.
+ if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil {
+ // If an error occurred, we toss away the FID. This isn't ideal,
+ // but I'm not sure what else makes sense in this context.
+ log.Warningf("Tclunk failed, losing FID %v: %v", c.fid, err)
+ return err
+ }
+
+ // Return the FID to the pool.
+ c.client.fidPool.Put(uint64(c.fid))
+ return nil
+}
+
+// Open implements File.Open.
+func (c *clientFile) Open(flags OpenFlags) (*fd.FD, QID, uint32, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, QID{}, 0, syscall.EBADF
+ }
+
+ rlopen := Rlopen{}
+ if err := c.client.sendRecv(&Tlopen{FID: c.fid, Flags: flags}, &rlopen); err != nil {
+ return nil, QID{}, 0, err
+ }
+
+ return rlopen.File, rlopen.QID, rlopen.IoUnit, nil
+}
+
+// Connect implements File.Connect.
+func (c *clientFile) Connect(flags ConnectFlags) (*fd.FD, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+
+ if !VersionSupportsConnect(c.client.version) {
+ return nil, syscall.ECONNREFUSED
+ }
+
+ rlconnect := Rlconnect{}
+ if err := c.client.sendRecv(&Tlconnect{FID: c.fid, Flags: flags}, &rlconnect); err != nil {
+ return nil, err
+ }
+
+ return rlconnect.File, nil
+}
+
+// chunk applies fn to p in chunkSize-sized chunks until fn returns a partial result, p is
+// exhausted, or an error is encountered (which may be io.EOF).
+func chunk(chunkSize uint32, fn func([]byte, uint64) (int, error), p []byte, offset uint64) (int, error) {
+ // Some p9.Clients depend on executing fn on zero-byte buffers. Handle this
+ // as a special case (normally it is fine to short-circuit and return (0, nil)).
+ if len(p) == 0 {
+ return fn(p, offset)
+ }
+
+ // total is the cumulative bytes processed.
+ var total int
+ for {
+ var n int
+ var err error
+
+ // We're done, don't bother trying to do anything more.
+ if total == len(p) {
+ return total, nil
+ }
+
+ // Apply fn to a chunkSize-sized (or less) chunk of p.
+ if len(p) < total+int(chunkSize) {
+ n, err = fn(p[total:], offset)
+ } else {
+ n, err = fn(p[total:total+int(chunkSize)], offset)
+ }
+ total += n
+ offset += uint64(n)
+
+ // Return whatever we have processed if we encounter an error. This error
+ // could be io.EOF.
+ if err != nil {
+ return total, err
+ }
+
+ // Did we get a partial result? If so, return it immediately.
+ if n < int(chunkSize) {
+ return total, nil
+ }
+
+ // If we received more bytes than we ever requested, this is a problem.
+ if total > len(p) {
+ panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", total, len(p)))
+ }
+ }
+}
+
+// ReadAt proxies File.ReadAt.
+func (c *clientFile) ReadAt(p []byte, offset uint64) (int, error) {
+ return chunk(c.client.payloadSize, c.readAt, p, offset)
+}
+
+func (c *clientFile) readAt(p []byte, offset uint64) (int, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return 0, syscall.EBADF
+ }
+
+ rread := Rread{Data: p}
+ if err := c.client.sendRecv(&Tread{FID: c.fid, Offset: offset, Count: uint32(len(p))}, &rread); err != nil {
+ return 0, err
+ }
+
+ // The message may have been truncated, or for some reason a new buffer
+ // allocated. This isn't the common path, but we make sure that if the
+ // payload has changed we copy it. See transport.go for more information.
+ if len(p) > 0 && len(rread.Data) > 0 && &rread.Data[0] != &p[0] {
+ copy(p, rread.Data)
+ }
+
+ // io.EOF is not an error that a p9 server can return. Use POSIX semantics to
+ // return io.EOF manually: zero bytes were returned and a non-zero buffer was used.
+ if len(rread.Data) == 0 && len(p) > 0 {
+ return 0, io.EOF
+ }
+
+ return len(rread.Data), nil
+}
+
+// WriteAt proxies File.WriteAt.
+func (c *clientFile) WriteAt(p []byte, offset uint64) (int, error) {
+ return chunk(c.client.payloadSize, c.writeAt, p, offset)
+}
+
+func (c *clientFile) writeAt(p []byte, offset uint64) (int, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return 0, syscall.EBADF
+ }
+
+ rwrite := Rwrite{}
+ if err := c.client.sendRecv(&Twrite{FID: c.fid, Offset: offset, Data: p}, &rwrite); err != nil {
+ return 0, err
+ }
+
+ return int(rwrite.Count), nil
+}
+
+// ReadWriterFile wraps a File and implements io.ReadWriter, io.ReaderAt, and io.WriterAt.
+type ReadWriterFile struct {
+ File File
+ Offset uint64
+}
+
+// Read implements part of the io.ReadWriter interface.
+func (r *ReadWriterFile) Read(p []byte) (int, error) {
+ n, err := r.File.ReadAt(p, r.Offset)
+ r.Offset += uint64(n)
+ if err != nil {
+ return n, err
+ }
+ if n == 0 && len(p) > 0 {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// ReadAt implements the io.ReaderAt interface.
+func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) {
+ n, err := r.File.ReadAt(p, uint64(offset))
+ if err != nil {
+ return 0, err
+ }
+ if n == 0 && len(p) > 0 {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// Write implements part of the io.ReadWriter interface.
+func (r *ReadWriterFile) Write(p []byte) (int, error) {
+ n, err := r.File.WriteAt(p, r.Offset)
+ r.Offset += uint64(n)
+ if err != nil {
+ return n, err
+ }
+ if n < len(p) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
+// WriteAt implements the io.WriteAt interface.
+func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) {
+ n, err := r.File.WriteAt(p, uint64(offset))
+ if err != nil {
+ return n, err
+ }
+ if n < len(p) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
+// Rename implements File.Rename.
+func (c *clientFile) Rename(dir File, name string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ clientDir, ok := dir.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Trename{FID: c.fid, Directory: clientDir.fid, Name: name}, &Rrename{})
+}
+
+// Create implements File.Create.
+func (c *clientFile) Create(name string, openFlags OpenFlags, permissions FileMode, uid UID, gid GID) (*fd.FD, File, QID, uint32, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, nil, QID{}, 0, syscall.EBADF
+ }
+
+ msg := Tlcreate{
+ FID: c.fid,
+ Name: name,
+ OpenFlags: openFlags,
+ Permissions: permissions,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rucreate := Rucreate{}
+ if err := c.client.sendRecv(&Tucreate{Tlcreate: msg, UID: uid}, &rucreate); err != nil {
+ return nil, nil, QID{}, 0, err
+ }
+ return rucreate.File, c, rucreate.QID, rucreate.IoUnit, nil
+ }
+
+ rlcreate := Rlcreate{}
+ if err := c.client.sendRecv(&msg, &rlcreate); err != nil {
+ return nil, nil, QID{}, 0, err
+ }
+
+ return rlcreate.File, c, rlcreate.QID, rlcreate.IoUnit, nil
+}
+
+// Mkdir implements File.Mkdir.
+func (c *clientFile) Mkdir(name string, permissions FileMode, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tmkdir{
+ Directory: c.fid,
+ Name: name,
+ Permissions: permissions,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rumkdir := Rumkdir{}
+ if err := c.client.sendRecv(&Tumkdir{Tmkdir: msg, UID: uid}, &rumkdir); err != nil {
+ return QID{}, err
+ }
+ return rumkdir.QID, nil
+ }
+
+ rmkdir := Rmkdir{}
+ if err := c.client.sendRecv(&msg, &rmkdir); err != nil {
+ return QID{}, err
+ }
+
+ return rmkdir.QID, nil
+}
+
+// Symlink implements File.Symlink.
+func (c *clientFile) Symlink(oldname string, newname string, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tsymlink{
+ Directory: c.fid,
+ Name: newname,
+ Target: oldname,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rusymlink := Rusymlink{}
+ if err := c.client.sendRecv(&Tusymlink{Tsymlink: msg, UID: uid}, &rusymlink); err != nil {
+ return QID{}, err
+ }
+ return rusymlink.QID, nil
+ }
+
+ rsymlink := Rsymlink{}
+ if err := c.client.sendRecv(&msg, &rsymlink); err != nil {
+ return QID{}, err
+ }
+
+ return rsymlink.QID, nil
+}
+
+// Link implements File.Link.
+func (c *clientFile) Link(target File, newname string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ targetFile, ok := target.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tlink{Directory: c.fid, Name: newname, Target: targetFile.fid}, &Rlink{})
+}
+
+// Mknod implements File.Mknod.
+func (c *clientFile) Mknod(name string, mode FileMode, major uint32, minor uint32, uid UID, gid GID) (QID, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return QID{}, syscall.EBADF
+ }
+
+ msg := Tmknod{
+ Directory: c.fid,
+ Name: name,
+ Mode: mode,
+ Major: major,
+ Minor: minor,
+ GID: NoGID,
+ }
+
+ if versionSupportsTucreation(c.client.version) {
+ msg.GID = gid
+ rumknod := Rumknod{}
+ if err := c.client.sendRecv(&Tumknod{Tmknod: msg, UID: uid}, &rumknod); err != nil {
+ return QID{}, err
+ }
+ return rumknod.QID, nil
+ }
+
+ rmknod := Rmknod{}
+ if err := c.client.sendRecv(&msg, &rmknod); err != nil {
+ return QID{}, err
+ }
+
+ return rmknod.QID, nil
+}
+
+// RenameAt implements File.RenameAt.
+func (c *clientFile) RenameAt(oldname string, newdir File, newname string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ clientNewDir, ok := newdir.(*clientFile)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Trenameat{OldDirectory: c.fid, OldName: oldname, NewDirectory: clientNewDir.fid, NewName: newname}, &Rrenameat{})
+}
+
+// UnlinkAt implements File.UnlinkAt.
+func (c *clientFile) UnlinkAt(name string, flags uint32) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ return c.client.sendRecv(&Tunlinkat{Directory: c.fid, Name: name, Flags: flags}, &Runlinkat{})
+}
+
+// Readdir implements File.Readdir.
+func (c *clientFile) Readdir(offset uint64, count uint32) ([]Dirent, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+
+ rreaddir := Rreaddir{}
+ if err := c.client.sendRecv(&Treaddir{Directory: c.fid, Offset: offset, Count: count}, &rreaddir); err != nil {
+ return nil, err
+ }
+
+ return rreaddir.Entries, nil
+}
+
+// Readlink implements File.Readlink.
+func (c *clientFile) Readlink() (string, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return "", syscall.EBADF
+ }
+
+ rreadlink := Rreadlink{}
+ if err := c.client.sendRecv(&Treadlink{FID: c.fid}, &rreadlink); err != nil {
+ return "", err
+ }
+
+ return rreadlink.Target, nil
+}
+
+// Flush implements File.Flush.
+func (c *clientFile) Flush() error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+
+ if !VersionSupportsTflushf(c.client.version) {
+ return nil
+ }
+
+ return c.client.sendRecv(&Tflushf{FID: c.fid}, &Rflushf{})
+}
+
+// Renamed implements File.Renamed.
+func (c *clientFile) Renamed(newDir File, newName string) {}
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
new file mode 100644
index 000000000..a456e8b3d
--- /dev/null
+++ b/pkg/p9/file.go
@@ -0,0 +1,256 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+)
+
+// Attacher is provided by the server.
+type Attacher interface {
+ // Attach returns a new File.
+ //
+ // The client-side attach will be translate to a series of walks from
+ // the file returned by this Attach call.
+ Attach() (File, error)
+}
+
+// File is a set of operations corresponding to a single node.
+//
+// Note that on the server side, the server logic places constraints on
+// concurrent operations to make things easier. This may reduce the need for
+// complex, error-prone locking and logic in the backend. These are documented
+// for each method.
+//
+// There are three different types of guarantees provided:
+//
+// none: There is no concurrency guarantee. The method may be invoked
+// concurrently with any other method on any other file.
+//
+// read: The method is guaranteed to be exclusive of any write or global
+// operation that is mutating the state of the directory tree starting at this
+// node. For example, this means creating new files, symlinks, directories or
+// renaming a directory entry (or renaming in to this target), but the method
+// may be called concurrently with other read methods.
+//
+// write: The method is guaranteed to be exclusive of any read, write or global
+// operation that is mutating the state of the directory tree starting at this
+// node, as described in read above. There may however, be other write
+// operations executing concurrently on other components in the directory tree.
+//
+// global: The method is guaranteed to be exclusive of any read, write or
+// global operation.
+type File interface {
+ // Walk walks to the path components given in names.
+ //
+ // Walk returns QIDs in the same order that the names were passed in.
+ //
+ // An empty list of arguments should return a copy of the current file.
+ //
+ // On the server, Walk has a read concurrency guarantee.
+ Walk(names []string) ([]QID, File, error)
+
+ // WalkGetAttr walks to the next file and returns its maximal set of
+ // attributes.
+ //
+ // Server-side p9.Files may return syscall.ENOSYS to indicate that Walk
+ // and GetAttr should be used separately to satisfy this request.
+ //
+ // On the server, WalkGetAttr has a read concurrency guarantee.
+ WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error)
+
+ // StatFS returns information about the file system associated with
+ // this file.
+ //
+ // On the server, StatFS has no concurrency guarantee.
+ StatFS() (FSStat, error)
+
+ // GetAttr returns attributes of this node.
+ //
+ // On the server, GetAttr has a read concurrency guarantee.
+ GetAttr(req AttrMask) (QID, AttrMask, Attr, error)
+
+ // SetAttr sets attributes on this node.
+ //
+ // On the server, SetAttr has a write concurrency guarantee.
+ SetAttr(valid SetAttrMask, attr SetAttr) error
+
+ // Allocate allows the caller to directly manipulate the allocated disk space
+ // for the file. See fallocate(2) for more details.
+ Allocate(mode AllocateMode, offset, length uint64) error
+
+ // Close is called when all references are dropped on the server side,
+ // and Close should be called by the client to drop all references.
+ //
+ // For server-side implementations of Close, the error is ignored.
+ //
+ // Close must be called even when Open has not been called.
+ //
+ // On the server, Close has no concurrency guarantee.
+ Close() error
+
+ // Open must be called prior to using Read, Write or Readdir. Once Open
+ // is called, some operations, such as Walk, will no longer work.
+ //
+ // On the client, Open should be called only once. The fd return is
+ // optional, and may be nil.
+ //
+ // On the server, Open has a read concurrency guarantee. If an *fd.FD
+ // is provided, ownership now belongs to the caller. Open is guaranteed
+ // to be called only once.
+ //
+ // N.B. The server must resolve any lazy paths when open is called.
+ // After this point, read and write may be called on files with no
+ // deletion check, so resolving in the data path is not viable.
+ Open(mode OpenFlags) (*fd.FD, QID, uint32, error)
+
+ // Read reads from this file. Open must be called first.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, ReadAt has a read concurrency guarantee. See Open for
+ // additional requirements regarding lazy path resolution.
+ ReadAt(p []byte, offset uint64) (int, error)
+
+ // Write writes to this file. Open must be called first.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, WriteAt has a read concurrency guarantee. See Open
+ // for additional requirements regarding lazy path resolution.
+ WriteAt(p []byte, offset uint64) (int, error)
+
+ // FSync syncs this node. Open must be called first.
+ //
+ // On the server, FSync has a read concurrency guarantee.
+ FSync() error
+
+ // Create creates a new regular file and opens it according to the
+ // flags given. This file is already Open.
+ //
+ // N.B. On the client, the returned file is a reference to the current
+ // file, which now represents the created file. This is not the case on
+ // the server. These semantics are very subtle and can easily lead to
+ // bugs, but are a consequence of the 9P create operation.
+ //
+ // See p9.File.Open for a description of *fd.FD.
+ //
+ // On the server, Create has a write concurrency guarantee.
+ Create(name string, flags OpenFlags, permissions FileMode, uid UID, gid GID) (*fd.FD, File, QID, uint32, error)
+
+ // Mkdir creates a subdirectory.
+ //
+ // On the server, Mkdir has a write concurrency guarantee.
+ Mkdir(name string, permissions FileMode, uid UID, gid GID) (QID, error)
+
+ // Symlink makes a new symbolic link.
+ //
+ // On the server, Symlink has a write concurrency guarantee.
+ Symlink(oldName string, newName string, uid UID, gid GID) (QID, error)
+
+ // Link makes a new hard link.
+ //
+ // On the server, Link has a write concurrency guarantee.
+ Link(target File, newName string) error
+
+ // Mknod makes a new device node.
+ //
+ // On the server, Mknod has a write concurrency guarantee.
+ Mknod(name string, mode FileMode, major uint32, minor uint32, uid UID, gid GID) (QID, error)
+
+ // Rename renames the file.
+ //
+ // Rename will never be called on the server, and RenameAt will always
+ // be used instead.
+ Rename(newDir File, newName string) error
+
+ // RenameAt renames a given file to a new name in a potentially new
+ // directory.
+ //
+ // oldName must be a name relative to this file, which must be a
+ // directory. newName is a name relative to newDir.
+ //
+ // On the server, RenameAt has a global concurrency guarantee.
+ RenameAt(oldName string, newDir File, newName string) error
+
+ // UnlinkAt the given named file.
+ //
+ // name must be a file relative to this directory.
+ //
+ // Flags are implementation-specific (e.g. O_DIRECTORY), but are
+ // generally Linux unlinkat(2) flags.
+ //
+ // On the server, UnlinkAt has a write concurrency guarantee.
+ UnlinkAt(name string, flags uint32) error
+
+ // Readdir reads directory entries.
+ //
+ // This may return io.EOF in addition to syscall.Errno values.
+ //
+ // On the server, Readdir has a read concurrency guarantee.
+ Readdir(offset uint64, count uint32) ([]Dirent, error)
+
+ // Readlink reads the link target.
+ //
+ // On the server, Readlink has a read concurrency guarantee.
+ Readlink() (string, error)
+
+ // Flush is called prior to Close.
+ //
+ // Whereas Close drops all references to the file, Flush cleans up the
+ // file state. Behavior is implementation-specific.
+ //
+ // Flush is not related to flush(9p). Flush is an extension to 9P2000.L,
+ // see version.go.
+ //
+ // On the server, Flush has a read concurrency guarantee.
+ Flush() error
+
+ // Connect establishes a new host-socket backed connection with a
+ // socket. A File does not need to be opened before it can be connected
+ // and it can be connected to multiple times resulting in a unique
+ // *fd.FD each time. In addition, the lifetime of the *fd.FD is
+ // independent from the lifetime of the p9.File and must be managed by
+ // the caller.
+ //
+ // The returned FD must be non-blocking.
+ //
+ // Flags indicates the requested type of socket.
+ //
+ // On the server, Connect has a read concurrency guarantee.
+ Connect(flags ConnectFlags) (*fd.FD, error)
+
+ // Renamed is called when this node is renamed.
+ //
+ // This may not fail. The file will hold a reference to its parent
+ // within the p9 package, and is therefore safe to use for the lifetime
+ // of this File (until Close is called).
+ //
+ // This method should not be called by clients, who should use the
+ // relevant Rename methods. (Although the method will be a no-op.)
+ //
+ // On the server, Renamed has a global concurrency guarantee.
+ Renamed(newDir File, newName string)
+}
+
+// DefaultWalkGetAttr implements File.WalkGetAttr to return ENOSYS for server-side Files.
+type DefaultWalkGetAttr struct{}
+
+// WalkGetAttr implements File.WalkGetAttr.
+func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) {
+ return nil, nil, AttrMask{}, Attr{}, syscall.ENOSYS
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
new file mode 100644
index 000000000..f32368763
--- /dev/null
+++ b/pkg/p9/handlers.go
@@ -0,0 +1,1291 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "path"
+ "strings"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+// ExtractErrno extracts a syscall.Errno from a error, best effort.
+func ExtractErrno(err error) syscall.Errno {
+ switch err {
+ case os.ErrNotExist:
+ return syscall.ENOENT
+ case os.ErrExist:
+ return syscall.EEXIST
+ case os.ErrPermission:
+ return syscall.EACCES
+ case os.ErrInvalid:
+ return syscall.EINVAL
+ }
+
+ // Attempt to unwrap.
+ switch e := err.(type) {
+ case syscall.Errno:
+ return e
+ case *os.PathError:
+ return ExtractErrno(e.Err)
+ case *os.SyscallError:
+ return ExtractErrno(e.Err)
+ }
+
+ // Default case.
+ log.Warningf("unknown error: %v", err)
+ return syscall.EIO
+}
+
+// newErr returns a new error message from an error.
+func newErr(err error) *Rlerror {
+ return &Rlerror{Error: uint32(ExtractErrno(err))}
+}
+
+// handler is implemented for server-handled messages.
+//
+// See server.go for call information.
+type handler interface {
+ // Handle handles the given message.
+ //
+ // This may modify the server state. The handle function must return a
+ // message which will be sent back to the client. It may be useful to
+ // use newErr to automatically extract an error message.
+ handle(cs *connState) message
+}
+
+// handle implements handler.handle.
+func (t *Tversion) handle(cs *connState) message {
+ if t.MSize == 0 {
+ return newErr(syscall.EINVAL)
+ }
+ if t.MSize > maximumLength {
+ return newErr(syscall.EINVAL)
+ }
+ atomic.StoreUint32(&cs.messageSize, t.MSize)
+ requested, ok := parseVersion(t.Version)
+ if !ok {
+ return newErr(syscall.EINVAL)
+ }
+ // The server cannot support newer versions that it doesn't know about. In this
+ // case we return EAGAIN to tell the client to try again with a lower version.
+ if requested > highestSupportedVersion {
+ return newErr(syscall.EAGAIN)
+ }
+ // From Tversion(9P): "The server may respond with the client’s version
+ // string, or a version string identifying an earlier defined protocol version".
+ atomic.StoreUint32(&cs.version, requested)
+ return &Rversion{
+ MSize: t.MSize,
+ Version: t.Version,
+ }
+}
+
+// handle implements handler.handle.
+func (t *Tflush) handle(cs *connState) message {
+ cs.WaitTag(t.OldTag)
+ return &Rflush{}
+}
+
+// checkSafeName validates the name and returns nil or returns an error.
+func checkSafeName(name string) error {
+ if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." {
+ return nil
+ }
+ return syscall.EINVAL
+}
+
+// handle implements handler.handle.
+func (t *Tclunk) handle(cs *connState) message {
+ if !cs.DeleteFID(t.FID) {
+ return newErr(syscall.EBADF)
+ }
+ return &Rclunk{}
+}
+
+// handle implements handler.handle.
+func (t *Tremove) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Frustratingly, because we can't be guaranteed that a rename is not
+ // occurring simultaneously with this removal, we need to acquire the
+ // global rename lock for this kind of remove operation to ensure that
+ // ref.parent does not change out from underneath us.
+ //
+ // This is why Tremove is a bad idea, and clients should generally use
+ // Tunlinkat. All p9 clients will use Tunlinkat.
+ err := ref.safelyGlobal(func() error {
+ // Is this a root? Can't remove that.
+ if ref.isRoot() {
+ return syscall.EINVAL
+ }
+
+ // N.B. this remove operation is permitted, even if the file is open.
+ // See also rename below for reasoning.
+
+ // Is this file already deleted?
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ // Retrieve the file's proper name.
+ name := ref.parent.pathNode.nameFor(ref)
+
+ // Attempt the removal.
+ if err := ref.parent.file.UnlinkAt(name, 0); err != nil {
+ return err
+ }
+
+ // Mark all relevant fids as deleted. We don't need to lock any
+ // individual nodes because we already hold the global lock.
+ ref.parent.markChildDeleted(name)
+ return nil
+ })
+
+ // "The remove request asks the file server both to remove the file
+ // represented by fid and to clunk the fid, even if the remove fails."
+ //
+ // "It is correct to consider remove to be a clunk with the side effect
+ // of removing the file if permissions allow."
+ // https://swtch.com/plan9port/man/man9/remove.html
+ if !cs.DeleteFID(t.FID) {
+ return newErr(syscall.EBADF)
+ }
+ if err != nil {
+ return newErr(err)
+ }
+
+ return &Rremove{}
+}
+
+// handle implements handler.handle.
+//
+// We don't support authentication, so this just returns ENOSYS.
+func (t *Tauth) handle(cs *connState) message {
+ return newErr(syscall.ENOSYS)
+}
+
+// handle implements handler.handle.
+func (t *Tattach) handle(cs *connState) message {
+ // Ensure no authentication FID is provided.
+ if t.Auth.AuthenticationFID != NoFID {
+ return newErr(syscall.EINVAL)
+ }
+
+ // Must provide an absolute path.
+ if path.IsAbs(t.Auth.AttachName) {
+ // Trim off the leading / if the path is absolute. We always
+ // treat attach paths as absolute and call attach with the root
+ // argument on the server file for clarity.
+ t.Auth.AttachName = t.Auth.AttachName[1:]
+ }
+
+ // Do the attach on the root.
+ sf, err := cs.server.attacher.Attach()
+ if err != nil {
+ return newErr(err)
+ }
+ qid, valid, attr, err := sf.GetAttr(AttrMaskAll())
+ if err != nil {
+ sf.Close() // Drop file.
+ return newErr(err)
+ }
+ if !valid.Mode {
+ sf.Close() // Drop file.
+ return newErr(syscall.EINVAL)
+ }
+
+ // Build a transient reference.
+ root := &fidRef{
+ server: cs.server,
+ parent: nil,
+ file: sf,
+ refs: 1,
+ mode: attr.Mode.FileType(),
+ pathNode: &cs.server.pathTree,
+ }
+ defer root.DecRef()
+
+ // Attach the root?
+ if len(t.Auth.AttachName) == 0 {
+ cs.InsertFID(t.FID, root)
+ return &Rattach{QID: qid}
+ }
+
+ // We want the same traversal checks to apply on attach, so always
+ // attach at the root and use the regular walk paths.
+ names := strings.Split(t.Auth.AttachName, "/")
+ _, newRef, _, _, err := doWalk(cs, root, names, false)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Insert the FID.
+ cs.InsertFID(t.FID, newRef)
+ return &Rattach{QID: qid}
+}
+
+// CanOpen returns whether this file open can be opened, read and written to.
+//
+// This includes everything except symlinks and sockets.
+func CanOpen(mode FileMode) bool {
+ return mode.IsRegular() || mode.IsDir() || mode.IsNamedPipe() || mode.IsBlockDevice() || mode.IsCharacterDevice()
+}
+
+// handle implements handler.handle.
+func (t *Tlopen) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ ref.openedMu.Lock()
+ defer ref.openedMu.Unlock()
+
+ // Has it been opened already?
+ if ref.opened || !CanOpen(ref.mode) {
+ return newErr(syscall.EINVAL)
+ }
+
+ // Are flags valid?
+ flags := t.Flags &^ OpenFlagsIgnoreMask
+ if flags&^OpenFlagsModeMask != 0 {
+ return newErr(syscall.EINVAL)
+ }
+
+ // Is this an attempt to open a directory as writable? Don't accept.
+ if ref.mode.IsDir() && flags != ReadOnly {
+ return newErr(syscall.EINVAL)
+ }
+
+ var (
+ qid QID
+ ioUnit uint32
+ osFile *fd.FD
+ )
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been deleted already?
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ // Do the open.
+ osFile, qid, ioUnit, err = ref.file.Open(t.Flags)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ // Mark file as opened and set open mode.
+ ref.opened = true
+ ref.openFlags = t.Flags
+
+ return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}
+}
+
+func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
+ // Don't allow complex names.
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var (
+ osFile *fd.FD
+ nsf File
+ qid QID
+ ioUnit uint32
+ newRef *fidRef
+ )
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow creation from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the create.
+ osFile, nsf, qid, ioUnit, err = ref.file.Create(t.Name, t.OpenFlags, t.Permissions, uid, t.GID)
+ if err != nil {
+ return err
+ }
+
+ newRef = &fidRef{
+ server: cs.server,
+ parent: ref,
+ file: nsf,
+ opened: true,
+ openFlags: t.OpenFlags,
+ mode: ModeRegular,
+ pathNode: ref.pathNode.pathNodeFor(t.Name),
+ }
+ ref.pathNode.addChild(newRef, t.Name)
+ ref.IncRef() // Acquire parent reference.
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+
+ // Replace the FID reference.
+ cs.InsertFID(t.FID, newRef)
+
+ return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tlcreate) handle(cs *connState) message {
+ rlcreate, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rlcreate
+}
+
+// handle implements handler.handle.
+func (t *Tsymlink) handle(cs *connState) message {
+ rsymlink, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rsymlink
+}
+
+func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
+ // Don't allow complex names.
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow symlinks from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the symlink.
+ qid, err = ref.file.Symlink(t.Target, t.Name, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rsymlink{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tlink) handle(cs *connState) message {
+ // Don't allow complex names.
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Lookup the other FID.
+ refTarget, ok := cs.LookupFID(t.Target)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow create links from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the link.
+ return ref.file.Link(refTarget.file, t.Name)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rlink{}
+}
+
+// handle implements handler.handle.
+func (t *Trenameat) handle(cs *connState) message {
+ // Don't allow complex names.
+ if err := checkSafeName(t.OldName); err != nil {
+ return newErr(err)
+ }
+ if err := checkSafeName(t.NewName); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.OldDirectory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Lookup the other FID.
+ refTarget, ok := cs.LookupFID(t.NewDirectory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ // Perform the rename holding the global lock.
+ if err := ref.safelyGlobal(func() (err error) {
+ // Don't allow renaming across deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() || refTarget.isDeleted() || !refTarget.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Is this the same file? If yes, short-circuit and return success.
+ if ref.pathNode == refTarget.pathNode && t.OldName == t.NewName {
+ return nil
+ }
+
+ // Attempt the actual rename.
+ if err := ref.file.RenameAt(t.OldName, refTarget.file, t.NewName); err != nil {
+ return err
+ }
+
+ // Update the path tree.
+ ref.renameChildTo(t.OldName, refTarget, t.NewName)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rrenameat{}
+}
+
+// handle implements handler.handle.
+func (t *Tunlinkat) handle(cs *connState) message {
+ // Don't allow complex names.
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow deletion from non-directories or deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Before we do the unlink itself, we need to ensure that there
+ // are no operations in flight on associated path node. The
+ // child's path node lock must be held to ensure that the
+ // unlink at marking the child deleted below is atomic with
+ // respect to any other read or write operations.
+ //
+ // This is one case where we have a lock ordering issue, but
+ // since we always acquire deeper in the hierarchy, we know
+ // that we are free of lock cycles.
+ childPathNode := ref.pathNode.pathNodeFor(t.Name)
+ childPathNode.mu.Lock()
+ defer childPathNode.mu.Unlock()
+
+ // Do the unlink.
+ err = ref.file.UnlinkAt(t.Name, t.Flags)
+ if err != nil {
+ return err
+ }
+
+ // Mark the path as deleted.
+ ref.markChildDeleted(t.Name)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Runlinkat{}
+}
+
+// handle implements handler.handle.
+func (t *Trename) handle(cs *connState) message {
+ // Don't allow complex names.
+ if err := checkSafeName(t.Name); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Lookup the target.
+ refTarget, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer refTarget.DecRef()
+
+ if err := ref.safelyGlobal(func() (err error) {
+ // Don't allow a root rename.
+ if ref.isRoot() {
+ return syscall.EINVAL
+ }
+
+ // Don't allow renaming deleting entries, or target non-directories.
+ if ref.isDeleted() || refTarget.isDeleted() || !refTarget.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // If the parent is deleted, but we not, something is seriously wrong.
+ // It's fail to die at this point with an assertion failure.
+ if ref.parent.isDeleted() {
+ panic(fmt.Sprintf("parent %+v deleted, child %+v is not", ref.parent, ref))
+ }
+
+ // N.B. The rename operation is allowed to proceed on open files. It
+ // does impact the state of its parent, but this is merely a sanity
+ // check in any case, and the operation is safe. There may be other
+ // files corresponding to the same path that are renamed anyways.
+
+ // Check for the exact same file and short-circuit.
+ oldName := ref.parent.pathNode.nameFor(ref)
+ if ref.parent.pathNode == refTarget.pathNode && oldName == t.Name {
+ return nil
+ }
+
+ // Call the rename method on the parent.
+ if err := ref.parent.file.RenameAt(oldName, refTarget.file, t.Name); err != nil {
+ return err
+ }
+
+ // Update the path tree.
+ ref.parent.renameChildTo(oldName, refTarget, t.Name)
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rrename{}
+}
+
+// handle implements handler.handle.
+func (t *Treadlink) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var target string
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow readlink on deleted files. There is no need to
+ // check if this file is opened because symlinks cannot be
+ // opened.
+ if ref.isDeleted() || !ref.mode.IsSymlink() {
+ return syscall.EINVAL
+ }
+
+ // Do the read.
+ target, err = ref.file.Readlink()
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rreadlink{target}
+}
+
+// handle implements handler.handle.
+func (t *Tread) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Constrain the size of the read buffer.
+ if int(t.Count) > int(maximumLength) {
+ return newErr(syscall.ENOBUFS)
+ }
+
+ var (
+ data = make([]byte, t.Count)
+ n int
+ )
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be read? Check permissions.
+ if openFlags&OpenFlagsModeMask == WriteOnly {
+ return syscall.EPERM
+ }
+
+ n, err = ref.file.ReadAt(data, t.Offset)
+ return err
+ }); err != nil && err != io.EOF {
+ return newErr(err)
+ }
+
+ return &Rread{Data: data[:n]}
+}
+
+// handle implements handler.handle.
+func (t *Twrite) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var n int
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be written? Check permissions.
+ if openFlags&OpenFlagsModeMask == ReadOnly {
+ return syscall.EPERM
+ }
+
+ n, err = ref.file.WriteAt(t.Data, t.Offset)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rwrite{Count: uint32(n)}
+}
+
+// handle implements handler.handle.
+func (t *Tmknod) handle(cs *connState) message {
+ rmknod, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rmknod
+}
+
+func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) {
+ // Don't allow complex names.
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow mknod on deleted files.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the mknod.
+ qid, err = ref.file.Mknod(t.Name, t.Mode, t.Major, t.Minor, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rmknod{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tmkdir) handle(cs *connState) message {
+ rmkdir, err := t.do(cs, NoUID)
+ if err != nil {
+ return newErr(err)
+ }
+ return rmkdir
+}
+
+func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
+ // Don't allow complex names.
+ if err := checkSafeName(t.Name); err != nil {
+ return nil, err
+ }
+
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer ref.DecRef()
+
+ var qid QID
+ if err := ref.safelyWrite(func() (err error) {
+ // Don't allow mkdir on deleted files.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Not allowed on open directories.
+ if _, opened := ref.OpenFlags(); opened {
+ return syscall.EINVAL
+ }
+
+ // Do the mkdir.
+ qid, err = ref.file.Mkdir(t.Name, t.Permissions, uid, t.GID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+
+ return &Rmkdir{QID: qid}, nil
+}
+
+// handle implements handler.handle.
+func (t *Tgetattr) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We allow getattr on deleted files. Depending on the backing
+ // implementation, it's possible that races exist that might allow
+ // fetching attributes of other files. But we need to generally allow
+ // refreshing attributes and this is a minor leak, if at all.
+
+ var (
+ qid QID
+ valid AttrMask
+ attr Attr
+ )
+ if err := ref.safelyRead(func() (err error) {
+ qid, valid, attr, err = ref.file.GetAttr(t.AttrMask)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rgetattr{QID: qid, Valid: valid, Attr: attr}
+}
+
+// handle implements handler.handle.
+func (t *Tsetattr) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // We don't allow setattr on files that have been deleted.
+ // This might be technically incorrect, as it's possible that
+ // there were multiple links and you can still change the
+ // corresponding inode information.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ // Set the attributes.
+ return ref.file.SetAttr(t.Valid, t.SetAttr)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rsetattr{}
+}
+
+// handle implements handler.handle.
+func (t *Tallocate) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Has it been opened already?
+ openFlags, opened := ref.OpenFlags()
+ if !opened {
+ return syscall.EINVAL
+ }
+
+ // Can it be written? Check permissions.
+ if openFlags&OpenFlagsModeMask == ReadOnly {
+ return syscall.EBADF
+ }
+
+ // We don't allow allocate on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ return ref.file.Allocate(t.Mode, t.Offset, t.Length)
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rallocate{}
+}
+
+// handle implements handler.handle.
+func (t *Txattrwalk) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We don't support extended attributes.
+ return newErr(syscall.ENODATA)
+}
+
+// handle implements handler.handle.
+func (t *Txattrcreate) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // We don't support extended attributes.
+ return newErr(syscall.ENOSYS)
+}
+
+// handle implements handler.handle.
+func (t *Treaddir) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.Directory)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var entries []Dirent
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow reading deleted directories.
+ if ref.isDeleted() || !ref.mode.IsDir() {
+ return syscall.EINVAL
+ }
+
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); !opened {
+ return syscall.EINVAL
+ }
+
+ // Read the entries.
+ entries, err = ref.file.Readdir(t.Offset, t.Count)
+ if err != nil && err != io.EOF {
+ return err
+ }
+ return nil
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rreaddir{Count: t.Count, Entries: entries}
+}
+
+// handle implements handler.handle.
+func (t *Tfsync) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyRead(func() (err error) {
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); !opened {
+ return syscall.EINVAL
+ }
+
+ // Perform the sync.
+ return ref.file.FSync()
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rfsync{}
+}
+
+// handle implements handler.handle.
+func (t *Tstatfs) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ st, err := ref.file.StatFS()
+ if err != nil {
+ return newErr(err)
+ }
+
+ return &Rstatfs{st}
+}
+
+// handle implements handler.handle.
+func (t *Tflushf) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyRead(ref.file.Flush); err != nil {
+ return newErr(err)
+ }
+
+ return &Rflushf{}
+}
+
+// walkOne walks zero or one path elements.
+//
+// The slice passed as qids is append and returned.
+func walkOne(qids []QID, from File, names []string, getattr bool) ([]QID, File, AttrMask, Attr, error) {
+ if len(names) > 1 {
+ // We require exactly zero or one elements.
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+ var (
+ localQIDs []QID
+ sf File
+ valid AttrMask
+ attr Attr
+ err error
+ )
+ switch {
+ case getattr:
+ localQIDs, sf, valid, attr, err = from.WalkGetAttr(names)
+ // Can't put fallthrough in the if because Go.
+ if err != syscall.ENOSYS {
+ break
+ }
+ fallthrough
+ default:
+ localQIDs, sf, err = from.Walk(names)
+ if err != nil {
+ // No way to walk this element.
+ break
+ }
+ if getattr {
+ _, valid, attr, err = sf.GetAttr(AttrMaskAll())
+ if err != nil {
+ // Don't leak the file.
+ sf.Close()
+ }
+ }
+ }
+ if err != nil {
+ // Error walking, don't return anything.
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ if len(localQIDs) != 1 {
+ // Expected a single QID.
+ sf.Close()
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+ return append(qids, localQIDs...), sf, valid, attr, nil
+}
+
+// doWalk walks from a given fidRef.
+//
+// This enforces that all intermediate nodes are walkable (directories). The
+// fidRef returned (newRef) has a reference associated with it that is now
+// owned by the caller and must be handled appropriately.
+func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QID, newRef *fidRef, valid AttrMask, attr Attr, err error) {
+ // Check the names.
+ for _, name := range names {
+ err = checkSafeName(name)
+ if err != nil {
+ return
+ }
+ }
+
+ // Has it been opened already?
+ if _, opened := ref.OpenFlags(); opened {
+ err = syscall.EBUSY
+ return
+ }
+
+ // Is this an empty list? Handle specially. We don't actually need to
+ // validate anything since this is always permitted.
+ if len(names) == 0 {
+ var sf File // Temporary.
+ if err := ref.maybeParent().safelyRead(func() (err error) {
+ // Clone the single element.
+ qids, sf, valid, attr, err = walkOne(nil, ref.file, nil, getattr)
+ if err != nil {
+ return err
+ }
+
+ newRef = &fidRef{
+ server: cs.server,
+ parent: ref.parent,
+ file: sf,
+ mode: ref.mode,
+ pathNode: ref.pathNode,
+
+ // For the clone case, the cloned fid must
+ // preserve the deleted property of the
+ // original FID.
+ deleted: ref.deleted,
+ }
+ if !ref.isRoot() {
+ if !newRef.isDeleted() {
+ // Add only if a non-root node; the same node.
+ ref.parent.pathNode.addChild(newRef, ref.parent.pathNode.nameFor(ref))
+ }
+ ref.parent.IncRef() // Acquire parent reference.
+ }
+ // doWalk returns a reference.
+ newRef.IncRef()
+ return nil
+ }); err != nil {
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ // Do not return the new QID.
+ return nil, newRef, valid, attr, nil
+ }
+
+ // Do the walk, one element at a time.
+ walkRef := ref
+ walkRef.IncRef()
+ for i := 0; i < len(names); i++ {
+ // We won't allow beyond past symlinks; stop here if this isn't
+ // a proper directory and we have additional paths to walk.
+ if !walkRef.mode.IsDir() {
+ walkRef.DecRef() // Drop walk reference; no lock required.
+ return nil, nil, AttrMask{}, Attr{}, syscall.EINVAL
+ }
+
+ var sf File // Temporary.
+ if err := walkRef.safelyRead(func() (err error) {
+ // Pass getattr = true to walkOne since we need the file type for
+ // newRef.
+ qids, sf, valid, attr, err = walkOne(qids, walkRef.file, names[i:i+1], true)
+ if err != nil {
+ return err
+ }
+
+ // Note that we don't need to acquire a lock on any of
+ // these individual instances. That's because they are
+ // not actually addressable via a FID. They are
+ // anonymous. They exist in the tree for tracking
+ // purposes.
+ newRef := &fidRef{
+ server: cs.server,
+ parent: walkRef,
+ file: sf,
+ mode: attr.Mode.FileType(),
+ pathNode: walkRef.pathNode.pathNodeFor(names[i]),
+ }
+ walkRef.pathNode.addChild(newRef, names[i])
+ // We allow our walk reference to become the new parent
+ // reference here and so we don't IncRef. Instead, just
+ // set walkRef to the newRef above and acquire a new
+ // walk reference.
+ walkRef = newRef
+ walkRef.IncRef()
+ return nil
+ }); err != nil {
+ walkRef.DecRef() // Drop the old walkRef.
+ return nil, nil, AttrMask{}, Attr{}, err
+ }
+ }
+
+ // Success.
+ return qids, walkRef, valid, attr, nil
+}
+
+// handle implements handler.handle.
+func (t *Twalk) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Do the walk.
+ qids, newRef, _, _, err := doWalk(cs, ref, t.Names, false)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Install the new FID.
+ cs.InsertFID(t.NewFID, newRef)
+ return &Rwalk{QIDs: qids}
+}
+
+// handle implements handler.handle.
+func (t *Twalkgetattr) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ // Do the walk.
+ qids, newRef, valid, attr, err := doWalk(cs, ref, t.Names, true)
+ if err != nil {
+ return newErr(err)
+ }
+ defer newRef.DecRef()
+
+ // Install the new FID.
+ cs.InsertFID(t.NewFID, newRef)
+ return &Rwalkgetattr{QIDs: qids, Valid: valid, Attr: attr}
+}
+
+// handle implements handler.handle.
+func (t *Tucreate) handle(cs *connState) message {
+ rlcreate, err := t.Tlcreate.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rucreate{*rlcreate}
+}
+
+// handle implements handler.handle.
+func (t *Tumkdir) handle(cs *connState) message {
+ rmkdir, err := t.Tmkdir.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rumkdir{*rmkdir}
+}
+
+// handle implements handler.handle.
+func (t *Tusymlink) handle(cs *connState) message {
+ rsymlink, err := t.Tsymlink.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rusymlink{*rsymlink}
+}
+
+// handle implements handler.handle.
+func (t *Tumknod) handle(cs *connState) message {
+ rmknod, err := t.Tmknod.do(cs, t.UID)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rumknod{*rmknod}
+}
+
+// handle implements handler.handle.
+func (t *Tlconnect) handle(cs *connState) message {
+ // Lookup the FID.
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var osFile *fd.FD
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow connecting to deleted files.
+ if ref.isDeleted() || !ref.mode.IsSocket() {
+ return syscall.EINVAL
+ }
+
+ // Do the connect.
+ osFile, err = ref.file.Connect(t.Flags)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ return &Rlconnect{File: osFile}
+}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
new file mode 100644
index 000000000..75d6bc832
--- /dev/null
+++ b/pkg/p9/messages.go
@@ -0,0 +1,2359 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+)
+
+// ErrInvalidMsgType is returned when an unsupported message type is found.
+type ErrInvalidMsgType struct {
+ MsgType
+}
+
+// Error returns a useful string.
+func (e *ErrInvalidMsgType) Error() string {
+ return fmt.Sprintf("invalid message type: %d", e.MsgType)
+}
+
+// message is a generic 9P message.
+type message interface {
+ encoder
+ fmt.Stringer
+
+ // Type returns the message type number.
+ Type() MsgType
+}
+
+// payloader is a special message which may include an inline payload.
+type payloader interface {
+ // FixedSize returns the size of the fixed portion of this message.
+ FixedSize() uint32
+
+ // Payload returns the payload for sending.
+ Payload() []byte
+
+ // SetPayload returns the decoded message.
+ //
+ // This is going to be total message size - FixedSize. But this should
+ // be validated during Decode, which will be called after SetPayload.
+ SetPayload([]byte)
+}
+
+// filer is a message capable of passing a file.
+type filer interface {
+ // FilePayload returns the file payload.
+ FilePayload() *fd.FD
+
+ // SetFilePayload sets the file payload.
+ SetFilePayload(*fd.FD)
+}
+
+// Tversion is a version request.
+type Tversion struct {
+ // MSize is the message size to use.
+ MSize uint32
+
+ // Version is the version string.
+ //
+ // For this implementation, this must be 9P2000.L.
+ Version string
+}
+
+// Decode implements encoder.Decode.
+func (t *Tversion) Decode(b *buffer) {
+ t.MSize = b.Read32()
+ t.Version = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tversion) Encode(b *buffer) {
+ b.Write32(t.MSize)
+ b.WriteString(t.Version)
+}
+
+// Type implements message.Type.
+func (*Tversion) Type() MsgType {
+ return MsgTversion
+}
+
+// String implements fmt.Stringer.
+func (t *Tversion) String() string {
+ return fmt.Sprintf("Tversion{MSize: %d, Version: %s}", t.MSize, t.Version)
+}
+
+// Rversion is a version response.
+type Rversion struct {
+ // MSize is the negotiated size.
+ MSize uint32
+
+ // Version is the negotiated version.
+ Version string
+}
+
+// Decode implements encoder.Decode.
+func (r *Rversion) Decode(b *buffer) {
+ r.MSize = b.Read32()
+ r.Version = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rversion) Encode(b *buffer) {
+ b.Write32(r.MSize)
+ b.WriteString(r.Version)
+}
+
+// Type implements message.Type.
+func (*Rversion) Type() MsgType {
+ return MsgRversion
+}
+
+// String implements fmt.Stringer.
+func (r *Rversion) String() string {
+ return fmt.Sprintf("Rversion{MSize: %d, Version: %s}", r.MSize, r.Version)
+}
+
+// Tflush is a flush request.
+type Tflush struct {
+ // OldTag is the tag to wait on.
+ OldTag Tag
+}
+
+// Decode implements encoder.Decode.
+func (t *Tflush) Decode(b *buffer) {
+ t.OldTag = b.ReadTag()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tflush) Encode(b *buffer) {
+ b.WriteTag(t.OldTag)
+}
+
+// Type implements message.Type.
+func (*Tflush) Type() MsgType {
+ return MsgTflush
+}
+
+// String implements fmt.Stringer.
+func (t *Tflush) String() string {
+ return fmt.Sprintf("Tflush{OldTag: %d}", t.OldTag)
+}
+
+// Rflush is a flush response.
+type Rflush struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rflush) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rflush) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rflush) Type() MsgType {
+ return MsgRflush
+}
+
+// String implements fmt.Stringer.
+func (r *Rflush) String() string {
+ return fmt.Sprintf("RFlush{}")
+}
+
+// Twalk is a walk request.
+type Twalk struct {
+ // FID is the FID to be walked.
+ FID FID
+
+ // NewFID is the resulting FID.
+ NewFID FID
+
+ // Names are the set of names to be walked.
+ Names []string
+}
+
+// Decode implements encoder.Decode.
+func (t *Twalk) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ n := b.Read16()
+ t.Names = t.Names[:0]
+ for i := 0; i < int(n); i++ {
+ t.Names = append(t.Names, b.ReadString())
+ }
+}
+
+// Encode implements encoder.Encode.
+func (t *Twalk) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.Write16(uint16(len(t.Names)))
+ for _, name := range t.Names {
+ b.WriteString(name)
+ }
+}
+
+// Type implements message.Type.
+func (*Twalk) Type() MsgType {
+ return MsgTwalk
+}
+
+// String implements fmt.Stringer.
+func (t *Twalk) String() string {
+ return fmt.Sprintf("Twalk{FID: %d, NewFID: %d, Names: %v}", t.FID, t.NewFID, t.Names)
+}
+
+// Rwalk is a walk response.
+type Rwalk struct {
+ // QIDs are the set of QIDs returned.
+ QIDs []QID
+}
+
+// Decode implements encoder.Decode.
+func (r *Rwalk) Decode(b *buffer) {
+ n := b.Read16()
+ r.QIDs = r.QIDs[:0]
+ for i := 0; i < int(n); i++ {
+ var q QID
+ q.Decode(b)
+ r.QIDs = append(r.QIDs, q)
+ }
+}
+
+// Encode implements encoder.Encode.
+func (r *Rwalk) Encode(b *buffer) {
+ b.Write16(uint16(len(r.QIDs)))
+ for _, q := range r.QIDs {
+ q.Encode(b)
+ }
+}
+
+// Type implements message.Type.
+func (*Rwalk) Type() MsgType {
+ return MsgRwalk
+}
+
+// String implements fmt.Stringer.
+func (r *Rwalk) String() string {
+ return fmt.Sprintf("Rwalk{QIDs: %v}", r.QIDs)
+}
+
+// Tclunk is a close request.
+type Tclunk struct {
+ // FID is the FID to be closed.
+ FID FID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tclunk) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tclunk) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tclunk) Type() MsgType {
+ return MsgTclunk
+}
+
+// String implements fmt.Stringer.
+func (t *Tclunk) String() string {
+ return fmt.Sprintf("Tclunk{FID: %d}", t.FID)
+}
+
+// Rclunk is a close response.
+type Rclunk struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rclunk) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rclunk) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rclunk) Type() MsgType {
+ return MsgRclunk
+}
+
+// String implements fmt.Stringer.
+func (r *Rclunk) String() string {
+ return fmt.Sprintf("Rclunk{}")
+}
+
+// Tremove is a remove request.
+//
+// This will eventually be replaced by Tunlinkat.
+type Tremove struct {
+ // FID is the FID to be removed.
+ FID FID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tremove) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tremove) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tremove) Type() MsgType {
+ return MsgTremove
+}
+
+// String implements fmt.Stringer.
+func (t *Tremove) String() string {
+ return fmt.Sprintf("Tremove{FID: %d}", t.FID)
+}
+
+// Rremove is a remove response.
+type Rremove struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rremove) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rremove) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rremove) Type() MsgType {
+ return MsgRremove
+}
+
+// String implements fmt.Stringer.
+func (r *Rremove) String() string {
+ return fmt.Sprintf("Rremove{}")
+}
+
+// Rlerror is an error response.
+//
+// Note that this replaces the error code used in 9p.
+type Rlerror struct {
+ Error uint32
+}
+
+// Decode implements encoder.Decode.
+func (r *Rlerror) Decode(b *buffer) {
+ r.Error = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rlerror) Encode(b *buffer) {
+ b.Write32(r.Error)
+}
+
+// Type implements message.Type.
+func (*Rlerror) Type() MsgType {
+ return MsgRlerror
+}
+
+// String implements fmt.Stringer.
+func (r *Rlerror) String() string {
+ return fmt.Sprintf("Rlerror{Error: %d}", r.Error)
+}
+
+// Tauth is an authentication request.
+type Tauth struct {
+ // AuthenticationFID is the FID to attach the authentication result.
+ AuthenticationFID FID
+
+ // UserName is the user to attach.
+ UserName string
+
+ // AttachName is the attach name.
+ AttachName string
+
+ // UserID is the numeric identifier for UserName.
+ UID UID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tauth) Decode(b *buffer) {
+ t.AuthenticationFID = b.ReadFID()
+ t.UserName = b.ReadString()
+ t.AttachName = b.ReadString()
+ t.UID = b.ReadUID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tauth) Encode(b *buffer) {
+ b.WriteFID(t.AuthenticationFID)
+ b.WriteString(t.UserName)
+ b.WriteString(t.AttachName)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (*Tauth) Type() MsgType {
+ return MsgTauth
+}
+
+// String implements fmt.Stringer.
+func (t *Tauth) String() string {
+ return fmt.Sprintf("Tauth{AuthFID: %d, UserName: %s, AttachName: %s, UID: %d", t.AuthenticationFID, t.UserName, t.AttachName, t.UID)
+}
+
+// Rauth is an authentication response.
+//
+// Encode, Decode and Length are inherited directly from QID.
+type Rauth struct {
+ QID
+}
+
+// Type implements message.Type.
+func (*Rauth) Type() MsgType {
+ return MsgRauth
+}
+
+// String implements fmt.Stringer.
+func (r *Rauth) String() string {
+ return fmt.Sprintf("Rauth{QID: %s}", r.QID)
+}
+
+// Tattach is an attach request.
+type Tattach struct {
+ // FID is the FID to be attached.
+ FID FID
+
+ // Auth is the embedded authentication request.
+ //
+ // See client.Attach for information regarding authentication.
+ Auth Tauth
+}
+
+// Decode implements encoder.Decode.
+func (t *Tattach) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Auth.Decode(b)
+}
+
+// Encode implements encoder.Encode.
+func (t *Tattach) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Auth.Encode(b)
+}
+
+// Type implements message.Type.
+func (*Tattach) Type() MsgType {
+ return MsgTattach
+}
+
+// String implements fmt.Stringer.
+func (t *Tattach) String() string {
+ return fmt.Sprintf("Tattach{FID: %d, AuthFID: %d, UserName: %s, AttachName: %s, UID: %d}", t.FID, t.Auth.AuthenticationFID, t.Auth.UserName, t.Auth.AttachName, t.Auth.UID)
+}
+
+// Rattach is an attach response.
+type Rattach struct {
+ QID
+}
+
+// Type implements message.Type.
+func (*Rattach) Type() MsgType {
+ return MsgRattach
+}
+
+// String implements fmt.Stringer.
+func (r *Rattach) String() string {
+ return fmt.Sprintf("Rattach{QID: %s}", r.QID)
+}
+
+// Tlopen is an open request.
+type Tlopen struct {
+ // FID is the FID to be opened.
+ FID FID
+
+ // Flags are the open flags.
+ Flags OpenFlags
+}
+
+// Decode implements encoder.Decode.
+func (t *Tlopen) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Flags = b.ReadOpenFlags()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tlopen) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteOpenFlags(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tlopen) Type() MsgType {
+ return MsgTlopen
+}
+
+// String implements fmt.Stringer.
+func (t *Tlopen) String() string {
+ return fmt.Sprintf("Tlopen{FID: %d, Flags: %v}", t.FID, t.Flags)
+}
+
+// Rlopen is a open response.
+type Rlopen struct {
+ // QID is the file's QID.
+ QID QID
+
+ // IoUnit is the recommended I/O unit.
+ IoUnit uint32
+
+ // File may be attached via the socket.
+ //
+ // This is an extension specific to this package.
+ File *fd.FD
+}
+
+// Decode implements encoder.Decode.
+func (r *Rlopen) Decode(b *buffer) {
+ r.QID.Decode(b)
+ r.IoUnit = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rlopen) Encode(b *buffer) {
+ r.QID.Encode(b)
+ b.Write32(r.IoUnit)
+}
+
+// Type implements message.Type.
+func (*Rlopen) Type() MsgType {
+ return MsgRlopen
+}
+
+// FilePayload returns the file payload.
+func (r *Rlopen) FilePayload() *fd.FD {
+ return r.File
+}
+
+// SetFilePayload sets the received file.
+func (r *Rlopen) SetFilePayload(file *fd.FD) {
+ r.File = file
+}
+
+// String implements fmt.Stringer.
+func (r *Rlopen) String() string {
+ return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
+}
+
+// Tlcreate is a create request.
+type Tlcreate struct {
+ // FID is the parent FID.
+ //
+ // This becomes the new file.
+ FID FID
+
+ // Name is the file name to create.
+ Name string
+
+ // Mode is the open mode (O_RDWR, etc.).
+ //
+ // Note that flags like O_TRUNC are ignored, as is O_EXCL. All
+ // create operations are exclusive.
+ OpenFlags OpenFlags
+
+ // Permissions is the set of permission bits.
+ Permissions FileMode
+
+ // GID is the group ID to use for creating the file.
+ GID GID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tlcreate) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.OpenFlags = b.ReadOpenFlags()
+ t.Permissions = b.ReadPermissions()
+ t.GID = b.ReadGID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tlcreate) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.WriteOpenFlags(t.OpenFlags)
+ b.WritePermissions(t.Permissions)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tlcreate) Type() MsgType {
+ return MsgTlcreate
+}
+
+// String implements fmt.Stringer.
+func (t *Tlcreate) String() string {
+ return fmt.Sprintf("Tlcreate{FID: %d, Name: %s, OpenFlags: %s, Permissions: 0o%o, GID: %d}", t.FID, t.Name, t.OpenFlags, t.Permissions, t.GID)
+}
+
+// Rlcreate is a create response.
+//
+// The Encode, Decode, etc. methods are inherited from Rlopen.
+type Rlcreate struct {
+ Rlopen
+}
+
+// Type implements message.Type.
+func (*Rlcreate) Type() MsgType {
+ return MsgRlcreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rlcreate) String() string {
+ return fmt.Sprintf("Rlcreate{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
+}
+
+// Tsymlink is a symlink request.
+type Tsymlink struct {
+ // Directory is the directory FID.
+ Directory FID
+
+ // Name is the new in the directory.
+ Name string
+
+ // Target is the symlink target.
+ Target string
+
+ // GID is the owning group.
+ GID GID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tsymlink) Decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Target = b.ReadString()
+ t.GID = b.ReadGID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tsymlink) Encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WriteString(t.Target)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tsymlink) Type() MsgType {
+ return MsgTsymlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tsymlink) String() string {
+ return fmt.Sprintf("Tsymlink{DirectoryFID: %d, Name: %s, Target: %s, GID: %d}", t.Directory, t.Name, t.Target, t.GID)
+}
+
+// Rsymlink is a symlink response.
+type Rsymlink struct {
+ // QID is the new symlink's QID.
+ QID QID
+}
+
+// Decode implements encoder.Decode.
+func (r *Rsymlink) Decode(b *buffer) {
+ r.QID.Decode(b)
+}
+
+// Encode implements encoder.Encode.
+func (r *Rsymlink) Encode(b *buffer) {
+ r.QID.Encode(b)
+}
+
+// Type implements message.Type.
+func (*Rsymlink) Type() MsgType {
+ return MsgRsymlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rsymlink) String() string {
+ return fmt.Sprintf("Rsymlink{QID: %s}", r.QID)
+}
+
+// Tlink is a link request.
+type Tlink struct {
+ // Directory is the directory to contain the link.
+ Directory FID
+
+ // FID is the target.
+ Target FID
+
+ // Name is the new source name.
+ Name string
+}
+
+// Decode implements encoder.Decode.
+func (t *Tlink) Decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Target = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tlink) Encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteFID(t.Target)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Tlink) Type() MsgType {
+ return MsgTlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tlink) String() string {
+ return fmt.Sprintf("Tlink{DirectoryFID: %d, TargetFID: %d, Name: %s}", t.Directory, t.Target, t.Name)
+}
+
+// Rlink is a link response.
+type Rlink struct {
+}
+
+// Type implements message.Type.
+func (*Rlink) Type() MsgType {
+ return MsgRlink
+}
+
+// Decode implements encoder.Decode.
+func (*Rlink) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rlink) Encode(b *buffer) {
+}
+
+// String implements fmt.Stringer.
+func (r *Rlink) String() string {
+ return fmt.Sprintf("Rlink{}")
+}
+
+// Trenameat is a rename request.
+type Trenameat struct {
+ // OldDirectory is the source directory.
+ OldDirectory FID
+
+ // OldName is the source file name.
+ OldName string
+
+ // NewDirectory is the target directory.
+ NewDirectory FID
+
+ // NewName is the new file name.
+ NewName string
+}
+
+// Decode implements encoder.Decode.
+func (t *Trenameat) Decode(b *buffer) {
+ t.OldDirectory = b.ReadFID()
+ t.OldName = b.ReadString()
+ t.NewDirectory = b.ReadFID()
+ t.NewName = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (t *Trenameat) Encode(b *buffer) {
+ b.WriteFID(t.OldDirectory)
+ b.WriteString(t.OldName)
+ b.WriteFID(t.NewDirectory)
+ b.WriteString(t.NewName)
+}
+
+// Type implements message.Type.
+func (*Trenameat) Type() MsgType {
+ return MsgTrenameat
+}
+
+// String implements fmt.Stringer.
+func (t *Trenameat) String() string {
+ return fmt.Sprintf("TrenameAt{OldDirectoryFID: %d, OldName: %s, NewDirectoryFID: %d, NewName: %s}", t.OldDirectory, t.OldName, t.NewDirectory, t.NewName)
+}
+
+// Rrenameat is a rename response.
+type Rrenameat struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rrenameat) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rrenameat) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rrenameat) Type() MsgType {
+ return MsgRrenameat
+}
+
+// String implements fmt.Stringer.
+func (r *Rrenameat) String() string {
+ return fmt.Sprintf("Rrenameat{}")
+}
+
+// Tunlinkat is an unlink request.
+type Tunlinkat struct {
+ // Directory is the originating directory.
+ Directory FID
+
+ // Name is the name of the entry to unlink.
+ Name string
+
+ // Flags are extra flags (e.g. O_DIRECTORY). These are not interpreted by p9.
+ Flags uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tunlinkat) Decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Flags = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tunlinkat) Encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tunlinkat) Type() MsgType {
+ return MsgTunlinkat
+}
+
+// String implements fmt.Stringer.
+func (t *Tunlinkat) String() string {
+ return fmt.Sprintf("Tunlinkat{DirectoryFID: %d, Name: %s, Flags: 0x%X}", t.Directory, t.Name, t.Flags)
+}
+
+// Runlinkat is an unlink response.
+type Runlinkat struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Runlinkat) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Runlinkat) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Runlinkat) Type() MsgType {
+ return MsgRunlinkat
+}
+
+// String implements fmt.Stringer.
+func (r *Runlinkat) String() string {
+ return fmt.Sprintf("Runlinkat{}")
+}
+
+// Trename is a rename request.
+//
+// Note that this generally isn't used anymore, and ideally all rename calls
+// should Trenameat below.
+type Trename struct {
+ // FID is the FID to rename.
+ FID FID
+
+ // Directory is the target directory.
+ Directory FID
+
+ // Name is the new file name.
+ Name string
+}
+
+// Decode implements encoder.Decode.
+func (t *Trename) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (t *Trename) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Trename) Type() MsgType {
+ return MsgTrename
+}
+
+// String implements fmt.Stringer.
+func (t *Trename) String() string {
+ return fmt.Sprintf("Trename{FID: %d, DirectoryFID: %d, Name: %s}", t.FID, t.Directory, t.Name)
+}
+
+// Rrename is a rename response.
+type Rrename struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rrename) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rrename) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rrename) Type() MsgType {
+ return MsgRrename
+}
+
+// String implements fmt.Stringer.
+func (r *Rrename) String() string {
+ return fmt.Sprintf("Rrename{}")
+}
+
+// Treadlink is a readlink request.
+type Treadlink struct {
+ // FID is the symlink.
+ FID FID
+}
+
+// Decode implements encoder.Decode.
+func (t *Treadlink) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Treadlink) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Treadlink) Type() MsgType {
+ return MsgTreadlink
+}
+
+// String implements fmt.Stringer.
+func (t *Treadlink) String() string {
+ return fmt.Sprintf("Treadlink{FID: %d}", t.FID)
+}
+
+// Rreadlink is a readlink response.
+type Rreadlink struct {
+ // Target is the symlink target.
+ Target string
+}
+
+// Decode implements encoder.Decode.
+func (r *Rreadlink) Decode(b *buffer) {
+ r.Target = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rreadlink) Encode(b *buffer) {
+ b.WriteString(r.Target)
+}
+
+// Type implements message.Type.
+func (*Rreadlink) Type() MsgType {
+ return MsgRreadlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rreadlink) String() string {
+ return fmt.Sprintf("Rreadlink{Target: %s}", r.Target)
+}
+
+// Tread is a read request.
+type Tread struct {
+ // FID is the FID to read.
+ FID FID
+
+ // Offset indicates the file offset.
+ Offset uint64
+
+ // Count indicates the number of bytes to read.
+ Count uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tread) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Offset = b.Read64()
+ t.Count = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tread) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Offset)
+ b.Write32(t.Count)
+}
+
+// Type implements message.Type.
+func (*Tread) Type() MsgType {
+ return MsgTread
+}
+
+// String implements fmt.Stringer.
+func (t *Tread) String() string {
+ return fmt.Sprintf("Tread{FID: %d, Offset: %d, Count: %d}", t.FID, t.Offset, t.Count)
+}
+
+// Rread is the response for a Tread.
+type Rread struct {
+ // Data is the resulting data.
+ Data []byte
+}
+
+// Decode implements encoder.Decode.
+//
+// Data is automatically decoded via Payload.
+func (r *Rread) Decode(b *buffer) {
+ count := b.Read32()
+ if count != uint32(len(r.Data)) {
+ b.markOverrun()
+ }
+}
+
+// Encode implements encoder.Encode.
+//
+// Data is automatically encoded via Payload.
+func (r *Rread) Encode(b *buffer) {
+ b.Write32(uint32(len(r.Data)))
+}
+
+// Type implements message.Type.
+func (*Rread) Type() MsgType {
+ return MsgRread
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Rread) FixedSize() uint32 {
+ return 4
+}
+
+// Payload implements payloader.Payload.
+func (r *Rread) Payload() []byte {
+ return r.Data
+}
+
+// SetPayload implements payloader.SetPayload.
+func (r *Rread) SetPayload(p []byte) {
+ r.Data = p
+}
+
+// String implements fmt.Stringer.
+func (r *Rread) String() string {
+ return fmt.Sprintf("Rread{len(Data): %d}", len(r.Data))
+}
+
+// Twrite is a write request.
+type Twrite struct {
+ // FID is the FID to read.
+ FID FID
+
+ // Offset indicates the file offset.
+ Offset uint64
+
+ // Data is the data to be written.
+ Data []byte
+}
+
+// Decode implements encoder.Decode.
+func (t *Twrite) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Offset = b.Read64()
+ count := b.Read32()
+ if count != uint32(len(t.Data)) {
+ b.markOverrun()
+ }
+}
+
+// Encode implements encoder.Encode.
+//
+// This uses the buffer payload to avoid a copy.
+func (t *Twrite) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Offset)
+ b.Write32(uint32(len(t.Data)))
+}
+
+// Type implements message.Type.
+func (*Twrite) Type() MsgType {
+ return MsgTwrite
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Twrite) FixedSize() uint32 {
+ return 16
+}
+
+// Payload implements payloader.Payload.
+func (t *Twrite) Payload() []byte {
+ return t.Data
+}
+
+// SetPayload implements payloader.SetPayload.
+func (t *Twrite) SetPayload(p []byte) {
+ t.Data = p
+}
+
+// String implements fmt.Stringer.
+func (t *Twrite) String() string {
+ return fmt.Sprintf("Twrite{FID: %v, Offset %d, len(Data): %d}", t.FID, t.Offset, len(t.Data))
+}
+
+// Rwrite is the response for a Twrite.
+type Rwrite struct {
+ // Count indicates the number of bytes successfully written.
+ Count uint32
+}
+
+// Decode implements encoder.Decode.
+func (r *Rwrite) Decode(b *buffer) {
+ r.Count = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rwrite) Encode(b *buffer) {
+ b.Write32(r.Count)
+}
+
+// Type implements message.Type.
+func (*Rwrite) Type() MsgType {
+ return MsgRwrite
+}
+
+// String implements fmt.Stringer.
+func (r *Rwrite) String() string {
+ return fmt.Sprintf("Rwrite{Count: %d}", r.Count)
+}
+
+// Tmknod is a mknod request.
+type Tmknod struct {
+ // Directory is the parent directory.
+ Directory FID
+
+ // Name is the device name.
+ Name string
+
+ // Mode is the device mode and permissions.
+ Mode FileMode
+
+ // Major is the device major number.
+ Major uint32
+
+ // Minor is the device minor number.
+ Minor uint32
+
+ // GID is the device GID.
+ GID GID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tmknod) Decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Mode = b.ReadFileMode()
+ t.Major = b.Read32()
+ t.Minor = b.Read32()
+ t.GID = b.ReadGID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tmknod) Encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WriteFileMode(t.Mode)
+ b.Write32(t.Major)
+ b.Write32(t.Minor)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tmknod) Type() MsgType {
+ return MsgTmknod
+}
+
+// String implements fmt.Stringer.
+func (t *Tmknod) String() string {
+ return fmt.Sprintf("Tmknod{DirectoryFID: %d, Name: %s, Mode: 0o%o, Major: %d, Minor: %d, GID: %d}", t.Directory, t.Name, t.Mode, t.Major, t.Minor, t.GID)
+}
+
+// Rmknod is a mknod response.
+type Rmknod struct {
+ // QID is the resulting QID.
+ QID QID
+}
+
+// Decode implements encoder.Decode.
+func (r *Rmknod) Decode(b *buffer) {
+ r.QID.Decode(b)
+}
+
+// Encode implements encoder.Encode.
+func (r *Rmknod) Encode(b *buffer) {
+ r.QID.Encode(b)
+}
+
+// Type implements message.Type.
+func (*Rmknod) Type() MsgType {
+ return MsgRmknod
+}
+
+// String implements fmt.Stringer.
+func (r *Rmknod) String() string {
+ return fmt.Sprintf("Rmknod{QID: %s}", r.QID)
+}
+
+// Tmkdir is a mkdir request.
+type Tmkdir struct {
+ // Directory is the parent directory.
+ Directory FID
+
+ // Name is the new directory name.
+ Name string
+
+ // Permissions is the set of permission bits.
+ Permissions FileMode
+
+ // GID is the owning group.
+ GID GID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tmkdir) Decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Permissions = b.ReadPermissions()
+ t.GID = b.ReadGID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tmkdir) Encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.WriteString(t.Name)
+ b.WritePermissions(t.Permissions)
+ b.WriteGID(t.GID)
+}
+
+// Type implements message.Type.
+func (*Tmkdir) Type() MsgType {
+ return MsgTmkdir
+}
+
+// String implements fmt.Stringer.
+func (t *Tmkdir) String() string {
+ return fmt.Sprintf("Tmkdir{DirectoryFID: %d, Name: %s, Permissions: 0o%o, GID: %d}", t.Directory, t.Name, t.Permissions, t.GID)
+}
+
+// Rmkdir is a mkdir response.
+type Rmkdir struct {
+ // QID is the resulting QID.
+ QID QID
+}
+
+// Decode implements encoder.Decode.
+func (r *Rmkdir) Decode(b *buffer) {
+ r.QID.Decode(b)
+}
+
+// Encode implements encoder.Encode.
+func (r *Rmkdir) Encode(b *buffer) {
+ r.QID.Encode(b)
+}
+
+// Type implements message.Type.
+func (*Rmkdir) Type() MsgType {
+ return MsgRmkdir
+}
+
+// String implements fmt.Stringer.
+func (r *Rmkdir) String() string {
+ return fmt.Sprintf("Rmkdir{QID: %s}", r.QID)
+}
+
+// Tgetattr is a getattr request.
+type Tgetattr struct {
+ // FID is the FID to get attributes for.
+ FID FID
+
+ // AttrMask is the set of attributes to get.
+ AttrMask AttrMask
+}
+
+// Decode implements encoder.Decode.
+func (t *Tgetattr) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.AttrMask.Decode(b)
+}
+
+// Encode implements encoder.Encode.
+func (t *Tgetattr) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.AttrMask.Encode(b)
+}
+
+// Type implements message.Type.
+func (*Tgetattr) Type() MsgType {
+ return MsgTgetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tgetattr) String() string {
+ return fmt.Sprintf("Tgetattr{FID: %d, AttrMask: %s}", t.FID, t.AttrMask)
+}
+
+// Rgetattr is a getattr response.
+type Rgetattr struct {
+ // Valid indicates which fields are valid.
+ Valid AttrMask
+
+ // QID is the QID for this file.
+ QID
+
+ // Attr is the set of attributes.
+ Attr Attr
+}
+
+// Decode implements encoder.Decode.
+func (r *Rgetattr) Decode(b *buffer) {
+ r.Valid.Decode(b)
+ r.QID.Decode(b)
+ r.Attr.Decode(b)
+}
+
+// Encode implements encoder.Encode.
+func (r *Rgetattr) Encode(b *buffer) {
+ r.Valid.Encode(b)
+ r.QID.Encode(b)
+ r.Attr.Encode(b)
+}
+
+// Type implements message.Type.
+func (*Rgetattr) Type() MsgType {
+ return MsgRgetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rgetattr) String() string {
+ return fmt.Sprintf("Rgetattr{Valid: %v, QID: %s, Attr: %s}", r.Valid, r.QID, r.Attr)
+}
+
+// Tsetattr is a setattr request.
+type Tsetattr struct {
+ // FID is the FID to change.
+ FID FID
+
+ // Valid is the set of bits which will be used.
+ Valid SetAttrMask
+
+ // SetAttr is the set request.
+ SetAttr SetAttr
+}
+
+// Decode implements encoder.Decode.
+func (t *Tsetattr) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Valid.Decode(b)
+ t.SetAttr.Decode(b)
+}
+
+// Encode implements encoder.Encode.
+func (t *Tsetattr) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Valid.Encode(b)
+ t.SetAttr.Encode(b)
+}
+
+// Type implements message.Type.
+func (*Tsetattr) Type() MsgType {
+ return MsgTsetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetattr) String() string {
+ return fmt.Sprintf("Tsetattr{FID: %d, Valid: %v, SetAttr: %s}", t.FID, t.Valid, t.SetAttr)
+}
+
+// Rsetattr is a setattr response.
+type Rsetattr struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rsetattr) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rsetattr) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetattr) Type() MsgType {
+ return MsgRsetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetattr) String() string {
+ return fmt.Sprintf("Rsetattr{}")
+}
+
+// Tallocate is an allocate request. This is an extension to 9P protocol, not
+// present in the 9P2000.L standard.
+type Tallocate struct {
+ FID FID
+ Mode AllocateMode
+ Offset uint64
+ Length uint64
+}
+
+// Decode implements encoder.Decode.
+func (t *Tallocate) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Mode.Decode(b)
+ t.Offset = b.Read64()
+ t.Length = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tallocate) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Mode.Encode(b)
+ b.Write64(t.Offset)
+ b.Write64(t.Length)
+}
+
+// Type implements message.Type.
+func (*Tallocate) Type() MsgType {
+ return MsgTallocate
+}
+
+// String implements fmt.Stringer.
+func (t *Tallocate) String() string {
+ return fmt.Sprintf("Tallocate{FID: %d, Offset: %d, Length: %d}", t.FID, t.Offset, t.Length)
+}
+
+// Rallocate is an allocate response.
+type Rallocate struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rallocate) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rallocate) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rallocate) Type() MsgType {
+ return MsgRallocate
+}
+
+// String implements fmt.Stringer.
+func (r *Rallocate) String() string {
+ return fmt.Sprintf("Rallocate{}")
+}
+
+// Txattrwalk walks extended attributes.
+type Txattrwalk struct {
+ // FID is the FID to check for attributes.
+ FID FID
+
+ // NewFID is the new FID associated with the attributes.
+ NewFID FID
+
+ // Name is the attribute name.
+ Name string
+}
+
+// Decode implements encoder.Decode.
+func (t *Txattrwalk) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (t *Txattrwalk) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Txattrwalk) Type() MsgType {
+ return MsgTxattrwalk
+}
+
+// String implements fmt.Stringer.
+func (t *Txattrwalk) String() string {
+ return fmt.Sprintf("Txattrwalk{FID: %d, NewFID: %d, Name: %s}", t.FID, t.NewFID, t.Name)
+}
+
+// Rxattrwalk is a xattrwalk response.
+type Rxattrwalk struct {
+ // Size is the size of the extended attribute.
+ Size uint64
+}
+
+// Decode implements encoder.Decode.
+func (r *Rxattrwalk) Decode(b *buffer) {
+ r.Size = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rxattrwalk) Encode(b *buffer) {
+ b.Write64(r.Size)
+}
+
+// Type implements message.Type.
+func (*Rxattrwalk) Type() MsgType {
+ return MsgRxattrwalk
+}
+
+// String implements fmt.Stringer.
+func (r *Rxattrwalk) String() string {
+ return fmt.Sprintf("Rxattrwalk{Size: %d}", r.Size)
+}
+
+// Txattrcreate prepare to set extended attributes.
+type Txattrcreate struct {
+ // FID is input/output parameter, it identifies the file on which
+ // extended attributes will be set but after successful Rxattrcreate
+ // it is used to write the extended attribute value.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+
+ // Size of the attribute value. When the FID is clunked it has to match
+ // the number of bytes written to the FID.
+ AttrSize uint64
+
+ // Linux setxattr(2) flags.
+ Flags uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Txattrcreate) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.AttrSize = b.Read64()
+ t.Flags = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Txattrcreate) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.Write64(t.AttrSize)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Txattrcreate) Type() MsgType {
+ return MsgTxattrcreate
+}
+
+// String implements fmt.Stringer.
+func (t *Txattrcreate) String() string {
+ return fmt.Sprintf("Txattrcreate{FID: %d, Name: %s, AttrSize: %d, Flags: %d}", t.FID, t.Name, t.AttrSize, t.Flags)
+}
+
+// Rxattrcreate is a xattrcreate response.
+type Rxattrcreate struct {
+}
+
+// Decode implements encoder.Decode.
+func (r *Rxattrcreate) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (r *Rxattrcreate) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rxattrcreate) Type() MsgType {
+ return MsgRxattrcreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rxattrcreate) String() string {
+ return fmt.Sprintf("Rxattrcreate{}")
+}
+
+// Treaddir is a readdir request.
+type Treaddir struct {
+ // Directory is the directory FID to read.
+ Directory FID
+
+ // Offset is the offset to read at.
+ Offset uint64
+
+ // Count is the number of bytes to read.
+ Count uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Treaddir) Decode(b *buffer) {
+ t.Directory = b.ReadFID()
+ t.Offset = b.Read64()
+ t.Count = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Treaddir) Encode(b *buffer) {
+ b.WriteFID(t.Directory)
+ b.Write64(t.Offset)
+ b.Write32(t.Count)
+}
+
+// Type implements message.Type.
+func (*Treaddir) Type() MsgType {
+ return MsgTreaddir
+}
+
+// String implements fmt.Stringer.
+func (t *Treaddir) String() string {
+ return fmt.Sprintf("Treaddir{DirectoryFID: %d, Offset: %d, Count: %d}", t.Directory, t.Offset, t.Count)
+}
+
+// Rreaddir is a readdir response.
+type Rreaddir struct {
+ // Count is the byte limit.
+ //
+ // This should always be set from the Treaddir request.
+ Count uint32
+
+ // Entries are the resulting entries.
+ //
+ // This may be constructed in decode.
+ Entries []Dirent
+
+ // payload is the encoded payload.
+ //
+ // This is constructed by encode.
+ payload []byte
+}
+
+// Decode implements encoder.Decode.
+func (r *Rreaddir) Decode(b *buffer) {
+ r.Count = b.Read32()
+ entriesBuf := buffer{data: r.payload}
+ r.Entries = r.Entries[:0]
+ for {
+ var d Dirent
+ d.Decode(&entriesBuf)
+ if entriesBuf.isOverrun() {
+ // Couldn't decode a complete entry.
+ break
+ }
+ r.Entries = append(r.Entries, d)
+ }
+}
+
+// Encode implements encoder.Encode.
+func (r *Rreaddir) Encode(b *buffer) {
+ entriesBuf := buffer{}
+ for _, d := range r.Entries {
+ d.Encode(&entriesBuf)
+ if len(entriesBuf.data) >= int(r.Count) {
+ break
+ }
+ }
+ if len(entriesBuf.data) < int(r.Count) {
+ r.Count = uint32(len(entriesBuf.data))
+ r.payload = entriesBuf.data
+ } else {
+ r.payload = entriesBuf.data[:r.Count]
+ }
+ b.Write32(uint32(r.Count))
+}
+
+// Type implements message.Type.
+func (*Rreaddir) Type() MsgType {
+ return MsgRreaddir
+}
+
+// FixedSize implements payloader.FixedSize.
+func (*Rreaddir) FixedSize() uint32 {
+ return 4
+}
+
+// Payload implements payloader.Payload.
+func (r *Rreaddir) Payload() []byte {
+ return r.payload
+}
+
+// SetPayload implements payloader.SetPayload.
+func (r *Rreaddir) SetPayload(p []byte) {
+ r.payload = p
+}
+
+// String implements fmt.Stringer.
+func (r *Rreaddir) String() string {
+ return fmt.Sprintf("Rreaddir{Count: %d, Entries: %s}", r.Count, r.Entries)
+}
+
+// Tfsync is an fsync request.
+type Tfsync struct {
+ // FID is the fid to sync.
+ FID FID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tfsync) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tfsync) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tfsync) Type() MsgType {
+ return MsgTfsync
+}
+
+// String implements fmt.Stringer.
+func (t *Tfsync) String() string {
+ return fmt.Sprintf("Tfsync{FID: %d}", t.FID)
+}
+
+// Rfsync is an fsync response.
+type Rfsync struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rfsync) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rfsync) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rfsync) Type() MsgType {
+ return MsgRfsync
+}
+
+// String implements fmt.Stringer.
+func (r *Rfsync) String() string {
+ return fmt.Sprintf("Rfsync{}")
+}
+
+// Tstatfs is a stat request.
+type Tstatfs struct {
+ // FID is the root.
+ FID FID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tstatfs) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tstatfs) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tstatfs) Type() MsgType {
+ return MsgTstatfs
+}
+
+// String implements fmt.Stringer.
+func (t *Tstatfs) String() string {
+ return fmt.Sprintf("Tstatfs{FID: %d}", t.FID)
+}
+
+// Rstatfs is the response for a Tstatfs.
+type Rstatfs struct {
+ // FSStat is the stat result.
+ FSStat FSStat
+}
+
+// Decode implements encoder.Decode.
+func (r *Rstatfs) Decode(b *buffer) {
+ r.FSStat.Decode(b)
+}
+
+// Encode implements encoder.Encode.
+func (r *Rstatfs) Encode(b *buffer) {
+ r.FSStat.Encode(b)
+}
+
+// Type implements message.Type.
+func (*Rstatfs) Type() MsgType {
+ return MsgRstatfs
+}
+
+// String implements fmt.Stringer.
+func (r *Rstatfs) String() string {
+ return fmt.Sprintf("Rstatfs{FSStat: %v}", r.FSStat)
+}
+
+// Tflushf is a flush file request, not to be confused with Tflush.
+type Tflushf struct {
+ // FID is the FID to be flushed.
+ FID FID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tflushf) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tflushf) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+}
+
+// Type implements message.Type.
+func (*Tflushf) Type() MsgType {
+ return MsgTflushf
+}
+
+// String implements fmt.Stringer.
+func (t *Tflushf) String() string {
+ return fmt.Sprintf("Tflushf{FID: %d}", t.FID)
+}
+
+// Rflushf is a flush file response.
+type Rflushf struct {
+}
+
+// Decode implements encoder.Decode.
+func (*Rflushf) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (*Rflushf) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rflushf) Type() MsgType {
+ return MsgRflushf
+}
+
+// String implements fmt.Stringer.
+func (*Rflushf) String() string {
+ return fmt.Sprintf("Rflushf{}")
+}
+
+// Twalkgetattr is a walk request.
+type Twalkgetattr struct {
+ // FID is the FID to be walked.
+ FID FID
+
+ // NewFID is the resulting FID.
+ NewFID FID
+
+ // Names are the set of names to be walked.
+ Names []string
+}
+
+// Decode implements encoder.Decode.
+func (t *Twalkgetattr) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.NewFID = b.ReadFID()
+ n := b.Read16()
+ t.Names = t.Names[:0]
+ for i := 0; i < int(n); i++ {
+ t.Names = append(t.Names, b.ReadString())
+ }
+}
+
+// Encode implements encoder.Encode.
+func (t *Twalkgetattr) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteFID(t.NewFID)
+ b.Write16(uint16(len(t.Names)))
+ for _, name := range t.Names {
+ b.WriteString(name)
+ }
+}
+
+// Type implements message.Type.
+func (*Twalkgetattr) Type() MsgType {
+ return MsgTwalkgetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Twalkgetattr) String() string {
+ return fmt.Sprintf("Twalkgetattr{FID: %d, NewFID: %d, Names: %v}", t.FID, t.NewFID, t.Names)
+}
+
+// Rwalkgetattr is a walk response.
+type Rwalkgetattr struct {
+ // Valid indicates which fields are valid in the Attr below.
+ Valid AttrMask
+
+ // Attr is the set of attributes for the last QID (the file walked to).
+ Attr Attr
+
+ // QIDs are the set of QIDs returned.
+ QIDs []QID
+}
+
+// Decode implements encoder.Decode.
+func (r *Rwalkgetattr) Decode(b *buffer) {
+ r.Valid.Decode(b)
+ r.Attr.Decode(b)
+ n := b.Read16()
+ r.QIDs = r.QIDs[:0]
+ for i := 0; i < int(n); i++ {
+ var q QID
+ q.Decode(b)
+ r.QIDs = append(r.QIDs, q)
+ }
+}
+
+// Encode implements encoder.Encode.
+func (r *Rwalkgetattr) Encode(b *buffer) {
+ r.Valid.Encode(b)
+ r.Attr.Encode(b)
+ b.Write16(uint16(len(r.QIDs)))
+ for _, q := range r.QIDs {
+ q.Encode(b)
+ }
+}
+
+// Type implements message.Type.
+func (*Rwalkgetattr) Type() MsgType {
+ return MsgRwalkgetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rwalkgetattr) String() string {
+ return fmt.Sprintf("Rwalkgetattr{Valid: %s, Attr: %s, QIDs: %v}", r.Valid, r.Attr, r.QIDs)
+}
+
+// Tucreate is a Tlcreate message that includes a UID.
+type Tucreate struct {
+ Tlcreate
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tucreate) Decode(b *buffer) {
+ t.Tlcreate.Decode(b)
+ t.UID = b.ReadUID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tucreate) Encode(b *buffer) {
+ t.Tlcreate.Encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tucreate) Type() MsgType {
+ return MsgTucreate
+}
+
+// String implements fmt.Stringer.
+func (t *Tucreate) String() string {
+ return fmt.Sprintf("Tucreate{Tlcreate: %v, UID: %d}", &t.Tlcreate, t.UID)
+}
+
+// Rucreate is a file creation response.
+type Rucreate struct {
+ Rlcreate
+}
+
+// Type implements message.Type.
+func (*Rucreate) Type() MsgType {
+ return MsgRucreate
+}
+
+// String implements fmt.Stringer.
+func (r *Rucreate) String() string {
+ return fmt.Sprintf("Rucreate{%v}", &r.Rlcreate)
+}
+
+// Tumkdir is a Tmkdir message that includes a UID.
+type Tumkdir struct {
+ Tmkdir
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tumkdir) Decode(b *buffer) {
+ t.Tmkdir.Decode(b)
+ t.UID = b.ReadUID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tumkdir) Encode(b *buffer) {
+ t.Tmkdir.Encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tumkdir) Type() MsgType {
+ return MsgTumkdir
+}
+
+// String implements fmt.Stringer.
+func (t *Tumkdir) String() string {
+ return fmt.Sprintf("Tumkdir{Tmkdir: %v, UID: %d}", &t.Tmkdir, t.UID)
+}
+
+// Rumkdir is a umkdir response.
+type Rumkdir struct {
+ Rmkdir
+}
+
+// Type implements message.Type.
+func (*Rumkdir) Type() MsgType {
+ return MsgRumkdir
+}
+
+// String implements fmt.Stringer.
+func (r *Rumkdir) String() string {
+ return fmt.Sprintf("Rumkdir{%v}", &r.Rmkdir)
+}
+
+// Tumknod is a Tmknod message that includes a UID.
+type Tumknod struct {
+ Tmknod
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tumknod) Decode(b *buffer) {
+ t.Tmknod.Decode(b)
+ t.UID = b.ReadUID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tumknod) Encode(b *buffer) {
+ t.Tmknod.Encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tumknod) Type() MsgType {
+ return MsgTumknod
+}
+
+// String implements fmt.Stringer.
+func (t *Tumknod) String() string {
+ return fmt.Sprintf("Tumknod{Tmknod: %v, UID: %d}", &t.Tmknod, t.UID)
+}
+
+// Rumknod is a umknod response.
+type Rumknod struct {
+ Rmknod
+}
+
+// Type implements message.Type.
+func (*Rumknod) Type() MsgType {
+ return MsgRumknod
+}
+
+// String implements fmt.Stringer.
+func (r *Rumknod) String() string {
+ return fmt.Sprintf("Rumknod{%v}", &r.Rmknod)
+}
+
+// Tusymlink is a Tsymlink message that includes a UID.
+type Tusymlink struct {
+ Tsymlink
+
+ // UID is the UID to use as the effective UID in creation messages.
+ UID UID
+}
+
+// Decode implements encoder.Decode.
+func (t *Tusymlink) Decode(b *buffer) {
+ t.Tsymlink.Decode(b)
+ t.UID = b.ReadUID()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tusymlink) Encode(b *buffer) {
+ t.Tsymlink.Encode(b)
+ b.WriteUID(t.UID)
+}
+
+// Type implements message.Type.
+func (t *Tusymlink) Type() MsgType {
+ return MsgTusymlink
+}
+
+// String implements fmt.Stringer.
+func (t *Tusymlink) String() string {
+ return fmt.Sprintf("Tusymlink{Tsymlink: %v, UID: %d}", &t.Tsymlink, t.UID)
+}
+
+// Rusymlink is a usymlink response.
+type Rusymlink struct {
+ Rsymlink
+}
+
+// Type implements message.Type.
+func (*Rusymlink) Type() MsgType {
+ return MsgRusymlink
+}
+
+// String implements fmt.Stringer.
+func (r *Rusymlink) String() string {
+ return fmt.Sprintf("Rusymlink{%v}", &r.Rsymlink)
+}
+
+// Tlconnect is a connect request.
+type Tlconnect struct {
+ // FID is the FID to be connected.
+ FID FID
+
+ // Flags are the connect flags.
+ Flags ConnectFlags
+}
+
+// Decode implements encoder.Decode.
+func (t *Tlconnect) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Flags = b.ReadConnectFlags()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tlconnect) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteConnectFlags(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tlconnect) Type() MsgType {
+ return MsgTlconnect
+}
+
+// String implements fmt.Stringer.
+func (t *Tlconnect) String() string {
+ return fmt.Sprintf("Tlconnect{FID: %d, Flags: %v}", t.FID, t.Flags)
+}
+
+// Rlconnect is a connect response.
+type Rlconnect struct {
+ // File is a host socket.
+ File *fd.FD
+}
+
+// Decode implements encoder.Decode.
+func (r *Rlconnect) Decode(*buffer) {}
+
+// Encode implements encoder.Encode.
+func (r *Rlconnect) Encode(*buffer) {}
+
+// Type implements message.Type.
+func (*Rlconnect) Type() MsgType {
+ return MsgRlconnect
+}
+
+// FilePayload returns the file payload.
+func (r *Rlconnect) FilePayload() *fd.FD {
+ return r.File
+}
+
+// SetFilePayload sets the received file.
+func (r *Rlconnect) SetFilePayload(file *fd.FD) {
+ r.File = file
+}
+
+// String implements fmt.Stringer.
+func (r *Rlconnect) String() string {
+ return fmt.Sprintf("Rlconnect{File: %v}", r.File)
+}
+
+const maxCacheSize = 3
+
+// msgFactory is used to reduce allocations by caching messages for reuse.
+type msgFactory struct {
+ create func() message
+ cache chan message
+}
+
+// msgRegistry indexes all message factories by type.
+var msgRegistry registry
+
+type registry struct {
+ factories [math.MaxUint8]msgFactory
+
+ // largestFixedSize is computed so that given some message size M, you can
+ // compute the maximum payload size (e.g. for Twrite, Rread) with
+ // M-largestFixedSize. You could do this individual on a per-message basis,
+ // but it's easier to compute a single maximum safe payload.
+ largestFixedSize uint32
+}
+
+// get returns a new message by type.
+//
+// An error is returned in the case of an unknown message.
+//
+// This takes, and ignores, a message tag so that it may be used directly as a
+// lookupTagAndType function for recv (by design).
+func (r *registry) get(_ Tag, t MsgType) (message, error) {
+ entry := &r.factories[t]
+ if entry.create == nil {
+ return nil, &ErrInvalidMsgType{t}
+ }
+
+ select {
+ case msg := <-entry.cache:
+ return msg, nil
+ default:
+ return entry.create(), nil
+ }
+}
+
+func (r *registry) put(msg message) {
+ if p, ok := msg.(payloader); ok {
+ p.SetPayload(nil)
+ }
+ if f, ok := msg.(filer); ok {
+ f.SetFilePayload(nil)
+ }
+
+ entry := &r.factories[msg.Type()]
+ select {
+ case entry.cache <- msg:
+ default:
+ }
+}
+
+// register registers the given message type.
+//
+// This may cause panic on failure and should only be used from init.
+func (r *registry) register(t MsgType, fn func() message) {
+ if int(t) >= len(r.factories) {
+ panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(r.factories)))
+ }
+ if r.factories[t].create != nil {
+ panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, r.factories[t].create(), fn()))
+ }
+ r.factories[t] = msgFactory{
+ create: fn,
+ cache: make(chan message, maxCacheSize),
+ }
+
+ if size := calculateSize(fn()); size > r.largestFixedSize {
+ r.largestFixedSize = size
+ }
+}
+
+func calculateSize(m message) uint32 {
+ if p, ok := m.(payloader); ok {
+ return p.FixedSize()
+ }
+ var dataBuf buffer
+ m.Encode(&dataBuf)
+ return uint32(len(dataBuf.data))
+}
+
+func init() {
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+ msgRegistry.register(MsgTstatfs, func() message { return &Tstatfs{} })
+ msgRegistry.register(MsgRstatfs, func() message { return &Rstatfs{} })
+ msgRegistry.register(MsgTlopen, func() message { return &Tlopen{} })
+ msgRegistry.register(MsgRlopen, func() message { return &Rlopen{} })
+ msgRegistry.register(MsgTlcreate, func() message { return &Tlcreate{} })
+ msgRegistry.register(MsgRlcreate, func() message { return &Rlcreate{} })
+ msgRegistry.register(MsgTsymlink, func() message { return &Tsymlink{} })
+ msgRegistry.register(MsgRsymlink, func() message { return &Rsymlink{} })
+ msgRegistry.register(MsgTmknod, func() message { return &Tmknod{} })
+ msgRegistry.register(MsgRmknod, func() message { return &Rmknod{} })
+ msgRegistry.register(MsgTrename, func() message { return &Trename{} })
+ msgRegistry.register(MsgRrename, func() message { return &Rrename{} })
+ msgRegistry.register(MsgTreadlink, func() message { return &Treadlink{} })
+ msgRegistry.register(MsgRreadlink, func() message { return &Rreadlink{} })
+ msgRegistry.register(MsgTgetattr, func() message { return &Tgetattr{} })
+ msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} })
+ msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} })
+ msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} })
+ msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
+ msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
+ msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
+ msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
+ msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
+ msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
+ msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
+ msgRegistry.register(MsgRfsync, func() message { return &Rfsync{} })
+ msgRegistry.register(MsgTlink, func() message { return &Tlink{} })
+ msgRegistry.register(MsgRlink, func() message { return &Rlink{} })
+ msgRegistry.register(MsgTmkdir, func() message { return &Tmkdir{} })
+ msgRegistry.register(MsgRmkdir, func() message { return &Rmkdir{} })
+ msgRegistry.register(MsgTrenameat, func() message { return &Trenameat{} })
+ msgRegistry.register(MsgRrenameat, func() message { return &Rrenameat{} })
+ msgRegistry.register(MsgTunlinkat, func() message { return &Tunlinkat{} })
+ msgRegistry.register(MsgRunlinkat, func() message { return &Runlinkat{} })
+ msgRegistry.register(MsgTversion, func() message { return &Tversion{} })
+ msgRegistry.register(MsgRversion, func() message { return &Rversion{} })
+ msgRegistry.register(MsgTauth, func() message { return &Tauth{} })
+ msgRegistry.register(MsgRauth, func() message { return &Rauth{} })
+ msgRegistry.register(MsgTattach, func() message { return &Tattach{} })
+ msgRegistry.register(MsgRattach, func() message { return &Rattach{} })
+ msgRegistry.register(MsgTflush, func() message { return &Tflush{} })
+ msgRegistry.register(MsgRflush, func() message { return &Rflush{} })
+ msgRegistry.register(MsgTwalk, func() message { return &Twalk{} })
+ msgRegistry.register(MsgRwalk, func() message { return &Rwalk{} })
+ msgRegistry.register(MsgTread, func() message { return &Tread{} })
+ msgRegistry.register(MsgRread, func() message { return &Rread{} })
+ msgRegistry.register(MsgTwrite, func() message { return &Twrite{} })
+ msgRegistry.register(MsgRwrite, func() message { return &Rwrite{} })
+ msgRegistry.register(MsgTclunk, func() message { return &Tclunk{} })
+ msgRegistry.register(MsgRclunk, func() message { return &Rclunk{} })
+ msgRegistry.register(MsgTremove, func() message { return &Tremove{} })
+ msgRegistry.register(MsgRremove, func() message { return &Rremove{} })
+ msgRegistry.register(MsgTflushf, func() message { return &Tflushf{} })
+ msgRegistry.register(MsgRflushf, func() message { return &Rflushf{} })
+ msgRegistry.register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} })
+ msgRegistry.register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} })
+ msgRegistry.register(MsgTucreate, func() message { return &Tucreate{} })
+ msgRegistry.register(MsgRucreate, func() message { return &Rucreate{} })
+ msgRegistry.register(MsgTumkdir, func() message { return &Tumkdir{} })
+ msgRegistry.register(MsgRumkdir, func() message { return &Rumkdir{} })
+ msgRegistry.register(MsgTumknod, func() message { return &Tumknod{} })
+ msgRegistry.register(MsgRumknod, func() message { return &Rumknod{} })
+ msgRegistry.register(MsgTusymlink, func() message { return &Tusymlink{} })
+ msgRegistry.register(MsgRusymlink, func() message { return &Rusymlink{} })
+ msgRegistry.register(MsgTlconnect, func() message { return &Tlconnect{} })
+ msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
+ msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
+ msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
new file mode 100644
index 000000000..4039862e6
--- /dev/null
+++ b/pkg/p9/p9.go
@@ -0,0 +1,1141 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package p9 is a 9P2000.L implementation.
+package p9
+
+import (
+ "fmt"
+ "math"
+ "os"
+ "strings"
+ "sync/atomic"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+// OpenFlags is the mode passed to Open and Create operations.
+//
+// These correspond to bits sent over the wire.
+type OpenFlags uint32
+
+const (
+ // ReadOnly is a Topen and Tcreate flag indicating read-only mode.
+ ReadOnly OpenFlags = 0
+
+ // WriteOnly is a Topen and Tcreate flag indicating write-only mode.
+ WriteOnly OpenFlags = 1
+
+ // ReadWrite is a Topen flag indicates read-write mode.
+ ReadWrite OpenFlags = 2
+
+ // OpenFlagsModeMask is a mask of valid OpenFlags mode bits.
+ OpenFlagsModeMask OpenFlags = 3
+
+ // OpenFlagsIgnoreMask is a list of OpenFlags mode bits that are ignored for Tlopen.
+ // Note that syscall.O_LARGEFILE is set to zero, use value from Linux fcntl.h.
+ OpenFlagsIgnoreMask OpenFlags = syscall.O_DIRECTORY | syscall.O_NOATIME | 0100000
+)
+
+// ConnectFlags is the mode passed to Connect operations.
+//
+// These correspond to bits sent over the wire.
+type ConnectFlags uint32
+
+const (
+ // StreamSocket is a Tlconnect flag indicating SOCK_STREAM mode.
+ StreamSocket ConnectFlags = 0
+
+ // DgramSocket is a Tlconnect flag indicating SOCK_DGRAM mode.
+ DgramSocket ConnectFlags = 1
+
+ // SeqpacketSocket is a Tlconnect flag indicating SOCK_SEQPACKET mode.
+ SeqpacketSocket ConnectFlags = 2
+
+ // AnonymousSocket is a Tlconnect flag indicating that the mode does not
+ // matter and that the requester will accept any socket type.
+ AnonymousSocket ConnectFlags = 3
+)
+
+// OSFlags converts a p9.OpenFlags to an int compatible with open(2).
+func (o OpenFlags) OSFlags() int {
+ return int(o & OpenFlagsModeMask)
+}
+
+// String implements fmt.Stringer.
+func (o OpenFlags) String() string {
+ switch o {
+ case ReadOnly:
+ return "ReadOnly"
+ case WriteOnly:
+ return "WriteOnly"
+ case ReadWrite:
+ return "ReadWrite"
+ case OpenFlagsModeMask:
+ return "OpenFlagsModeMask"
+ case OpenFlagsIgnoreMask:
+ return "OpenFlagsIgnoreMask"
+ default:
+ return "UNDEFINED"
+ }
+}
+
+// Tag is a messsage tag.
+type Tag uint16
+
+// FID is a file identifier.
+type FID uint64
+
+// FileMode are flags corresponding to file modes.
+//
+// These correspond to bits sent over the wire.
+// These also correspond to mode_t bits.
+type FileMode uint32
+
+const (
+ // FileModeMask is a mask of all the file mode bits of FileMode.
+ FileModeMask FileMode = 0170000
+
+ // ModeSocket is an (unused) mode bit for a socket.
+ ModeSocket FileMode = 0140000
+
+ // ModeSymlink is a mode bit for a symlink.
+ ModeSymlink FileMode = 0120000
+
+ // ModeRegular is a mode bit for regular files.
+ ModeRegular FileMode = 0100000
+
+ // ModeBlockDevice is a mode bit for block devices.
+ ModeBlockDevice FileMode = 060000
+
+ // ModeDirectory is a mode bit for directories.
+ ModeDirectory FileMode = 040000
+
+ // ModeCharacterDevice is a mode bit for a character device.
+ ModeCharacterDevice FileMode = 020000
+
+ // ModeNamedPipe is a mode bit for a named pipe.
+ ModeNamedPipe FileMode = 010000
+
+ // Read is a mode bit indicating read permission.
+ Read FileMode = 04
+
+ // Write is a mode bit indicating write permission.
+ Write FileMode = 02
+
+ // Exec is a mode bit indicating exec permission.
+ Exec FileMode = 01
+
+ // AllPermissions is a mask with rwx bits set for user, group and others.
+ AllPermissions FileMode = 0777
+
+ // Sticky is a mode bit indicating sticky directories.
+ Sticky FileMode = 01000
+
+ // permissionsMask is the mask to apply to FileModes for permissions. It
+ // includes rwx bits for user, group and others, and sticky bit.
+ permissionsMask FileMode = 01777
+)
+
+// QIDType is the most significant byte of the FileMode word, to be used as the
+// Type field of p9.QID.
+func (m FileMode) QIDType() QIDType {
+ switch {
+ case m.IsDir():
+ return TypeDir
+ case m.IsSocket(), m.IsNamedPipe(), m.IsCharacterDevice():
+ // Best approximation.
+ return TypeAppendOnly
+ case m.IsSymlink():
+ return TypeSymlink
+ default:
+ return TypeRegular
+ }
+}
+
+// FileType returns the file mode without the permission bits.
+func (m FileMode) FileType() FileMode {
+ return m & FileModeMask
+}
+
+// Permissions returns just the permission bits of the mode.
+func (m FileMode) Permissions() FileMode {
+ return m & permissionsMask
+}
+
+// Writable returns the mode with write bits added.
+func (m FileMode) Writable() FileMode {
+ return m | 0222
+}
+
+// IsReadable returns true if m represents a file that can be read.
+func (m FileMode) IsReadable() bool {
+ return m&0444 != 0
+}
+
+// IsWritable returns true if m represents a file that can be written to.
+func (m FileMode) IsWritable() bool {
+ return m&0222 != 0
+}
+
+// IsExecutable returns true if m represents a file that can be executed.
+func (m FileMode) IsExecutable() bool {
+ return m&0111 != 0
+}
+
+// IsRegular returns true if m is a regular file.
+func (m FileMode) IsRegular() bool {
+ return m&FileModeMask == ModeRegular
+}
+
+// IsDir returns true if m represents a directory.
+func (m FileMode) IsDir() bool {
+ return m&FileModeMask == ModeDirectory
+}
+
+// IsNamedPipe returns true if m represents a named pipe.
+func (m FileMode) IsNamedPipe() bool {
+ return m&FileModeMask == ModeNamedPipe
+}
+
+// IsCharacterDevice returns true if m represents a character device.
+func (m FileMode) IsCharacterDevice() bool {
+ return m&FileModeMask == ModeCharacterDevice
+}
+
+// IsBlockDevice returns true if m represents a character device.
+func (m FileMode) IsBlockDevice() bool {
+ return m&FileModeMask == ModeBlockDevice
+}
+
+// IsSocket returns true if m represents a socket.
+func (m FileMode) IsSocket() bool {
+ return m&FileModeMask == ModeSocket
+}
+
+// IsSymlink returns true if m represents a symlink.
+func (m FileMode) IsSymlink() bool {
+ return m&FileModeMask == ModeSymlink
+}
+
+// ModeFromOS returns a FileMode from an os.FileMode.
+func ModeFromOS(mode os.FileMode) FileMode {
+ m := FileMode(mode.Perm())
+ switch {
+ case mode.IsDir():
+ m |= ModeDirectory
+ case mode&os.ModeSymlink != 0:
+ m |= ModeSymlink
+ case mode&os.ModeSocket != 0:
+ m |= ModeSocket
+ case mode&os.ModeNamedPipe != 0:
+ m |= ModeNamedPipe
+ case mode&os.ModeCharDevice != 0:
+ m |= ModeCharacterDevice
+ case mode&os.ModeDevice != 0:
+ m |= ModeBlockDevice
+ default:
+ m |= ModeRegular
+ }
+ return m
+}
+
+// OSMode converts a p9.FileMode to an os.FileMode.
+func (m FileMode) OSMode() os.FileMode {
+ var osMode os.FileMode
+ osMode |= os.FileMode(m.Permissions())
+ switch {
+ case m.IsDir():
+ osMode |= os.ModeDir
+ case m.IsSymlink():
+ osMode |= os.ModeSymlink
+ case m.IsSocket():
+ osMode |= os.ModeSocket
+ case m.IsNamedPipe():
+ osMode |= os.ModeNamedPipe
+ case m.IsCharacterDevice():
+ osMode |= os.ModeCharDevice | os.ModeDevice
+ case m.IsBlockDevice():
+ osMode |= os.ModeDevice
+ }
+ return osMode
+}
+
+// UID represents a user ID.
+type UID uint32
+
+// Ok returns true if uid is not NoUID.
+func (uid UID) Ok() bool {
+ return uid != NoUID
+}
+
+// GID represents a group ID.
+type GID uint32
+
+// Ok returns true if gid is not NoGID.
+func (gid GID) Ok() bool {
+ return gid != NoGID
+}
+
+const (
+ // NoTag is a sentinel used to indicate no valid tag.
+ NoTag Tag = math.MaxUint16
+
+ // NoFID is a sentinel used to indicate no valid FID.
+ NoFID FID = math.MaxUint32
+
+ // NoUID is a sentinel used to indicate no valid UID.
+ NoUID UID = math.MaxUint32
+
+ // NoGID is a sentinel used to indicate no valid GID.
+ NoGID GID = math.MaxUint32
+)
+
+// MsgType is a type identifier.
+type MsgType uint8
+
+// MsgType declarations.
+const (
+ MsgTlerror MsgType = 6
+ MsgRlerror = 7
+ MsgTstatfs = 8
+ MsgRstatfs = 9
+ MsgTlopen = 12
+ MsgRlopen = 13
+ MsgTlcreate = 14
+ MsgRlcreate = 15
+ MsgTsymlink = 16
+ MsgRsymlink = 17
+ MsgTmknod = 18
+ MsgRmknod = 19
+ MsgTrename = 20
+ MsgRrename = 21
+ MsgTreadlink = 22
+ MsgRreadlink = 23
+ MsgTgetattr = 24
+ MsgRgetattr = 25
+ MsgTsetattr = 26
+ MsgRsetattr = 27
+ MsgTxattrwalk = 30
+ MsgRxattrwalk = 31
+ MsgTxattrcreate = 32
+ MsgRxattrcreate = 33
+ MsgTreaddir = 40
+ MsgRreaddir = 41
+ MsgTfsync = 50
+ MsgRfsync = 51
+ MsgTlink = 70
+ MsgRlink = 71
+ MsgTmkdir = 72
+ MsgRmkdir = 73
+ MsgTrenameat = 74
+ MsgRrenameat = 75
+ MsgTunlinkat = 76
+ MsgRunlinkat = 77
+ MsgTversion = 100
+ MsgRversion = 101
+ MsgTauth = 102
+ MsgRauth = 103
+ MsgTattach = 104
+ MsgRattach = 105
+ MsgTflush = 108
+ MsgRflush = 109
+ MsgTwalk = 110
+ MsgRwalk = 111
+ MsgTread = 116
+ MsgRread = 117
+ MsgTwrite = 118
+ MsgRwrite = 119
+ MsgTclunk = 120
+ MsgRclunk = 121
+ MsgTremove = 122
+ MsgRremove = 123
+ MsgTflushf = 124
+ MsgRflushf = 125
+ MsgTwalkgetattr = 126
+ MsgRwalkgetattr = 127
+ MsgTucreate = 128
+ MsgRucreate = 129
+ MsgTumkdir = 130
+ MsgRumkdir = 131
+ MsgTumknod = 132
+ MsgRumknod = 133
+ MsgTusymlink = 134
+ MsgRusymlink = 135
+ MsgTlconnect = 136
+ MsgRlconnect = 137
+ MsgTallocate = 138
+ MsgRallocate = 139
+)
+
+// QIDType represents the file type for QIDs.
+//
+// QIDType corresponds to the high 8 bits of a Plan 9 file mode.
+type QIDType uint8
+
+const (
+ // TypeDir represents a directory type.
+ TypeDir QIDType = 0x80
+
+ // TypeAppendOnly represents an append only file.
+ TypeAppendOnly QIDType = 0x40
+
+ // TypeExclusive represents an exclusive-use file.
+ TypeExclusive QIDType = 0x20
+
+ // TypeMount represents a mounted channel.
+ TypeMount QIDType = 0x10
+
+ // TypeAuth represents an authentication file.
+ TypeAuth QIDType = 0x08
+
+ // TypeTemporary represents a temporary file.
+ TypeTemporary QIDType = 0x04
+
+ // TypeSymlink represents a symlink.
+ TypeSymlink QIDType = 0x02
+
+ // TypeLink represents a hard link.
+ TypeLink QIDType = 0x01
+
+ // TypeRegular represents a regular file.
+ TypeRegular QIDType = 0x00
+)
+
+// QID is a unique file identifier.
+//
+// This may be embedded in other requests and responses.
+type QID struct {
+ // Type is the highest order byte of the file mode.
+ Type QIDType
+
+ // Version is an arbitrary server version number.
+ Version uint32
+
+ // Path is a unique server identifier for this path (e.g. inode).
+ Path uint64
+}
+
+// String implements fmt.Stringer.
+func (q QID) String() string {
+ return fmt.Sprintf("QID{Type: %d, Version: %d, Path: %d}", q.Type, q.Version, q.Path)
+}
+
+// Decode implements encoder.Decode.
+func (q *QID) Decode(b *buffer) {
+ q.Type = b.ReadQIDType()
+ q.Version = b.Read32()
+ q.Path = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (q *QID) Encode(b *buffer) {
+ b.WriteQIDType(q.Type)
+ b.Write32(q.Version)
+ b.Write64(q.Path)
+}
+
+// QIDGenerator is a simple generator for QIDs that atomically increments Path
+// values.
+type QIDGenerator struct {
+ // uids is an ever increasing value that can be atomically incremented
+ // to provide unique Path values for QIDs.
+ uids uint64
+}
+
+// Get returns a new 9P unique ID with a unique Path given a QID type.
+//
+// While the 9P spec allows Version to be incremented every time the file is
+// modified, we currently do not use the Version member for anything. Hence,
+// it is set to 0.
+func (q *QIDGenerator) Get(t QIDType) QID {
+ return QID{
+ Type: t,
+ Version: 0,
+ Path: atomic.AddUint64(&q.uids, 1),
+ }
+}
+
+// FSStat is used by statfs.
+type FSStat struct {
+ // Type is the filesystem type.
+ Type uint32
+
+ // BlockSize is the blocksize.
+ BlockSize uint32
+
+ // Blocks is the number of blocks.
+ Blocks uint64
+
+ // BlocksFree is the number of free blocks.
+ BlocksFree uint64
+
+ // BlocksAvailable is the number of blocks *available*.
+ BlocksAvailable uint64
+
+ // Files is the number of files available.
+ Files uint64
+
+ // FilesFree is the number of free file nodes.
+ FilesFree uint64
+
+ // FSID is the filesystem ID.
+ FSID uint64
+
+ // NameLength is the maximum name length.
+ NameLength uint32
+}
+
+// Decode implements encoder.Decode.
+func (f *FSStat) Decode(b *buffer) {
+ f.Type = b.Read32()
+ f.BlockSize = b.Read32()
+ f.Blocks = b.Read64()
+ f.BlocksFree = b.Read64()
+ f.BlocksAvailable = b.Read64()
+ f.Files = b.Read64()
+ f.FilesFree = b.Read64()
+ f.FSID = b.Read64()
+ f.NameLength = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (f *FSStat) Encode(b *buffer) {
+ b.Write32(f.Type)
+ b.Write32(f.BlockSize)
+ b.Write64(f.Blocks)
+ b.Write64(f.BlocksFree)
+ b.Write64(f.BlocksAvailable)
+ b.Write64(f.Files)
+ b.Write64(f.FilesFree)
+ b.Write64(f.FSID)
+ b.Write32(f.NameLength)
+}
+
+// AttrMask is a mask of attributes for getattr.
+type AttrMask struct {
+ Mode bool
+ NLink bool
+ UID bool
+ GID bool
+ RDev bool
+ ATime bool
+ MTime bool
+ CTime bool
+ INo bool
+ Size bool
+ Blocks bool
+ BTime bool
+ Gen bool
+ DataVersion bool
+}
+
+// Contains returns true if a contains all of the attributes masked as b.
+func (a AttrMask) Contains(b AttrMask) bool {
+ if b.Mode && !a.Mode {
+ return false
+ }
+ if b.NLink && !a.NLink {
+ return false
+ }
+ if b.UID && !a.UID {
+ return false
+ }
+ if b.GID && !a.GID {
+ return false
+ }
+ if b.RDev && !a.RDev {
+ return false
+ }
+ if b.ATime && !a.ATime {
+ return false
+ }
+ if b.MTime && !a.MTime {
+ return false
+ }
+ if b.CTime && !a.CTime {
+ return false
+ }
+ if b.INo && !a.INo {
+ return false
+ }
+ if b.Size && !a.Size {
+ return false
+ }
+ if b.Blocks && !a.Blocks {
+ return false
+ }
+ if b.BTime && !a.BTime {
+ return false
+ }
+ if b.Gen && !a.Gen {
+ return false
+ }
+ if b.DataVersion && !a.DataVersion {
+ return false
+ }
+ return true
+}
+
+// Empty returns true if no fields are masked.
+func (a AttrMask) Empty() bool {
+ return !a.Mode && !a.NLink && !a.UID && !a.GID && !a.RDev && !a.ATime && !a.MTime && !a.CTime && !a.INo && !a.Size && !a.Blocks && !a.BTime && !a.Gen && !a.DataVersion
+}
+
+// AttrMaskAll returns an AttrMask with all fields masked.
+func AttrMaskAll() AttrMask {
+ return AttrMask{
+ Mode: true,
+ NLink: true,
+ UID: true,
+ GID: true,
+ RDev: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ INo: true,
+ Size: true,
+ Blocks: true,
+ BTime: true,
+ Gen: true,
+ DataVersion: true,
+ }
+}
+
+// String implements fmt.Stringer.
+func (a AttrMask) String() string {
+ var masks []string
+ if a.Mode {
+ masks = append(masks, "Mode")
+ }
+ if a.NLink {
+ masks = append(masks, "NLink")
+ }
+ if a.UID {
+ masks = append(masks, "UID")
+ }
+ if a.GID {
+ masks = append(masks, "GID")
+ }
+ if a.RDev {
+ masks = append(masks, "RDev")
+ }
+ if a.ATime {
+ masks = append(masks, "ATime")
+ }
+ if a.MTime {
+ masks = append(masks, "MTime")
+ }
+ if a.CTime {
+ masks = append(masks, "CTime")
+ }
+ if a.INo {
+ masks = append(masks, "INo")
+ }
+ if a.Size {
+ masks = append(masks, "Size")
+ }
+ if a.Blocks {
+ masks = append(masks, "Blocks")
+ }
+ if a.BTime {
+ masks = append(masks, "BTime")
+ }
+ if a.Gen {
+ masks = append(masks, "Gen")
+ }
+ if a.DataVersion {
+ masks = append(masks, "DataVersion")
+ }
+ return fmt.Sprintf("AttrMask{with: %s}", strings.Join(masks, " "))
+}
+
+// Decode implements encoder.Decode.
+func (a *AttrMask) Decode(b *buffer) {
+ mask := b.Read64()
+ a.Mode = mask&0x00000001 != 0
+ a.NLink = mask&0x00000002 != 0
+ a.UID = mask&0x00000004 != 0
+ a.GID = mask&0x00000008 != 0
+ a.RDev = mask&0x00000010 != 0
+ a.ATime = mask&0x00000020 != 0
+ a.MTime = mask&0x00000040 != 0
+ a.CTime = mask&0x00000080 != 0
+ a.INo = mask&0x00000100 != 0
+ a.Size = mask&0x00000200 != 0
+ a.Blocks = mask&0x00000400 != 0
+ a.BTime = mask&0x00000800 != 0
+ a.Gen = mask&0x00001000 != 0
+ a.DataVersion = mask&0x00002000 != 0
+}
+
+// Encode implements encoder.Encode.
+func (a *AttrMask) Encode(b *buffer) {
+ var mask uint64
+ if a.Mode {
+ mask |= 0x00000001
+ }
+ if a.NLink {
+ mask |= 0x00000002
+ }
+ if a.UID {
+ mask |= 0x00000004
+ }
+ if a.GID {
+ mask |= 0x00000008
+ }
+ if a.RDev {
+ mask |= 0x00000010
+ }
+ if a.ATime {
+ mask |= 0x00000020
+ }
+ if a.MTime {
+ mask |= 0x00000040
+ }
+ if a.CTime {
+ mask |= 0x00000080
+ }
+ if a.INo {
+ mask |= 0x00000100
+ }
+ if a.Size {
+ mask |= 0x00000200
+ }
+ if a.Blocks {
+ mask |= 0x00000400
+ }
+ if a.BTime {
+ mask |= 0x00000800
+ }
+ if a.Gen {
+ mask |= 0x00001000
+ }
+ if a.DataVersion {
+ mask |= 0x00002000
+ }
+ b.Write64(mask)
+}
+
+// Attr is a set of attributes for getattr.
+type Attr struct {
+ Mode FileMode
+ UID UID
+ GID GID
+ NLink uint64
+ RDev uint64
+ Size uint64
+ BlockSize uint64
+ Blocks uint64
+ ATimeSeconds uint64
+ ATimeNanoSeconds uint64
+ MTimeSeconds uint64
+ MTimeNanoSeconds uint64
+ CTimeSeconds uint64
+ CTimeNanoSeconds uint64
+ BTimeSeconds uint64
+ BTimeNanoSeconds uint64
+ Gen uint64
+ DataVersion uint64
+}
+
+// String implements fmt.Stringer.
+func (a Attr) String() string {
+ return fmt.Sprintf("Attr{Mode: 0o%o, UID: %d, GID: %d, NLink: %d, RDev: %d, Size: %d, BlockSize: %d, Blocks: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}, CTime: {Sec: %d, NanoSec: %d}, BTime: {Sec: %d, NanoSec: %d}, Gen: %d, DataVersion: %d}",
+ a.Mode, a.UID, a.GID, a.NLink, a.RDev, a.Size, a.BlockSize, a.Blocks, a.ATimeSeconds, a.ATimeNanoSeconds, a.MTimeSeconds, a.MTimeNanoSeconds, a.CTimeSeconds, a.CTimeNanoSeconds, a.BTimeSeconds, a.BTimeNanoSeconds, a.Gen, a.DataVersion)
+}
+
+// Encode implements encoder.Encode.
+func (a *Attr) Encode(b *buffer) {
+ b.WriteFileMode(a.Mode)
+ b.WriteUID(a.UID)
+ b.WriteGID(a.GID)
+ b.Write64(a.NLink)
+ b.Write64(a.RDev)
+ b.Write64(a.Size)
+ b.Write64(a.BlockSize)
+ b.Write64(a.Blocks)
+ b.Write64(a.ATimeSeconds)
+ b.Write64(a.ATimeNanoSeconds)
+ b.Write64(a.MTimeSeconds)
+ b.Write64(a.MTimeNanoSeconds)
+ b.Write64(a.CTimeSeconds)
+ b.Write64(a.CTimeNanoSeconds)
+ b.Write64(a.BTimeSeconds)
+ b.Write64(a.BTimeNanoSeconds)
+ b.Write64(a.Gen)
+ b.Write64(a.DataVersion)
+}
+
+// Decode implements encoder.Decode.
+func (a *Attr) Decode(b *buffer) {
+ a.Mode = b.ReadFileMode()
+ a.UID = b.ReadUID()
+ a.GID = b.ReadGID()
+ a.NLink = b.Read64()
+ a.RDev = b.Read64()
+ a.Size = b.Read64()
+ a.BlockSize = b.Read64()
+ a.Blocks = b.Read64()
+ a.ATimeSeconds = b.Read64()
+ a.ATimeNanoSeconds = b.Read64()
+ a.MTimeSeconds = b.Read64()
+ a.MTimeNanoSeconds = b.Read64()
+ a.CTimeSeconds = b.Read64()
+ a.CTimeNanoSeconds = b.Read64()
+ a.BTimeSeconds = b.Read64()
+ a.BTimeNanoSeconds = b.Read64()
+ a.Gen = b.Read64()
+ a.DataVersion = b.Read64()
+}
+
+// StatToAttr converts a Linux syscall stat structure to an Attr.
+func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) {
+ attr := Attr{
+ UID: NoUID,
+ GID: NoGID,
+ }
+ if req.Mode {
+ // p9.FileMode corresponds to Linux mode_t.
+ attr.Mode = FileMode(s.Mode)
+ }
+ if req.NLink {
+ attr.NLink = s.Nlink
+ }
+ if req.UID {
+ attr.UID = UID(s.Uid)
+ }
+ if req.GID {
+ attr.GID = GID(s.Gid)
+ }
+ if req.RDev {
+ attr.RDev = s.Dev
+ }
+ if req.ATime {
+ attr.ATimeSeconds = uint64(s.Atim.Sec)
+ attr.ATimeNanoSeconds = uint64(s.Atim.Nsec)
+ }
+ if req.MTime {
+ attr.MTimeSeconds = uint64(s.Mtim.Sec)
+ attr.MTimeNanoSeconds = uint64(s.Mtim.Nsec)
+ }
+ if req.CTime {
+ attr.CTimeSeconds = uint64(s.Ctim.Sec)
+ attr.CTimeNanoSeconds = uint64(s.Ctim.Nsec)
+ }
+ if req.Size {
+ attr.Size = uint64(s.Size)
+ }
+ if req.Blocks {
+ attr.BlockSize = uint64(s.Blksize)
+ attr.Blocks = uint64(s.Blocks)
+ }
+
+ // Use the req field because we already have it.
+ req.BTime = false
+ req.Gen = false
+ req.DataVersion = false
+
+ return attr, req
+}
+
+// SetAttrMask specifies a valid mask for setattr.
+type SetAttrMask struct {
+ Permissions bool
+ UID bool
+ GID bool
+ Size bool
+ ATime bool
+ MTime bool
+ CTime bool
+ ATimeNotSystemTime bool
+ MTimeNotSystemTime bool
+}
+
+// IsSubsetOf returns whether s is a subset of m.
+func (s SetAttrMask) IsSubsetOf(m SetAttrMask) bool {
+ sb := s.bitmask()
+ sm := m.bitmask()
+ return sm|sb == sm
+}
+
+// String implements fmt.Stringer.
+func (s SetAttrMask) String() string {
+ var masks []string
+ if s.Permissions {
+ masks = append(masks, "Permissions")
+ }
+ if s.UID {
+ masks = append(masks, "UID")
+ }
+ if s.GID {
+ masks = append(masks, "GID")
+ }
+ if s.Size {
+ masks = append(masks, "Size")
+ }
+ if s.ATime {
+ masks = append(masks, "ATime")
+ }
+ if s.MTime {
+ masks = append(masks, "MTime")
+ }
+ if s.CTime {
+ masks = append(masks, "CTime")
+ }
+ if s.ATimeNotSystemTime {
+ masks = append(masks, "ATimeNotSystemTime")
+ }
+ if s.MTimeNotSystemTime {
+ masks = append(masks, "MTimeNotSystemTime")
+ }
+ return fmt.Sprintf("SetAttrMask{with: %s}", strings.Join(masks, " "))
+}
+
+// Empty returns true if no fields are masked.
+func (s SetAttrMask) Empty() bool {
+ return !s.Permissions && !s.UID && !s.GID && !s.Size && !s.ATime && !s.MTime && !s.CTime && !s.ATimeNotSystemTime && !s.MTimeNotSystemTime
+}
+
+// Decode implements encoder.Decode.
+func (s *SetAttrMask) Decode(b *buffer) {
+ mask := b.Read32()
+ s.Permissions = mask&0x00000001 != 0
+ s.UID = mask&0x00000002 != 0
+ s.GID = mask&0x00000004 != 0
+ s.Size = mask&0x00000008 != 0
+ s.ATime = mask&0x00000010 != 0
+ s.MTime = mask&0x00000020 != 0
+ s.CTime = mask&0x00000040 != 0
+ s.ATimeNotSystemTime = mask&0x00000080 != 0
+ s.MTimeNotSystemTime = mask&0x00000100 != 0
+}
+
+func (s SetAttrMask) bitmask() uint32 {
+ var mask uint32
+ if s.Permissions {
+ mask |= 0x00000001
+ }
+ if s.UID {
+ mask |= 0x00000002
+ }
+ if s.GID {
+ mask |= 0x00000004
+ }
+ if s.Size {
+ mask |= 0x00000008
+ }
+ if s.ATime {
+ mask |= 0x00000010
+ }
+ if s.MTime {
+ mask |= 0x00000020
+ }
+ if s.CTime {
+ mask |= 0x00000040
+ }
+ if s.ATimeNotSystemTime {
+ mask |= 0x00000080
+ }
+ if s.MTimeNotSystemTime {
+ mask |= 0x00000100
+ }
+ return mask
+}
+
+// Encode implements encoder.Encode.
+func (s *SetAttrMask) Encode(b *buffer) {
+ b.Write32(s.bitmask())
+}
+
+// SetAttr specifies a set of attributes for a setattr.
+type SetAttr struct {
+ Permissions FileMode
+ UID UID
+ GID GID
+ Size uint64
+ ATimeSeconds uint64
+ ATimeNanoSeconds uint64
+ MTimeSeconds uint64
+ MTimeNanoSeconds uint64
+}
+
+// String implements fmt.Stringer.
+func (s SetAttr) String() string {
+ return fmt.Sprintf("SetAttr{Permissions: 0o%o, UID: %d, GID: %d, Size: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}}", s.Permissions, s.UID, s.GID, s.Size, s.ATimeSeconds, s.ATimeNanoSeconds, s.MTimeSeconds, s.MTimeNanoSeconds)
+}
+
+// Decode implements encoder.Decode.
+func (s *SetAttr) Decode(b *buffer) {
+ s.Permissions = b.ReadPermissions()
+ s.UID = b.ReadUID()
+ s.GID = b.ReadGID()
+ s.Size = b.Read64()
+ s.ATimeSeconds = b.Read64()
+ s.ATimeNanoSeconds = b.Read64()
+ s.MTimeSeconds = b.Read64()
+ s.MTimeNanoSeconds = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (s *SetAttr) Encode(b *buffer) {
+ b.WritePermissions(s.Permissions)
+ b.WriteUID(s.UID)
+ b.WriteGID(s.GID)
+ b.Write64(s.Size)
+ b.Write64(s.ATimeSeconds)
+ b.Write64(s.ATimeNanoSeconds)
+ b.Write64(s.MTimeSeconds)
+ b.Write64(s.MTimeNanoSeconds)
+}
+
+// Apply applies this to the given Attr.
+func (a *Attr) Apply(mask SetAttrMask, attr SetAttr) {
+ if mask.Permissions {
+ a.Mode = a.Mode&^permissionsMask | (attr.Permissions & permissionsMask)
+ }
+ if mask.UID {
+ a.UID = attr.UID
+ }
+ if mask.GID {
+ a.GID = attr.GID
+ }
+ if mask.Size {
+ a.Size = attr.Size
+ }
+ if mask.ATime {
+ a.ATimeSeconds = attr.ATimeSeconds
+ a.ATimeNanoSeconds = attr.ATimeNanoSeconds
+ }
+ if mask.MTime {
+ a.MTimeSeconds = attr.MTimeSeconds
+ a.MTimeNanoSeconds = attr.MTimeNanoSeconds
+ }
+}
+
+// Dirent is used for readdir.
+type Dirent struct {
+ // QID is the entry QID.
+ QID QID
+
+ // Offset is the offset in the directory.
+ //
+ // This will be communicated back the original caller.
+ Offset uint64
+
+ // Type is the 9P type.
+ Type QIDType
+
+ // Name is the name of the entry (i.e. basename).
+ Name string
+}
+
+// String implements fmt.Stringer.
+func (d Dirent) String() string {
+ return fmt.Sprintf("Dirent{QID: %d, Offset: %d, Type: 0x%X, Name: %s}", d.QID, d.Offset, d.Type, d.Name)
+}
+
+// Decode implements encoder.Decode.
+func (d *Dirent) Decode(b *buffer) {
+ d.QID.Decode(b)
+ d.Offset = b.Read64()
+ d.Type = b.ReadQIDType()
+ d.Name = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (d *Dirent) Encode(b *buffer) {
+ d.QID.Encode(b)
+ b.Write64(d.Offset)
+ b.WriteQIDType(d.Type)
+ b.WriteString(d.Name)
+}
+
+// AllocateMode are possible modes to p9.File.Allocate().
+type AllocateMode struct {
+ KeepSize bool
+ PunchHole bool
+ NoHideStale bool
+ CollapseRange bool
+ ZeroRange bool
+ InsertRange bool
+ Unshare bool
+}
+
+// ToLinux converts to a value compatible with fallocate(2)'s mode.
+func (a *AllocateMode) ToLinux() uint32 {
+ rv := uint32(0)
+ if a.KeepSize {
+ rv |= unix.FALLOC_FL_KEEP_SIZE
+ }
+ if a.PunchHole {
+ rv |= unix.FALLOC_FL_PUNCH_HOLE
+ }
+ if a.NoHideStale {
+ rv |= unix.FALLOC_FL_NO_HIDE_STALE
+ }
+ if a.CollapseRange {
+ rv |= unix.FALLOC_FL_COLLAPSE_RANGE
+ }
+ if a.ZeroRange {
+ rv |= unix.FALLOC_FL_ZERO_RANGE
+ }
+ if a.InsertRange {
+ rv |= unix.FALLOC_FL_INSERT_RANGE
+ }
+ if a.Unshare {
+ rv |= unix.FALLOC_FL_UNSHARE_RANGE
+ }
+ return rv
+}
+
+// Decode implements encoder.Decode.
+func (a *AllocateMode) Decode(b *buffer) {
+ mask := b.Read32()
+ a.KeepSize = mask&0x01 != 0
+ a.PunchHole = mask&0x02 != 0
+ a.NoHideStale = mask&0x04 != 0
+ a.CollapseRange = mask&0x08 != 0
+ a.ZeroRange = mask&0x10 != 0
+ a.InsertRange = mask&0x20 != 0
+ a.Unshare = mask&0x40 != 0
+}
+
+// Encode implements encoder.Encode.
+func (a *AllocateMode) Encode(b *buffer) {
+ mask := uint32(0)
+ if a.KeepSize {
+ mask |= 0x01
+ }
+ if a.PunchHole {
+ mask |= 0x02
+ }
+ if a.NoHideStale {
+ mask |= 0x04
+ }
+ if a.CollapseRange {
+ mask |= 0x08
+ }
+ if a.ZeroRange {
+ mask |= 0x10
+ }
+ if a.InsertRange {
+ mask |= 0x20
+ }
+ if a.Unshare {
+ mask |= 0x40
+ }
+ b.Write32(mask)
+}
diff --git a/pkg/p9/p9_state_autogen.go b/pkg/p9/p9_state_autogen.go
new file mode 100755
index 000000000..0b9556862
--- /dev/null
+++ b/pkg/p9/p9_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package p9
+
diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go
new file mode 100644
index 000000000..f37ad4ab2
--- /dev/null
+++ b/pkg/p9/path_tree.go
@@ -0,0 +1,109 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "sync"
+)
+
+// pathNode is a single node in a path traversal.
+//
+// These are shared by all fidRefs that point to the same path.
+//
+// These are not synchronized because we allow certain operations (file walk)
+// to proceed without having to acquire a write lock. The lock in this
+// structure exists to synchronize high-level, semantic operations, such as the
+// simultaneous creation and deletion of a file.
+//
+// (+) below is the path component string.
+type pathNode struct {
+ mu sync.RWMutex // See above.
+ fidRefs sync.Map // => map[*fidRef]string(+)
+ children sync.Map // => map[string(+)]*pathNode
+ count int64
+}
+
+// pathNodeFor returns the path node for the given name, or a new one.
+//
+// Precondition: mu must be held in a readable fashion.
+func (p *pathNode) pathNodeFor(name string) *pathNode {
+ // Load the existing path node.
+ if pn, ok := p.children.Load(name); ok {
+ return pn.(*pathNode)
+ }
+
+ // Create a new pathNode for shared use.
+ pn, _ := p.children.LoadOrStore(name, new(pathNode))
+ return pn.(*pathNode)
+}
+
+// nameFor returns the name for the given fidRef.
+//
+// Precondition: mu must be held in a readable fashion.
+func (p *pathNode) nameFor(ref *fidRef) string {
+ if s, ok := p.fidRefs.Load(ref); ok {
+ return s.(string)
+ }
+
+ // This should not happen, don't proceed.
+ panic(fmt.Sprintf("expected name for %+v, none found", ref))
+}
+
+// addChild adds a child to the given pathNode.
+//
+// This applies only to an individual fidRef.
+//
+// Precondition: mu must be held in a writable fashion.
+func (p *pathNode) addChild(ref *fidRef, name string) {
+ if s, ok := p.fidRefs.Load(ref); ok {
+ // This should not happen, don't proceed.
+ panic(fmt.Sprintf("unexpected fidRef %+v with path %q, wanted %q", ref, s, name))
+ }
+
+ p.fidRefs.Store(ref, name)
+}
+
+// removeChild removes the given child.
+//
+// This applies only to an individual fidRef.
+//
+// Precondition: mu must be held in a writable fashion.
+func (p *pathNode) removeChild(ref *fidRef) {
+ p.fidRefs.Delete(ref)
+}
+
+// removeWithName removes all references with the given name.
+//
+// The original pathNode is returned by this function, and removed from this
+// pathNode. Any operations on the removed tree must use this value.
+//
+// The provided function is executed after removal.
+//
+// Precondition: mu must be held in a writable fashion.
+func (p *pathNode) removeWithName(name string, fn func(ref *fidRef)) *pathNode {
+ p.fidRefs.Range(func(key, value interface{}) bool {
+ if value.(string) == name {
+ p.fidRefs.Delete(key)
+ fn(key.(*fidRef))
+ }
+ return true
+ })
+
+ // Return the original path node.
+ origPathNode := p.pathNodeFor(name)
+ p.children.Delete(name)
+ return origPathNode
+}
diff --git a/pkg/p9/pool.go b/pkg/p9/pool.go
new file mode 100644
index 000000000..52de889e1
--- /dev/null
+++ b/pkg/p9/pool.go
@@ -0,0 +1,68 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "sync"
+)
+
+// pool is a simple allocator.
+//
+// It is used for both tags and FIDs.
+type pool struct {
+ mu sync.Mutex
+
+ // cache is the set of returned values.
+ cache []uint64
+
+ // start is the starting value (if needed).
+ start uint64
+
+ // max is the current maximum issued.
+ max uint64
+
+ // limit is the upper limit.
+ limit uint64
+}
+
+// Get gets a value from the pool.
+func (p *pool) Get() (uint64, bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Anything cached?
+ if len(p.cache) > 0 {
+ v := p.cache[len(p.cache)-1]
+ p.cache = p.cache[:len(p.cache)-1]
+ return v, true
+ }
+
+ // Over the limit?
+ if p.start == p.limit {
+ return 0, false
+ }
+
+ // Generate a new value.
+ v := p.start
+ p.start++
+ return v, true
+}
+
+// Put returns a value to the pool.
+func (p *pool) Put(v uint64) {
+ p.mu.Lock()
+ p.cache = append(p.cache, v)
+ p.mu.Unlock()
+}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
new file mode 100644
index 000000000..f377a6557
--- /dev/null
+++ b/pkg/p9/server.go
@@ -0,0 +1,575 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "io"
+ "runtime/debug"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// Server is a 9p2000.L server.
+type Server struct {
+ // attacher provides the attach function.
+ attacher Attacher
+
+ // pathTree is the full set of paths opened on this server.
+ //
+ // These may be across different connections, but rename operations
+ // must be serialized globally for safely. There is a single pathTree
+ // for the entire server, and not per connection.
+ pathTree pathNode
+
+ // renameMu is a global lock protecting rename operations. With this
+ // lock, we can be certain that any given rename operation can safely
+ // acquire two path nodes in any order, as all other concurrent
+ // operations acquire at most a single node.
+ renameMu sync.RWMutex
+}
+
+// NewServer returns a new server.
+//
+func NewServer(attacher Attacher) *Server {
+ return &Server{
+ attacher: attacher,
+ }
+}
+
+// connState is the state for a single connection.
+type connState struct {
+ // server is the backing server.
+ server *Server
+
+ // sendMu is the send lock.
+ sendMu sync.Mutex
+
+ // conn is the connection.
+ conn *unet.Socket
+
+ // fids is the set of active FIDs.
+ //
+ // This is used to find FIDs for files.
+ fidMu sync.Mutex
+ fids map[FID]*fidRef
+
+ // tags is the set of active tags.
+ //
+ // The given channel is closed when the
+ // tag is finished with processing.
+ tagMu sync.Mutex
+ tags map[Tag]chan struct{}
+
+ // messageSize is the maximum message size. The server does not
+ // do automatic splitting of messages.
+ messageSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // recvOkay indicates that a receive may start.
+ recvOkay chan bool
+
+ // recvDone is signalled when a message is received.
+ recvDone chan error
+
+ // sendDone is signalled when a send is finished.
+ sendDone chan error
+}
+
+// fidRef wraps a node and tracks references.
+type fidRef struct {
+ // server is the associated server.
+ server *Server
+
+ // file is the associated File.
+ file File
+
+ // refs is an active refence count.
+ //
+ // The node above will be closed only when refs reaches zero.
+ refs int64
+
+ // openedMu protects opened and openFlags.
+ openedMu sync.Mutex
+
+ // opened indicates whether this has been opened already.
+ //
+ // This is updated in handlers.go.
+ opened bool
+
+ // mode is the fidRef's mode from the walk. Only the type bits are
+ // valid, the permissions may change. This is used to sanity check
+ // operations on this element, and prevent walks across
+ // non-directories.
+ mode FileMode
+
+ // openFlags is the mode used in the open.
+ //
+ // This is updated in handlers.go.
+ openFlags OpenFlags
+
+ // pathNode is the current pathNode for this FID.
+ pathNode *pathNode
+
+ // parent is the parent fidRef. We hold on to a parent reference to
+ // ensure that hooks, such as Renamed, can be executed safely by the
+ // server code.
+ //
+ // Note that parent cannot be changed without holding both the global
+ // rename lock and a writable lock on the associated pathNode for this
+ // fidRef. Holding either of these locks is sufficient to examine
+ // parent safely.
+ //
+ // The parent will be nil for root fidRefs, and non-nil otherwise. The
+ // method maybeParent can be used to return a cyclical reference, and
+ // isRoot should be used to check for root over looking at parent
+ // directly.
+ parent *fidRef
+
+ // deleted indicates that the backing file has been deleted. We stop
+ // many operations at the API level if they are incompatible with a
+ // file that has already been unlinked.
+ deleted uint32
+}
+
+// OpenFlags returns the flags the file was opened with and true iff the fid was opened previously.
+func (f *fidRef) OpenFlags() (OpenFlags, bool) {
+ f.openedMu.Lock()
+ defer f.openedMu.Unlock()
+ return f.openFlags, f.opened
+}
+
+// IncRef increases the references on a fid.
+func (f *fidRef) IncRef() {
+ atomic.AddInt64(&f.refs, 1)
+}
+
+// DecRef should be called when you're finished with a fid.
+func (f *fidRef) DecRef() {
+ if atomic.AddInt64(&f.refs, -1) == 0 {
+ f.file.Close()
+
+ // Drop the parent reference.
+ //
+ // Since this fidRef is guaranteed to be non-discoverable when
+ // the references reach zero, we don't need to worry about
+ // clearing the parent.
+ if f.parent != nil {
+ // If we've been previously deleted, this removing this
+ // ref is a no-op. That's expected.
+ f.parent.pathNode.removeChild(f)
+ f.parent.DecRef()
+ }
+ }
+}
+
+// isDeleted returns true if this fidRef has been deleted.
+func (f *fidRef) isDeleted() bool {
+ return atomic.LoadUint32(&f.deleted) != 0
+}
+
+// isRoot indicates whether this is a root fid.
+func (f *fidRef) isRoot() bool {
+ return f.parent == nil
+}
+
+// maybeParent returns a cyclic reference for roots, and the parent otherwise.
+func (f *fidRef) maybeParent() *fidRef {
+ if f.parent != nil {
+ return f.parent
+ }
+ return f // Root has itself.
+}
+
+// notifyDelete marks all fidRefs as deleted.
+//
+// Precondition: the write lock must be held on the given pathNode.
+func notifyDelete(pn *pathNode) {
+ // Call on all local references.
+ pn.fidRefs.Range(func(key, _ interface{}) bool {
+ ref := key.(*fidRef)
+ atomic.StoreUint32(&ref.deleted, 1)
+ return true
+ })
+
+ // Call on all subtrees.
+ pn.children.Range(func(_, value interface{}) bool {
+ notifyDelete(value.(*pathNode))
+ return true
+ })
+}
+
+// markChildDeleted marks all children below the given name as deleted.
+//
+// Precondition: this must be called via safelyWrite or safelyGlobal.
+func (f *fidRef) markChildDeleted(name string) {
+ origPathNode := f.pathNode.removeWithName(name, func(ref *fidRef) {
+ atomic.StoreUint32(&ref.deleted, 1)
+ })
+
+ // Mark everything below as deleted.
+ notifyDelete(origPathNode)
+}
+
+// notifyNameChange calls the relevant Renamed method on all nodes in the path,
+// recursively. Note that this applies only for subtrees, as these
+// notifications do not apply to the actual file whose name has changed.
+//
+// Precondition: the write lock must be held on the given pathNode.
+func notifyNameChange(pn *pathNode) {
+ // Call on all local references.
+ pn.fidRefs.Range(func(key, value interface{}) bool {
+ ref := key.(*fidRef)
+ name := value.(string)
+ ref.file.Renamed(ref.parent.file, name)
+ return true
+ })
+
+ // Call on all subtrees.
+ pn.children.Range(func(_, value interface{}) bool {
+ notifyNameChange(value.(*pathNode))
+ return true
+ })
+}
+
+// renameChildTo renames the given child to the target.
+//
+// Precondition: this must be called via safelyGlobal.
+func (f *fidRef) renameChildTo(oldName string, target *fidRef, newName string) {
+ target.markChildDeleted(newName)
+ origPathNode := f.pathNode.removeWithName(oldName, func(ref *fidRef) {
+ ref.parent.DecRef() // Drop original reference.
+ ref.parent = target // Change parent.
+ ref.parent.IncRef() // Acquire new one.
+ target.pathNode.addChild(ref, newName)
+ ref.file.Renamed(target.file, newName)
+ })
+
+ // Replace the previous (now deleted) path node.
+ f.pathNode.children.Store(newName, origPathNode)
+
+ // Call Renamed on everything above.
+ notifyNameChange(origPathNode)
+}
+
+// safelyRead executes the given operation with the local path node locked.
+// This implies that paths will not change during the operation.
+func (f *fidRef) safelyRead(fn func() error) (err error) {
+ f.server.renameMu.RLock()
+ defer f.server.renameMu.RUnlock()
+ f.pathNode.mu.RLock()
+ defer f.pathNode.mu.RUnlock()
+ return fn()
+}
+
+// safelyWrite executes the given operation with the local path node locked in
+// a writable fashion. This implies some paths may change.
+func (f *fidRef) safelyWrite(fn func() error) (err error) {
+ f.server.renameMu.RLock()
+ defer f.server.renameMu.RUnlock()
+ f.pathNode.mu.Lock()
+ defer f.pathNode.mu.Unlock()
+ return fn()
+}
+
+// safelyGlobal executes the given operation with the global path lock held.
+func (f *fidRef) safelyGlobal(fn func() error) (err error) {
+ f.server.renameMu.Lock()
+ defer f.server.renameMu.Unlock()
+ return fn()
+}
+
+// LookupFID finds the given FID.
+//
+// You should call fid.DecRef when you are finished using the fid.
+func (cs *connState) LookupFID(fid FID) (*fidRef, bool) {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ fidRef, ok := cs.fids[fid]
+ if ok {
+ fidRef.IncRef()
+ return fidRef, true
+ }
+ return nil, false
+}
+
+// InsertFID installs the given FID.
+//
+// This fid starts with a reference count of one. If a FID exists in
+// the slot already it is closed, per the specification.
+func (cs *connState) InsertFID(fid FID, newRef *fidRef) {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ origRef, ok := cs.fids[fid]
+ if ok {
+ defer origRef.DecRef()
+ }
+ newRef.IncRef()
+ cs.fids[fid] = newRef
+}
+
+// DeleteFID removes the given FID.
+//
+// This simply removes it from the map and drops a reference.
+func (cs *connState) DeleteFID(fid FID) bool {
+ cs.fidMu.Lock()
+ defer cs.fidMu.Unlock()
+ fidRef, ok := cs.fids[fid]
+ if !ok {
+ return false
+ }
+ delete(cs.fids, fid)
+ fidRef.DecRef()
+ return true
+}
+
+// StartTag starts handling the tag.
+//
+// False is returned if this tag is already active.
+func (cs *connState) StartTag(t Tag) bool {
+ cs.tagMu.Lock()
+ defer cs.tagMu.Unlock()
+ _, ok := cs.tags[t]
+ if ok {
+ return false
+ }
+ cs.tags[t] = make(chan struct{})
+ return true
+}
+
+// ClearTag finishes handling a tag.
+func (cs *connState) ClearTag(t Tag) {
+ cs.tagMu.Lock()
+ defer cs.tagMu.Unlock()
+ ch, ok := cs.tags[t]
+ if !ok {
+ // Should never happen.
+ panic("unused tag cleared")
+ }
+ delete(cs.tags, t)
+
+ // Notify.
+ close(ch)
+}
+
+// WaitTag waits for a tag to finish.
+func (cs *connState) WaitTag(t Tag) {
+ cs.tagMu.Lock()
+ ch, ok := cs.tags[t]
+ cs.tagMu.Unlock()
+ if !ok {
+ return
+ }
+
+ // Wait for close.
+ <-ch
+}
+
+// handleRequest handles a single request.
+//
+// The recvDone channel is signaled when recv is done (with a error if
+// necessary). The sendDone channel is signaled with the result of the send.
+func (cs *connState) handleRequest() {
+ messageSize := atomic.LoadUint32(&cs.messageSize)
+ if messageSize == 0 {
+ // Default or not yet negotiated.
+ messageSize = maximumLength
+ }
+
+ // Receive a message.
+ tag, m, err := recv(cs.conn, messageSize, msgRegistry.get)
+ if errSocket, ok := err.(ErrSocket); ok {
+ // Connection problem; stop serving.
+ cs.recvDone <- errSocket.error
+ return
+ }
+
+ // Signal receive is done.
+ cs.recvDone <- nil
+
+ // Deal with other errors.
+ if err != nil && err != io.EOF {
+ // If it's not a connection error, but some other protocol error,
+ // we can send a response immediately.
+ cs.sendMu.Lock()
+ err := send(cs.conn, tag, newErr(err))
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
+ return
+ }
+
+ // Try to start the tag.
+ if !cs.StartTag(tag) {
+ // Nothing we can do at this point; client is bogus.
+ log.Debugf("no valid tag [%05d]", tag)
+ cs.sendDone <- ErrNoValidMessage
+ return
+ }
+
+ // Handle the message.
+ var r message // r is the response.
+ defer func() {
+ if r == nil {
+ // Don't allow a panic to propagate.
+ recover()
+
+ // Include a useful log message.
+ log.Warningf("panic in handler: %s", debug.Stack())
+
+ // Wrap in an EFAULT error; we don't really have a
+ // better way to describe this kind of error. It will
+ // usually manifest as a result of the test framework.
+ r = newErr(syscall.EFAULT)
+ }
+
+ // Clear the tag before sending. That's because as soon as this
+ // hits the wire, the client can legally send another message
+ // with the same tag.
+ cs.ClearTag(tag)
+
+ // Send back the result.
+ cs.sendMu.Lock()
+ err = send(cs.conn, tag, r)
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
+ }()
+ if handler, ok := m.(handler); ok {
+ // Call the message handler.
+ r = handler.handle(cs)
+ } else {
+ // Produce an ENOSYS error.
+ r = newErr(syscall.ENOSYS)
+ }
+ msgRegistry.put(m)
+ m = nil // 'm' should not be touched after this point.
+}
+
+func (cs *connState) handleRequests() {
+ for range cs.recvOkay {
+ cs.handleRequest()
+ }
+}
+
+func (cs *connState) stop() {
+ // Close all channels.
+ close(cs.recvOkay)
+ close(cs.recvDone)
+ close(cs.sendDone)
+
+ for _, fidRef := range cs.fids {
+ // Drop final reference in the FID table. Note this should
+ // always close the file, since we've ensured that there are no
+ // handlers running via the wait for Pending => 0 below.
+ fidRef.DecRef()
+ }
+
+ // Ensure the connection is closed.
+ cs.conn.Close()
+}
+
+// service services requests concurrently.
+func (cs *connState) service() error {
+ // Pending is the number of handlers that have finished receiving but
+ // not finished processing requests. These must be waiting on properly
+ // below. See the next comment for an explanation of the loop.
+ pending := 0
+
+ // Start the first request handler.
+ go cs.handleRequests() // S/R-SAFE: Irrelevant.
+ cs.recvOkay <- true
+
+ // We loop and make sure there's always one goroutine waiting for a new
+ // request. We process all the data for a single request in one
+ // goroutine however, to ensure the best turnaround time possible.
+ for {
+ select {
+ case err := <-cs.recvDone:
+ if err != nil {
+ // Wait for pending handlers.
+ for i := 0; i < pending; i++ {
+ <-cs.sendDone
+ }
+ return err
+ }
+
+ // This handler is now pending.
+ pending++
+
+ // Kick the next receiver, or start a new handler
+ // if no receiver is currently waiting.
+ select {
+ case cs.recvOkay <- true:
+ default:
+ go cs.handleRequests() // S/R-SAFE: Irrelevant.
+ cs.recvOkay <- true
+ }
+
+ case <-cs.sendDone:
+ // This handler is finished.
+ pending--
+
+ // Error sending a response? Nothing can be done.
+ //
+ // We don't terminate on a send error though, since
+ // we still have a pending receive. The error would
+ // have been logged above, we just ignore it here.
+ }
+ }
+}
+
+// Handle handles a single connection.
+func (s *Server) Handle(conn *unet.Socket) error {
+ cs := &connState{
+ server: s,
+ conn: conn,
+ fids: make(map[FID]*fidRef),
+ tags: make(map[Tag]chan struct{}),
+ recvOkay: make(chan bool),
+ recvDone: make(chan error, 10),
+ sendDone: make(chan error, 10),
+ }
+ defer cs.stop()
+ return cs.service()
+}
+
+// Serve handles requests from the bound socket.
+//
+// The passed serverSocket _must_ be created in packet mode.
+func (s *Server) Serve(serverSocket *unet.ServerSocket) error {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ for {
+ conn, err := serverSocket.Accept()
+ if err != nil {
+ // Something went wrong.
+ //
+ // Socket closed?
+ return err
+ }
+
+ wg.Add(1)
+ go func(conn *unet.Socket) { // S/R-SAFE: Irrelevant.
+ s.Handle(conn)
+ wg.Done()
+ }(conn)
+ }
+}
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
new file mode 100644
index 000000000..ef59077ff
--- /dev/null
+++ b/pkg/p9/transport.go
@@ -0,0 +1,342 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// ErrSocket is returned in cases of a socket issue.
+//
+// This may be treated differently than other errors.
+type ErrSocket struct {
+ // error is the socket error.
+ error
+}
+
+// ErrMessageTooLarge indicates the size was larger than reasonable.
+type ErrMessageTooLarge struct {
+ size uint32
+ msize uint32
+}
+
+// Error returns a sensible error.
+func (e *ErrMessageTooLarge) Error() string {
+ return fmt.Sprintf("message too large for fixed buffer: size is %d, limit is %d", e.size, e.msize)
+}
+
+// ErrNoValidMessage indicates no valid message could be decoded.
+var ErrNoValidMessage = errors.New("buffer contained no valid message")
+
+const (
+ // headerLength is the number of bytes required for a header.
+ headerLength uint32 = 7
+
+ // maximumLength is the largest possible message.
+ maximumLength uint32 = 4 * 1024 * 1024
+
+ // initialBufferLength is the initial data buffer we allocate.
+ initialBufferLength uint32 = 64
+)
+
+var dataPool = sync.Pool{
+ New: func() interface{} {
+ // These buffers are used for decoding without a payload.
+ return make([]byte, initialBufferLength)
+ },
+}
+
+// send sends the given message over the socket.
+func send(s *unet.Socket, tag Tag, m message) error {
+ data := dataPool.Get().([]byte)
+ dataBuf := buffer{data: data[:0]}
+
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
+ }
+
+ // Encode the message. The buffer will grow automatically.
+ m.Encode(&dataBuf)
+
+ // Get our vectors to send.
+ var hdr [headerLength]byte
+ vecs := make([][]byte, 0, 3)
+ vecs = append(vecs, hdr[:])
+ if len(dataBuf.data) > 0 {
+ vecs = append(vecs, dataBuf.data)
+ }
+ totalLength := headerLength + uint32(len(dataBuf.data))
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ if len(p) > 0 {
+ vecs = append(vecs, p)
+ totalLength += uint32(len(p))
+ }
+ }
+
+ // Construct the header.
+ headerBuf := buffer{data: hdr[:0]}
+ headerBuf.Write32(totalLength)
+ headerBuf.WriteMsgType(m.Type())
+ headerBuf.WriteTag(tag)
+
+ // Pack any files if necessary.
+ w := s.Writer(true)
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ defer f.Close()
+ // Pack the file into the message.
+ w.PackFDs(f.FD())
+ }
+ }
+
+ for n := 0; n < int(totalLength); {
+ cur, err := w.WriteVec(vecs)
+ if err != nil {
+ return ErrSocket{err}
+ }
+ n += cur
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(vecs[0]) <= cur-consumed {
+ consumed += len(vecs[0])
+ vecs = vecs[1:]
+ } else {
+ vecs[0] = vecs[0][cur-consumed:]
+ break
+ }
+ }
+
+ if n > 0 && n < int(totalLength) {
+ // Don't resend any control message.
+ w.UnpackFDs()
+ }
+ }
+
+ // All set.
+ dataPool.Put(dataBuf.data)
+ return nil
+}
+
+// lookupTagAndType looks up an existing message or creates a new one.
+//
+// This is called by recv after decoding the header. Any error returned will be
+// propagating back to the caller. You may use messageByType directly as a
+// lookupTagAndType function (by design).
+type lookupTagAndType func(tag Tag, t MsgType) (message, error)
+
+// recv decodes a message from the socket.
+//
+// This is done in two parts, and is thus not safe for multiple callers.
+//
+// On a socket error, the special error type ErrSocket is returned.
+//
+// The tag value NoTag will always be returned if err is non-nil.
+func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, error) {
+ // Read a header.
+ //
+ // Since the send above is atomic, we must always receive control
+ // messages along with the header. This means we need to be careful
+ // about closing FDs during errors to prevent leaks.
+ var hdr [headerLength]byte
+ r := s.Reader(true)
+ r.EnableFDs(1)
+
+ n, err := r.ReadVec([][]byte{hdr[:]})
+ if err != nil && (n == 0 || err != io.EOF) {
+ r.CloseFDs()
+ return NoTag, nil, ErrSocket{err}
+ }
+
+ fds, err := r.ExtractFDs()
+ if err != nil {
+ return NoTag, nil, ErrSocket{err}
+ }
+ defer func() {
+ // Close anything left open. The case where
+ // fds are caught and used is handled below,
+ // and the fds variable will be set to nil.
+ for _, fd := range fds {
+ syscall.Close(fd)
+ }
+ }()
+ r.EnableFDs(0)
+
+ // Continuing reading for a short header.
+ for n < int(headerLength) {
+ cur, err := r.ReadVec([][]byte{hdr[n:]})
+ if err != nil && (cur == 0 || err != io.EOF) {
+ return NoTag, nil, ErrSocket{err}
+ }
+ n += cur
+ }
+
+ // Decode the header.
+ headerBuf := buffer{data: hdr[:]}
+ size := headerBuf.Read32()
+ t := headerBuf.ReadMsgType()
+ tag := headerBuf.ReadTag()
+ if size < headerLength {
+ // The message is too small.
+ //
+ // See above: it's probably screwed.
+ return NoTag, nil, ErrSocket{ErrNoValidMessage}
+ }
+ if size > maximumLength || size > msize {
+ // The message is too big.
+ return NoTag, nil, ErrSocket{&ErrMessageTooLarge{size, msize}}
+ }
+ remaining := size - headerLength
+
+ // Find our message to decode.
+ m, err := lookup(tag, t)
+ if err != nil {
+ // Throw away the contents of this message.
+ if remaining > 0 {
+ io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
+ }
+ return tag, nil, err
+ }
+
+ // Not yet initialized.
+ var dataBuf buffer
+
+ // Read the rest of the payload.
+ //
+ // This requires some special care to ensure that the vectors all line
+ // up the way they should. We do this to minimize copying data around.
+ var vecs [][]byte
+ if payloader, ok := m.(payloader); ok {
+ fixedSize := payloader.FixedSize()
+
+ // Do we need more than there is?
+ if fixedSize > remaining {
+ // This is not a valid message.
+ if remaining > 0 {
+ io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
+ }
+ return NoTag, nil, ErrNoValidMessage
+ }
+
+ if fixedSize != 0 {
+ // Pull a data buffer from the pool.
+ data := dataPool.Get().([]byte)
+ if int(fixedSize) > len(data) {
+ // Create a larger data buffer, ensuring
+ // sufficient capicity for the message.
+ data = make([]byte, fixedSize)
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data}
+ vecs = append(vecs, data)
+ } else {
+ // Limit the data buffer, and make sure it
+ // gets filled before the payload buffer.
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data[:fixedSize]}
+ vecs = append(vecs, data[:fixedSize])
+ }
+ }
+
+ // Include the payload.
+ p := payloader.Payload()
+ if p == nil || len(p) != int(remaining-fixedSize) {
+ p = make([]byte, remaining-fixedSize)
+ payloader.SetPayload(p)
+ }
+ if len(p) > 0 {
+ vecs = append(vecs, p)
+ }
+ } else if remaining != 0 {
+ // Pull a data buffer from the pool.
+ data := dataPool.Get().([]byte)
+ if int(remaining) > len(data) {
+ // Create a larger data buffer.
+ data = make([]byte, remaining)
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data}
+ vecs = append(vecs, data)
+ } else {
+ // Limit the data buffer.
+ defer dataPool.Put(data)
+ dataBuf = buffer{data: data[:remaining]}
+ vecs = append(vecs, data[:remaining])
+ }
+ }
+
+ if len(vecs) > 0 {
+ // Read the rest of the message.
+ //
+ // No need to handle a control message.
+ r := s.Reader(true)
+ for n := 0; n < int(remaining); {
+ cur, err := r.ReadVec(vecs)
+ if err != nil && (cur == 0 || err != io.EOF) {
+ return NoTag, nil, ErrSocket{err}
+ }
+ n += cur
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(vecs[0]) <= cur-consumed {
+ consumed += len(vecs[0])
+ vecs = vecs[1:]
+ } else {
+ vecs[0] = vecs[0][cur-consumed:]
+ break
+ }
+ }
+ }
+ }
+
+ // Decode the message data.
+ m.Decode(&dataBuf)
+ if dataBuf.isOverrun() {
+ // No need to drain the socket.
+ return NoTag, nil, ErrNoValidMessage
+ }
+
+ // Save the file, if any came out.
+ if filer, ok := m.(filer); ok && len(fds) > 0 {
+ // Set the file object.
+ filer.SetFilePayload(fd.New(fds[0]))
+
+ // Close the rest. We support only one.
+ for i := 1; i < len(fds); i++ {
+ syscall.Close(fds[i])
+ }
+
+ // Don't close in the defer.
+ fds = nil
+ }
+
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
+ }
+
+ // All set.
+ return tag, m, nil
+}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
new file mode 100644
index 000000000..c2a2885ae
--- /dev/null
+++ b/pkg/p9/version.go
@@ -0,0 +1,150 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+const (
+ // highestSupportedVersion is the highest supported version X in a
+ // version string of the format 9P2000.L.Google.X.
+ //
+ // Clients are expected to start requesting this version number and
+ // to continuously decrement it until a Tversion request succeeds.
+ highestSupportedVersion uint32 = 7
+
+ // lowestSupportedVersion is the lowest supported version X in a
+ // version string of the format 9P2000.L.Google.X.
+ //
+ // Clients are free to send a Tversion request at a version below this
+ // value but are expected to encounter an Rlerror in response.
+ lowestSupportedVersion uint32 = 0
+
+ // baseVersion is the base version of 9P that this package must always
+ // support. It is equivalent to 9P2000.L.Google.0.
+ baseVersion = "9P2000.L"
+)
+
+// HighestVersionString returns the highest possible version string that a client
+// may request or a server may support.
+func HighestVersionString() string {
+ return versionString(highestSupportedVersion)
+}
+
+// parseVersion parses a Tversion version string into a numeric version number
+// if the version string is supported by p9. Otherwise returns (0, false).
+//
+// From Tversion(9P): "Version strings are defined such that, if the client string
+// contains one or more period characters, the initial substring up to but not
+// including any single period in the version string defines a version of the protocol."
+//
+// p9 intentionally diverges from this and always requires that the version string
+// start with 9P2000.L to express that it is always compatible with 9P2000.L. The
+// only supported versions extensions are of the format 9p2000.L.Google.X where X
+// is an ever increasing version counter.
+//
+// Version 9P2000.L.Google.0 implies 9P2000.L.
+//
+// New versions must always be a strict superset of 9P2000.L. A version increase must
+// define a predicate representing the feature extension introduced by that version. The
+// predicate must be commented and should take the format:
+//
+// // VersionSupportsX returns true if version v supports X and must be checked when ...
+// func VersionSupportsX(v int32) bool {
+// ...
+// )
+func parseVersion(str string) (uint32, bool) {
+ // Special case the base version which lacks the ".Google.X" suffix. This
+ // version always means version 0.
+ if str == baseVersion {
+ return 0, true
+ }
+ substr := strings.Split(str, ".")
+ if len(substr) != 4 {
+ return 0, false
+ }
+ if substr[0] != "9P2000" || substr[1] != "L" || substr[2] != "Google" || len(substr[3]) == 0 {
+ return 0, false
+ }
+ version, err := strconv.ParseUint(substr[3], 10, 32)
+ if err != nil {
+ return 0, false
+ }
+ return uint32(version), true
+}
+
+// versionString formats a p9 version number into a Tversion version string.
+func versionString(version uint32) string {
+ // Special case the base version so that clients expecting this string
+ // instead of the 9P2000.L.Google.0 equivalent get it. This is important
+ // for backwards compatibility with legacy servers that check for exactly
+ // the baseVersion and allow nothing else.
+ if version == 0 {
+ return baseVersion
+ }
+ return fmt.Sprintf("9P2000.L.Google.%d", version)
+}
+
+// VersionSupportsTflushf returns true if version v supports the Tflushf message.
+// This predicate must be checked by clients before attempting to make a Tflushf
+// request. If this predicate returns false, then clients may safely no-op.
+func VersionSupportsTflushf(v uint32) bool {
+ return v >= 1
+}
+
+// versionSupportsTwalkgetattr returns true if version v supports the
+// Twalkgetattr message. This predicate must be checked by clients before
+// attempting to make a Twalkgetattr request.
+func versionSupportsTwalkgetattr(v uint32) bool {
+ return v >= 2
+}
+
+// versionSupportsTucreation returns true if version v supports the Tucreation
+// messages (Tucreate, Tusymlink, Tumkdir, Tumknod). This predicate must be
+// checked by clients before attempting to make a Tucreation request.
+// If Tucreation messages are not supported, their non-UID supporting
+// counterparts (Tlcreate, Tsymlink, Tmkdir, Tmknod) should be used.
+func versionSupportsTucreation(v uint32) bool {
+ return v >= 3
+}
+
+// VersionSupportsConnect returns true if version v supports the Tlconnect
+// message. This predicate must be checked by clients
+// before attempting to make a Tlconnect request. If Tlconnect messages are not
+// supported, Tlopen should be used.
+func VersionSupportsConnect(v uint32) bool {
+ return v >= 4
+}
+
+// VersionSupportsAnonymous returns true if version v supports Tlconnect
+// with the AnonymousSocket mode. This predicate must be checked by clients
+// before attempting to use the AnonymousSocket Tlconnect mode.
+func VersionSupportsAnonymous(v uint32) bool {
+ return v >= 5
+}
+
+// VersionSupportsMultiUser returns true if version v supports multi-user fake
+// directory permissions and ID values.
+func VersionSupportsMultiUser(v uint32) bool {
+ return v >= 6
+}
+
+// versionSupportsTallocate returns true if version v supports Allocate().
+func versionSupportsTallocate(v uint32) bool {
+ return v >= 7
+}
diff --git a/pkg/rand/rand.go b/pkg/rand/rand.go
new file mode 100644
index 000000000..a2714784d
--- /dev/null
+++ b/pkg/rand/rand.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux
+
+// Package rand implements a cryptographically secure pseudorandom number
+// generator.
+package rand
+
+import "crypto/rand"
+
+// Reader is the default reader.
+var Reader = rand.Reader
+
+// Read implements io.Reader.Read.
+func Read(b []byte) (int, error) {
+ return rand.Read(b)
+}
diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go
new file mode 100644
index 000000000..2b92db3e6
--- /dev/null
+++ b/pkg/rand/rand_linux.go
@@ -0,0 +1,62 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package rand implements a cryptographically secure pseudorandom number
+// generator.
+package rand
+
+import (
+ "crypto/rand"
+ "io"
+ "sync"
+
+ "golang.org/x/sys/unix"
+)
+
+// reader implements an io.Reader that returns pseudorandom bytes.
+type reader struct {
+ once sync.Once
+ useGetrandom bool
+}
+
+// Read implements io.Reader.Read.
+func (r *reader) Read(p []byte) (int, error) {
+ r.once.Do(func() {
+ _, err := unix.Getrandom(p, 0)
+ if err != unix.ENOSYS {
+ r.useGetrandom = true
+ }
+ })
+
+ if r.useGetrandom {
+ return unix.Getrandom(p, 0)
+ }
+ return rand.Read(p)
+}
+
+// Reader is the default reader.
+var Reader io.Reader = &reader{}
+
+// Read reads from the default reader.
+func Read(b []byte) (int, error) {
+ return io.ReadFull(Reader, b)
+}
+
+// Init can be called to make sure /dev/urandom is pre-opened on kernels that
+// do not support getrandom(2).
+func Init() error {
+ p := make([]byte, 1)
+ _, err := Read(p)
+ return err
+}
diff --git a/pkg/rand/rand_state_autogen.go b/pkg/rand/rand_state_autogen.go
new file mode 100755
index 000000000..e46e9ec7e
--- /dev/null
+++ b/pkg/rand/rand_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package rand
+
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
new file mode 100644
index 000000000..20f515391
--- /dev/null
+++ b/pkg/refs/refcounter.go
@@ -0,0 +1,303 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package refs defines an interface for reference counted objects. It
+// also provides a drop-in implementation called AtomicRefCount.
+package refs
+
+import (
+ "reflect"
+ "sync"
+ "sync/atomic"
+)
+
+// RefCounter is the interface to be implemented by objects that are reference
+// counted.
+type RefCounter interface {
+ // IncRef increments the reference counter on the object.
+ IncRef()
+
+ // DecRef decrements the reference counter on the object.
+ //
+ // Note that AtomicRefCounter.DecRef() does not support destructors.
+ // If a type has a destructor, it must implement its own DecRef()
+ // method and call AtomicRefCounter.DecRefWithDestructor(destructor).
+ DecRef()
+
+ // TryIncRef attempts to increase the reference counter on the object,
+ // but may fail if all references have already been dropped. This
+ // should be used only in special circumstances, such as WeakRefs.
+ TryIncRef() bool
+
+ // addWeakRef adds the given weak reference. Note that you should have a
+ // reference to the object when calling this method.
+ addWeakRef(*WeakRef)
+
+ // dropWeakRef drops the given weak reference. Note that you should have
+ // a reference to the object when calling this method.
+ dropWeakRef(*WeakRef)
+}
+
+// A WeakRefUser is notified when the last non-weak reference is dropped.
+type WeakRefUser interface {
+ // WeakRefGone is called when the last non-weak reference is dropped.
+ WeakRefGone()
+}
+
+// WeakRef is a weak reference.
+//
+// +stateify savable
+type WeakRef struct {
+ weakRefEntry `state:"nosave"`
+
+ // obj is an atomic value that points to the refCounter.
+ obj atomic.Value `state:".(savedReference)"`
+
+ // user is notified when the weak ref is zapped by the object getting
+ // destroyed.
+ user WeakRefUser
+}
+
+// weakRefPool is a pool of weak references to avoid allocations on the hot path.
+var weakRefPool = sync.Pool{
+ New: func() interface{} {
+ return &WeakRef{}
+ },
+}
+
+// NewWeakRef acquires a weak reference for the given object.
+//
+// An optional user will be notified when the last non-weak reference is
+// dropped.
+//
+// Note that you must hold a reference to the object prior to getting a weak
+// reference. (But you may drop the non-weak reference after that.)
+func NewWeakRef(rc RefCounter, u WeakRefUser) *WeakRef {
+ w := weakRefPool.Get().(*WeakRef)
+ w.init(rc, u)
+ return w
+}
+
+// get attempts to get a normal reference to the underlying object, and returns
+// the object. If this weak reference has already been zapped (the object has
+// been destroyed) then false is returned. If the object still exists, then
+// true is returned.
+func (w *WeakRef) get() (RefCounter, bool) {
+ rc := w.obj.Load().(RefCounter)
+ if v := reflect.ValueOf(rc); v == reflect.Zero(v.Type()) {
+ // This pointer has already been zapped by zap() below. We do
+ // this to ensure that the GC can collect the underlying
+ // RefCounter objects and they don't hog resources.
+ return nil, false
+ }
+ if !rc.TryIncRef() {
+ return nil, true
+ }
+ return rc, true
+}
+
+// Get attempts to get a normal reference to the underlying object, and returns
+// the object. If this fails (the object no longer exists), then nil will be
+// returned instead.
+func (w *WeakRef) Get() RefCounter {
+ rc, _ := w.get()
+ return rc
+}
+
+// Drop drops this weak reference. You should always call drop when you are
+// finished with the weak reference. You may not use this object after calling
+// drop.
+func (w *WeakRef) Drop() {
+ rc, ok := w.get()
+ if !ok {
+ // We've been zapped already. When the refcounter has called
+ // zap, we're guaranteed it's not holding references.
+ weakRefPool.Put(w)
+ return
+ }
+ if rc == nil {
+ // The object is in the process of being destroyed. We can't
+ // remove this from the object's list, nor can we return this
+ // object to the pool. It'll just be garbage collected. This is
+ // a rare edge case, so it's not a big deal.
+ return
+ }
+
+ // At this point, we have a reference on the object. So destruction
+ // of the object (and zapping this weak reference) can't race here.
+ rc.dropWeakRef(w)
+
+ // And now aren't on the object's list of weak references. So it won't
+ // zap us if this causes the reference count to drop to zero.
+ rc.DecRef()
+
+ // Return to the pool.
+ weakRefPool.Put(w)
+}
+
+// init initializes this weak reference.
+func (w *WeakRef) init(rc RefCounter, u WeakRefUser) {
+ // Reset the contents of the weak reference.
+ // This is important because we are reseting the atomic value type.
+ // Otherwise, we could panic here if obj is different than what it was
+ // the last time this was used.
+ *w = WeakRef{}
+ w.user = u
+ w.obj.Store(rc)
+
+ // In the load path, we may already have a nil value. So we need to
+ // check whether or not that is the case before calling addWeakRef.
+ if v := reflect.ValueOf(rc); v != reflect.Zero(v.Type()) {
+ rc.addWeakRef(w)
+ }
+}
+
+// zap zaps this weak reference.
+func (w *WeakRef) zap() {
+ // We need to be careful about types here.
+ // So reflect is involved. But it's not that bad.
+ rc := w.obj.Load()
+ typ := reflect.TypeOf(rc)
+ w.obj.Store(reflect.Zero(typ).Interface())
+}
+
+// AtomicRefCount keeps a reference count using atomic operations and calls the
+// destructor when the count reaches zero.
+//
+// N.B. To allow the zero-object to be initialized, the count is offset by
+// 1, that is, when refCount is n, there are really n+1 references.
+//
+// +stateify savable
+type AtomicRefCount struct {
+ // refCount is composed of two fields:
+ //
+ // [32-bit speculative references]:[32-bit real references]
+ //
+ // Speculative references are used for TryIncRef, to avoid a
+ // CompareAndSwap loop. See IncRef, DecRef and TryIncRef for details of
+ // how these fields are used.
+ refCount int64
+
+ // mu protects the list below.
+ mu sync.Mutex `state:"nosave"`
+
+ // weakRefs is our collection of weak references.
+ weakRefs weakRefList `state:"nosave"`
+}
+
+// ReadRefs returns the current number of references. The returned count is
+// inherently racy and is unsafe to use without external synchronization.
+func (r *AtomicRefCount) ReadRefs() int64 {
+ // Account for the internal -1 offset on refcounts.
+ return atomic.LoadInt64(&r.refCount) + 1
+}
+
+// IncRef increments this object's reference count. While the count is kept
+// greater than zero, the destructor doesn't get called.
+//
+// The sanity check here is limited to real references, since if they have
+// dropped beneath zero then the object should have been destroyed.
+func (r *AtomicRefCount) IncRef() {
+ if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
+ panic("Incrementing non-positive ref count")
+ }
+}
+
+// TryIncRef attempts to increment the reference count, *unless the count has
+// already reached zero*. If false is returned, then the object has already
+// been destroyed, and the weak reference is no longer valid. If true if
+// returned then a valid reference is now held on the object.
+//
+// To do this safely without a loop, a speculative reference is first acquired
+// on the object. This allows multiple concurrent TryIncRef calls to
+// distinguish other TryIncRef calls from genuine references held.
+func (r *AtomicRefCount) TryIncRef() bool {
+ const speculativeRef = 1 << 32
+ v := atomic.AddInt64(&r.refCount, speculativeRef)
+ if int32(v) < 0 {
+ // This object has already been freed.
+ atomic.AddInt64(&r.refCount, -speculativeRef)
+ return false
+ }
+
+ // Turn into a real reference.
+ atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ return true
+}
+
+// addWeakRef adds the given weak reference.
+func (r *AtomicRefCount) addWeakRef(w *WeakRef) {
+ r.mu.Lock()
+ r.weakRefs.PushBack(w)
+ r.mu.Unlock()
+}
+
+// dropWeakRef drops the given weak reference.
+func (r *AtomicRefCount) dropWeakRef(w *WeakRef) {
+ r.mu.Lock()
+ r.weakRefs.Remove(w)
+ r.mu.Unlock()
+}
+
+// DecRefWithDestructor decrements the object's reference count. If the
+// resulting count is negative and the destructor is not nil, then the
+// destructor will be called.
+//
+// Note that speculative references are counted here. Since they were added
+// prior to real references reaching zero, they will successfully convert to
+// real references. In other words, we see speculative references only in the
+// following case:
+//
+// A: TryIncRef [speculative increase => sees non-negative references]
+// B: DecRef [real decrease]
+// A: TryIncRef [transform speculative to real]
+//
+func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
+ switch v := atomic.AddInt64(&r.refCount, -1); {
+ case v < -1:
+ panic("Decrementing non-positive ref count")
+
+ case v == -1:
+ // Zap weak references. Note that at this point, all weak
+ // references are already invalid. That is, TryIncRef() will
+ // return false due to the reference count check.
+ r.mu.Lock()
+ for !r.weakRefs.Empty() {
+ w := r.weakRefs.Front()
+ // Capture the callback because w cannot be touched
+ // after it's zapped -- the owner is free it reuse it
+ // after that.
+ user := w.user
+ r.weakRefs.Remove(w)
+ w.zap()
+
+ if user != nil {
+ r.mu.Unlock()
+ user.WeakRefGone()
+ r.mu.Lock()
+ }
+ }
+ r.mu.Unlock()
+
+ // Call the destructor.
+ if destroy != nil {
+ destroy()
+ }
+ }
+}
+
+// DecRef decrements this object's reference count.
+func (r *AtomicRefCount) DecRef() {
+ r.DecRefWithDestructor(nil)
+}
diff --git a/pkg/refs/refcounter_state.go b/pkg/refs/refcounter_state.go
new file mode 100644
index 000000000..7c99fd2b5
--- /dev/null
+++ b/pkg/refs/refcounter_state.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package refs
+
+// +stateify savable
+type savedReference struct {
+ obj interface{}
+}
+
+func (w *WeakRef) saveObj() savedReference {
+ // We load the object directly, because it is typed. This will be
+ // serialized and loaded as a typed value.
+ return savedReference{w.obj.Load()}
+}
+
+func (w *WeakRef) loadObj(v savedReference) {
+ // See note above. This will be serialized and loaded typed. So we're okay
+ // as long as refs aren't changing during save and load (which they should
+ // not be).
+ //
+ // w.user is loaded before loadObj is called.
+ w.init(v.obj.(RefCounter), w.user)
+}
diff --git a/pkg/refs/refs_state_autogen.go b/pkg/refs/refs_state_autogen.go
new file mode 100755
index 000000000..cc788a4fd
--- /dev/null
+++ b/pkg/refs/refs_state_autogen.go
@@ -0,0 +1,77 @@
+// automatically generated by stateify.
+
+package refs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *WeakRef) beforeSave() {}
+func (x *WeakRef) save(m state.Map) {
+ x.beforeSave()
+ var obj savedReference = x.saveObj()
+ m.SaveValue("obj", obj)
+ m.Save("user", &x.user)
+}
+
+func (x *WeakRef) afterLoad() {}
+func (x *WeakRef) load(m state.Map) {
+ m.Load("user", &x.user)
+ m.LoadValue("obj", new(savedReference), func(y interface{}) { x.loadObj(y.(savedReference)) })
+}
+
+func (x *AtomicRefCount) beforeSave() {}
+func (x *AtomicRefCount) save(m state.Map) {
+ x.beforeSave()
+ m.Save("refCount", &x.refCount)
+}
+
+func (x *AtomicRefCount) afterLoad() {}
+func (x *AtomicRefCount) load(m state.Map) {
+ m.Load("refCount", &x.refCount)
+}
+
+func (x *savedReference) beforeSave() {}
+func (x *savedReference) save(m state.Map) {
+ x.beforeSave()
+ m.Save("obj", &x.obj)
+}
+
+func (x *savedReference) afterLoad() {}
+func (x *savedReference) load(m state.Map) {
+ m.Load("obj", &x.obj)
+}
+
+func (x *weakRefList) beforeSave() {}
+func (x *weakRefList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *weakRefList) afterLoad() {}
+func (x *weakRefList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *weakRefEntry) beforeSave() {}
+func (x *weakRefEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *weakRefEntry) afterLoad() {}
+func (x *weakRefEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("refs.WeakRef", (*WeakRef)(nil), state.Fns{Save: (*WeakRef).save, Load: (*WeakRef).load})
+ state.Register("refs.AtomicRefCount", (*AtomicRefCount)(nil), state.Fns{Save: (*AtomicRefCount).save, Load: (*AtomicRefCount).load})
+ state.Register("refs.savedReference", (*savedReference)(nil), state.Fns{Save: (*savedReference).save, Load: (*savedReference).load})
+ state.Register("refs.weakRefList", (*weakRefList)(nil), state.Fns{Save: (*weakRefList).save, Load: (*weakRefList).load})
+ state.Register("refs.weakRefEntry", (*weakRefEntry)(nil), state.Fns{Save: (*weakRefEntry).save, Load: (*weakRefEntry).load})
+}
diff --git a/pkg/refs/weak_ref_list.go b/pkg/refs/weak_ref_list.go
new file mode 100755
index 000000000..df8e98bf5
--- /dev/null
+++ b/pkg/refs/weak_ref_list.go
@@ -0,0 +1,173 @@
+package refs
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type weakRefElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (weakRefElementMapper) linkerFor(elem *WeakRef) *WeakRef { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type weakRefList struct {
+ head *WeakRef
+ tail *WeakRef
+}
+
+// Reset resets list l to the empty state.
+func (l *weakRefList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *weakRefList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *weakRefList) Front() *WeakRef {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *weakRefList) Back() *WeakRef {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *weakRefList) PushFront(e *WeakRef) {
+ weakRefElementMapper{}.linkerFor(e).SetNext(l.head)
+ weakRefElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ weakRefElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *weakRefList) PushBack(e *WeakRef) {
+ weakRefElementMapper{}.linkerFor(e).SetNext(nil)
+ weakRefElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ weakRefElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *weakRefList) PushBackList(m *weakRefList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ weakRefElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ weakRefElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *weakRefList) InsertAfter(b, e *WeakRef) {
+ a := weakRefElementMapper{}.linkerFor(b).Next()
+ weakRefElementMapper{}.linkerFor(e).SetNext(a)
+ weakRefElementMapper{}.linkerFor(e).SetPrev(b)
+ weakRefElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ weakRefElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *weakRefList) InsertBefore(a, e *WeakRef) {
+ b := weakRefElementMapper{}.linkerFor(a).Prev()
+ weakRefElementMapper{}.linkerFor(e).SetNext(a)
+ weakRefElementMapper{}.linkerFor(e).SetPrev(b)
+ weakRefElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ weakRefElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *weakRefList) Remove(e *WeakRef) {
+ prev := weakRefElementMapper{}.linkerFor(e).Prev()
+ next := weakRefElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ weakRefElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ weakRefElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type weakRefEntry struct {
+ next *WeakRef
+ prev *WeakRef
+}
+
+// Next returns the entry that follows e in the list.
+func (e *weakRefEntry) Next() *WeakRef {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *weakRefEntry) Prev() *WeakRef {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *weakRefEntry) SetNext(elem *WeakRef) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *weakRefEntry) SetPrev(elem *WeakRef) {
+ e.prev = elem
+}
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
new file mode 100644
index 000000000..cc142a497
--- /dev/null
+++ b/pkg/seccomp/seccomp.go
@@ -0,0 +1,375 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seccomp provides basic seccomp filters for x86_64 (little endian).
+package seccomp
+
+import (
+ "fmt"
+ "reflect"
+ "sort"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/bpf"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+const (
+ // skipOneInst is the offset to take for skipping one instruction.
+ skipOneInst = 1
+
+ // defaultLabel is the label for the default action.
+ defaultLabel = "default_action"
+)
+
+// Install generates BPF code based on the set of syscalls provided. It only
+// allows syscalls that conform to the specification. Syscalls that violate the
+// specification will trigger RET_KILL_PROCESS, except for the cases below.
+//
+// RET_TRAP is used in violations, instead of RET_KILL_PROCESS, in the
+// following cases:
+// 1. Kernel doesn't support RET_KILL_PROCESS: RET_KILL_THREAD only kills the
+// offending thread and often keeps the sentry hanging.
+// 2. Debug: RET_TRAP generates a panic followed by a stack trace which is
+// much easier to debug then RET_KILL_PROCESS which can't be caught.
+//
+// Be aware that RET_TRAP sends SIGSYS to the process and it may be ignored,
+// making it possible for the process to continue running after a violation.
+// However, it will leave a SECCOMP audit event trail behind. In any case, the
+// syscall is still blocked from executing.
+func Install(rules SyscallRules) error {
+ defaultAction, err := defaultAction()
+ if err != nil {
+ return err
+ }
+
+ // Uncomment to get stack trace when there is a violation.
+ // defaultAction = linux.BPFAction(linux.SECCOMP_RET_TRAP)
+
+ log.Infof("Installing seccomp filters for %d syscalls (action=%v)", len(rules), defaultAction)
+
+ instrs, err := BuildProgram([]RuleSet{
+ RuleSet{
+ Rules: rules,
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ }, defaultAction)
+ if log.IsLogging(log.Debug) {
+ programStr, errDecode := bpf.DecodeProgram(instrs)
+ if errDecode != nil {
+ programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr)
+ }
+ log.Debugf("Seccomp program dump:\n%s", programStr)
+ }
+ if err != nil {
+ return err
+ }
+
+ // Perform the actual installation.
+ if errno := SetFilter(instrs); errno != 0 {
+ return fmt.Errorf("Failed to set filter: %v", errno)
+ }
+
+ log.Infof("Seccomp filters installed.")
+ return nil
+}
+
+func defaultAction() (linux.BPFAction, error) {
+ available, err := isKillProcessAvailable()
+ if err != nil {
+ return 0, err
+ }
+ if available {
+ return linux.SECCOMP_RET_KILL_PROCESS, nil
+ }
+ return linux.SECCOMP_RET_TRAP, nil
+}
+
+// RuleSet is a set of rules and associated action.
+type RuleSet struct {
+ Rules SyscallRules
+ Action linux.BPFAction
+
+ // Vsyscall indicates that a check is made for a function being called
+ // from kernel mappings. This is where the vsyscall page is located
+ // (and typically) emulated, so this RuleSet will not match any
+ // functions not dispatched from the vsyscall page.
+ Vsyscall bool
+}
+
+// SyscallName gives names to system calls. It is used purely for debugging purposes.
+//
+// An alternate namer can be provided to the package at initialization time.
+var SyscallName = func(sysno uintptr) string {
+ return fmt.Sprintf("syscall_%d", sysno)
+}
+
+// BuildProgram builds a BPF program from the given map of actions to matching
+// SyscallRules. The single generated program covers all provided RuleSets.
+func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFInstruction, error) {
+ program := bpf.NewProgramBuilder()
+
+ // Be paranoid and check that syscall is done in the expected architecture.
+ //
+ // A = seccomp_data.arch
+ // if (A != AUDIT_ARCH) goto defaultAction.
+ program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArch)
+ // defaultLabel is at the bottom of the program. The size of program
+ // may exceeds 255 lines, which is the limit of a condition jump.
+ program.AddJump(bpf.Jmp|bpf.Jeq|bpf.K, LINUX_AUDIT_ARCH, skipOneInst, 0)
+ program.AddDirectJumpLabel(defaultLabel)
+ if err := buildIndex(rules, program); err != nil {
+ return nil, err
+ }
+
+ // Exhausted: return defaultAction.
+ if err := program.AddLabel(defaultLabel); err != nil {
+ return nil, err
+ }
+ program.AddStmt(bpf.Ret|bpf.K, uint32(defaultAction))
+
+ return program.Instructions()
+}
+
+// buildIndex builds a BST to quickly search through all syscalls.
+func buildIndex(rules []RuleSet, program *bpf.ProgramBuilder) error {
+ // Build a list of all application system calls, across all given rule
+ // sets. We have a simple BST, but may dispatch individual matchers
+ // with different actions. The matchers are evaluated linearly.
+ requiredSyscalls := make(map[uintptr]struct{})
+ for _, rs := range rules {
+ for sysno := range rs.Rules {
+ requiredSyscalls[sysno] = struct{}{}
+ }
+ }
+ syscalls := make([]uintptr, 0, len(requiredSyscalls))
+ for sysno, _ := range requiredSyscalls {
+ syscalls = append(syscalls, sysno)
+ }
+ sort.Slice(syscalls, func(i, j int) bool { return syscalls[i] < syscalls[j] })
+ for _, sysno := range syscalls {
+ for _, rs := range rules {
+ // Print only if there is a corresponding set of rules.
+ if _, ok := rs.Rules[sysno]; ok {
+ log.Debugf("syscall filter %v: %s => 0x%x", SyscallName(sysno), rs.Rules[sysno], rs.Action)
+ }
+ }
+ }
+
+ root := createBST(syscalls)
+ root.root = true
+
+ // Load syscall number into A and run through BST.
+ //
+ // A = seccomp_data.nr
+ program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetNR)
+ return root.traverse(buildBSTProgram, rules, program)
+}
+
+// createBST converts sorted syscall slice into a balanced BST.
+// Panics if syscalls is empty.
+func createBST(syscalls []uintptr) *node {
+ i := len(syscalls) / 2
+ parent := node{value: syscalls[i]}
+ if i > 0 {
+ parent.left = createBST(syscalls[:i])
+ }
+ if i+1 < len(syscalls) {
+ parent.right = createBST(syscalls[i+1:])
+ }
+ return &parent
+}
+
+func vsyscallViolationLabel(ruleSetIdx int, sysno uintptr) string {
+ return fmt.Sprintf("vsyscallViolation_%v_%v", ruleSetIdx, sysno)
+}
+
+func ruleViolationLabel(ruleSetIdx int, sysno uintptr, idx int) string {
+ return fmt.Sprintf("ruleViolation_%v_%v_%v", ruleSetIdx, sysno, idx)
+}
+
+func checkArgsLabel(sysno uintptr) string {
+ return fmt.Sprintf("checkArgs_%v", sysno)
+}
+
+// addSyscallArgsCheck adds argument checks for a single system call. It does
+// not insert a jump to the default action at the end and it is the
+// responsibility of the caller to insert an appropriate jump after calling
+// this function.
+func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAction, ruleSetIdx int, sysno uintptr) error {
+ for ruleidx, rule := range rules {
+ labelled := false
+ for i, arg := range rule {
+ if arg != nil {
+ switch a := arg.(type) {
+ case AllowAny:
+ case AllowValue:
+ high, low := uint32(a>>32), uint32(a)
+ // assert arg_low == low
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i))
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ // assert arg_high == high
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i))
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ labelled = true
+ default:
+ return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a))
+ }
+ }
+ }
+
+ // Matched, emit the given action.
+ p.AddStmt(bpf.Ret|bpf.K, uint32(action))
+
+ // Label the end of the rule if necessary. This is added for
+ // the jumps above when the argument check fails.
+ if labelled {
+ if err := p.AddLabel(ruleViolationLabel(ruleSetIdx, sysno, ruleidx)); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// buildBSTProgram converts a binary tree started in 'root' into BPF code. The ouline of the code
+// is as follows:
+//
+// // SYS_PIPE(22), root
+// (A == 22) ? goto argument check : continue
+// (A > 22) ? goto index_35 : goto index_9
+//
+// index_9: // SYS_MMAP(9), leaf
+// A == 9) ? goto argument check : defaultLabel
+//
+// index_35: // SYS_NANOSLEEP(35), single child
+// (A == 35) ? goto argument check : continue
+// (A > 35) ? goto index_50 : goto defaultLabel
+//
+// index_50: // SYS_LISTEN(50), leaf
+// (A == 50) ? goto argument check : goto defaultLabel
+//
+func buildBSTProgram(n *node, rules []RuleSet, program *bpf.ProgramBuilder) error {
+ // Root node is never referenced by label, skip it.
+ if !n.root {
+ if err := program.AddLabel(n.label()); err != nil {
+ return err
+ }
+ }
+
+ sysno := n.value
+ program.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, uint32(sysno), checkArgsLabel(sysno), 0)
+ if n.left == nil && n.right == nil {
+ // Leaf nodes don't require extra check.
+ program.AddDirectJumpLabel(defaultLabel)
+ } else {
+ // Non-leaf node. Check which turn to take otherwise. Using direct jumps
+ // in case that the offset may exceed the limit of a conditional jump (255)
+ program.AddJump(bpf.Jmp|bpf.Jgt|bpf.K, uint32(sysno), 0, skipOneInst)
+ program.AddDirectJumpLabel(n.right.label())
+ program.AddDirectJumpLabel(n.left.label())
+ }
+
+ if err := program.AddLabel(checkArgsLabel(sysno)); err != nil {
+ return err
+ }
+
+ emitted := false
+ for ruleSetIdx, rs := range rules {
+ if _, ok := rs.Rules[sysno]; ok {
+ // If there are no rules, then this will always match.
+ // Remember we've done this so that we can emit a
+ // sensible error. We can't catch all overlaps, but we
+ // can catch this one at least.
+ if emitted {
+ return fmt.Errorf("unreachable action for %v: 0x%x (rule set %d)", SyscallName(sysno), rs.Action, ruleSetIdx)
+ }
+
+ // Emit a vsyscall check if this rule requires a
+ // Vsyscall match. This rule ensures that the top bit
+ // is set in the instruction pointer, which is where
+ // the vsyscall page will be mapped.
+ if rs.Vsyscall {
+ program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetIPHigh)
+ program.AddJumpFalseLabel(bpf.Jmp|bpf.Jset|bpf.K, 0x80000000, 0, vsyscallViolationLabel(ruleSetIdx, sysno))
+ }
+
+ // Emit matchers.
+ if len(rs.Rules[sysno]) == 0 {
+ // This is a blanket action.
+ program.AddStmt(bpf.Ret|bpf.K, uint32(rs.Action))
+ emitted = true
+ } else {
+ // Add an argument check for these particular
+ // arguments. This will continue execution and
+ // check the next rule set. We need to ensure
+ // that at the very end, we insert a direct
+ // jump label for the unmatched case.
+ if err := addSyscallArgsCheck(program, rs.Rules[sysno], rs.Action, ruleSetIdx, sysno); err != nil {
+ return err
+ }
+ }
+
+ // If there was a Vsyscall check for this rule, then we
+ // need to add an appropriate label for the jump above.
+ if rs.Vsyscall {
+ if err := program.AddLabel(vsyscallViolationLabel(ruleSetIdx, sysno)); err != nil {
+ return err
+ }
+ }
+ }
+ }
+
+ // Not matched? We only need to insert a jump to the default label if
+ // not default action has been emitted for this call.
+ if !emitted {
+ program.AddDirectJumpLabel(defaultLabel)
+ }
+
+ return nil
+}
+
+// node represents a tree node.
+type node struct {
+ value uintptr
+ left *node
+ right *node
+ root bool
+}
+
+// label returns the label corresponding to this node.
+//
+// If n is nil, then the defaultLabel is returned.
+func (n *node) label() string {
+ if n == nil {
+ return defaultLabel
+ }
+ return fmt.Sprintf("index_%v", n.value)
+}
+
+type traverseFunc func(*node, []RuleSet, *bpf.ProgramBuilder) error
+
+func (n *node) traverse(fn traverseFunc, rules []RuleSet, p *bpf.ProgramBuilder) error {
+ if n == nil {
+ return nil
+ }
+ if err := fn(n, rules, p); err != nil {
+ return err
+ }
+ if err := n.left.traverse(fn, rules, p); err != nil {
+ return err
+ }
+ return n.right.traverse(fn, rules, p)
+}
diff --git a/pkg/seccomp/seccomp_amd64.go b/pkg/seccomp/seccomp_amd64.go
new file mode 100644
index 000000000..02dfb8d9f
--- /dev/null
+++ b/pkg/seccomp/seccomp_amd64.go
@@ -0,0 +1,26 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package seccomp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+const (
+ LINUX_AUDIT_ARCH = linux.AUDIT_ARCH_X86_64
+ SYS_SECCOMP = 317
+)
diff --git a/pkg/seccomp/seccomp_arm64.go b/pkg/seccomp/seccomp_arm64.go
new file mode 100644
index 000000000..b575bcdbf
--- /dev/null
+++ b/pkg/seccomp/seccomp_arm64.go
@@ -0,0 +1,26 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package seccomp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+const (
+ LINUX_AUDIT_ARCH = linux.AUDIT_ARCH_AARCH64
+ SYS_SECCOMP = 277
+)
diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go
new file mode 100644
index 000000000..29eec8db1
--- /dev/null
+++ b/pkg/seccomp/seccomp_rules.go
@@ -0,0 +1,132 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccomp
+
+import "fmt"
+
+// The offsets are based on the following struct in include/linux/seccomp.h.
+// struct seccomp_data {
+// int nr;
+// __u32 arch;
+// __u64 instruction_pointer;
+// __u64 args[6];
+// };
+const (
+ seccompDataOffsetNR = 0
+ seccompDataOffsetArch = 4
+ seccompDataOffsetIPLow = 8
+ seccompDataOffsetIPHigh = 12
+ seccompDataOffsetArgs = 16
+)
+
+func seccompDataOffsetArgLow(i int) uint32 {
+ return uint32(seccompDataOffsetArgs + i*8)
+}
+
+func seccompDataOffsetArgHigh(i int) uint32 {
+ return seccompDataOffsetArgLow(i) + 4
+}
+
+// AllowAny is marker to indicate any value will be accepted.
+type AllowAny struct{}
+
+func (a AllowAny) String() (s string) {
+ return "*"
+}
+
+// AllowValue specifies a value that needs to be strictly matched.
+type AllowValue uintptr
+
+func (a AllowValue) String() (s string) {
+ return fmt.Sprintf("%#x ", uintptr(a))
+}
+
+// Rule stores the whitelist of syscall arguments.
+//
+// For example:
+// rule := Rule {
+// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
+// }
+type Rule [6]interface{}
+
+func (r Rule) String() (s string) {
+ if len(r) == 0 {
+ return
+ }
+ s += "( "
+ for _, arg := range r {
+ if arg != nil {
+ s += fmt.Sprintf("%v ", arg)
+ }
+ }
+ s += ")"
+ return
+}
+
+// SyscallRules stores a map of OR'ed whitelist rules indexed by the syscall number.
+// If the 'Rules' is empty, we treat it as any argument is allowed.
+//
+// For example:
+// rules := SyscallRules{
+// syscall.SYS_FUTEX: []Rule{
+// {
+// AllowAny{},
+// AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
+// }, // OR
+// {
+// AllowAny{},
+// AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
+// },
+// },
+// syscall.SYS_GETPID: []Rule{},
+// }
+type SyscallRules map[uintptr][]Rule
+
+// NewSyscallRules returns a new SyscallRules.
+func NewSyscallRules() SyscallRules {
+ return make(map[uintptr][]Rule)
+}
+
+// AddRule adds the given rule. It will create a new entry for a new syscall, otherwise
+// it will append to the existing rules.
+func (sr SyscallRules) AddRule(sysno uintptr, r Rule) {
+ if cur, ok := sr[sysno]; ok {
+ // An empty rules means allow all. Honor it when more rules are added.
+ if len(cur) == 0 {
+ sr[sysno] = append(sr[sysno], Rule{})
+ }
+ sr[sysno] = append(sr[sysno], r)
+ } else {
+ sr[sysno] = []Rule{r}
+ }
+}
+
+// Merge merges the given SyscallRules.
+func (sr SyscallRules) Merge(rules SyscallRules) {
+ for sysno, rs := range rules {
+ if cur, ok := sr[sysno]; ok {
+ // An empty rules means allow all. Honor it when more rules are added.
+ if len(cur) == 0 {
+ sr[sysno] = append(sr[sysno], Rule{})
+ }
+ if len(rs) == 0 {
+ rs = []Rule{{}}
+ }
+ sr[sysno] = append(sr[sysno], rs...)
+ } else {
+ sr[sysno] = rs
+ }
+ }
+}
diff --git a/pkg/seccomp/seccomp_state_autogen.go b/pkg/seccomp/seccomp_state_autogen.go
new file mode 100755
index 000000000..0fc23d1a8
--- /dev/null
+++ b/pkg/seccomp/seccomp_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package seccomp
+
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
new file mode 100644
index 000000000..ebb6397e8
--- /dev/null
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccomp
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// sockFprog is sock_fprog taken from <linux/filter.h>.
+type sockFprog struct {
+ Len uint16
+ pad [6]byte
+ Filter *linux.BPFInstruction
+}
+
+// SetFilter installs the given BPF program.
+//
+// This is safe to call from an afterFork context.
+//
+//go:nosplit
+func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
+ // PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details.
+ if _, _, errno := syscall.RawSyscall(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0); errno != 0 {
+ return errno
+ }
+
+ sockProg := sockFprog{
+ Len: uint16(len(instrs)),
+ Filter: (*linux.BPFInstruction)(unsafe.Pointer(&instrs[0])),
+ }
+ return seccomp(linux.SECCOMP_SET_MODE_FILTER, linux.SECCOMP_FILTER_FLAG_TSYNC, unsafe.Pointer(&sockProg))
+}
+
+func isKillProcessAvailable() (bool, error) {
+ action := uint32(linux.SECCOMP_RET_KILL_PROCESS)
+ if errno := seccomp(linux.SECCOMP_GET_ACTION_AVAIL, 0, unsafe.Pointer(&action)); errno != 0 {
+ // EINVAL: SECCOMP_GET_ACTION_AVAIL not in this kernel yet.
+ // EOPNOTSUPP: SECCOMP_RET_KILL_PROCESS not supported.
+ if errno == syscall.EINVAL || errno == syscall.EOPNOTSUPP {
+ return false, nil
+ }
+ return false, errno
+ }
+ return true, nil
+}
+
+// seccomp calls seccomp(2). This is safe to call from an afterFork context.
+//
+//go:nosplit
+func seccomp(op, flags uint32, ptr unsafe.Pointer) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall(SYS_SECCOMP, uintptr(op), uintptr(flags), uintptr(ptr)); errno != 0 {
+ return errno
+ }
+ return 0
+}
diff --git a/pkg/secio/full_reader.go b/pkg/secio/full_reader.go
new file mode 100644
index 000000000..aed2564bd
--- /dev/null
+++ b/pkg/secio/full_reader.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package secio
+
+import (
+ "io"
+)
+
+// FullReader adapts an io.Reader to never return partial reads with a nil
+// error.
+type FullReader struct {
+ Reader io.Reader
+}
+
+// Read implements io.Reader.Read.
+func (r FullReader) Read(dst []byte) (int, error) {
+ n, err := io.ReadFull(r.Reader, dst)
+ if err == io.ErrUnexpectedEOF {
+ return n, io.EOF
+ }
+ return n, err
+}
diff --git a/pkg/secio/secio.go b/pkg/secio/secio.go
new file mode 100644
index 000000000..b43226035
--- /dev/null
+++ b/pkg/secio/secio.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package secio provides support for sectioned I/O.
+package secio
+
+import (
+ "errors"
+ "io"
+)
+
+// ErrReachedLimit is returned when SectionReader.Read or SectionWriter.Write
+// reaches its limit.
+var ErrReachedLimit = errors.New("reached limit")
+
+// SectionReader implements io.Reader on a section of an underlying io.ReaderAt.
+// It is similar to io.SectionReader, but:
+//
+// - Reading beyond the limit returns ErrReachedLimit, not io.EOF.
+//
+// - Limit overflow is handled correctly.
+type SectionReader struct {
+ r io.ReaderAt
+ off int64
+ limit int64
+}
+
+// Read implements io.Reader.Read.
+func (r *SectionReader) Read(dst []byte) (int, error) {
+ if r.limit >= 0 {
+ if max := r.limit - r.off; max < int64(len(dst)) {
+ dst = dst[:max]
+ }
+ }
+ n, err := r.r.ReadAt(dst, r.off)
+ r.off += int64(n)
+ if err == nil && r.off == r.limit {
+ err = ErrReachedLimit
+ }
+ return n, err
+}
+
+// NewOffsetReader returns an io.Reader that reads from r starting at offset
+// off.
+func NewOffsetReader(r io.ReaderAt, off int64) *SectionReader {
+ return &SectionReader{r, off, -1}
+}
+
+// NewSectionReader returns an io.Reader that reads from r starting at offset
+// off and stops with ErrReachedLimit after n bytes.
+func NewSectionReader(r io.ReaderAt, off int64, n int64) *SectionReader {
+ // If off + n overflows, it will be < 0 such that no limit applies, but
+ // this is the correct behavior as long as r prohibits reading at offsets
+ // beyond MaxInt64.
+ return &SectionReader{r, off, off + n}
+}
+
+// SectionWriter implements io.Writer on a section of an underlying
+// io.WriterAt. Writing beyond the limit returns ErrReachedLimit.
+type SectionWriter struct {
+ w io.WriterAt
+ off int64
+ limit int64
+}
+
+// Write implements io.Writer.Write.
+func (w *SectionWriter) Write(src []byte) (int, error) {
+ if w.limit >= 0 {
+ if max := w.limit - w.off; max < int64(len(src)) {
+ src = src[:max]
+ }
+ }
+ n, err := w.w.WriteAt(src, w.off)
+ w.off += int64(n)
+ if err == nil && w.off == w.limit {
+ err = ErrReachedLimit
+ }
+ return n, err
+}
+
+// NewOffsetWriter returns an io.Writer that writes to w starting at offset
+// off.
+func NewOffsetWriter(w io.WriterAt, off int64) *SectionWriter {
+ return &SectionWriter{w, off, -1}
+}
+
+// NewSectionWriter returns an io.Writer that writes to w starting at offset
+// off and stops with ErrReachedLimit after n bytes.
+func NewSectionWriter(w io.WriterAt, off int64, n int64) *SectionWriter {
+ // If off + n overflows, it will be < 0 such that no limit applies, but
+ // this is the correct behavior as long as w prohibits writing at offsets
+ // beyond MaxInt64.
+ return &SectionWriter{w, off, off + n}
+}
diff --git a/pkg/secio/secio_state_autogen.go b/pkg/secio/secio_state_autogen.go
new file mode 100755
index 000000000..ec559f264
--- /dev/null
+++ b/pkg/secio/secio_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package secio
+
diff --git a/pkg/sentry/arch/aligned.go b/pkg/sentry/arch/aligned.go
new file mode 100644
index 000000000..df01a903d
--- /dev/null
+++ b/pkg/sentry/arch/aligned.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "reflect"
+)
+
+// alignedBytes returns a slice of size bytes, aligned in memory to the given
+// alignment. This is used because we require certain structures to be aligned
+// in a specific way (for example, the X86 floating point data).
+func alignedBytes(size, alignment uint) []byte {
+ data := make([]byte, size+alignment-1)
+ offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment))
+ if offset == 0 {
+ return data[:size:size]
+ }
+ return data[alignment-offset:][:size:size]
+}
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
new file mode 100644
index 000000000..53f0c9018
--- /dev/null
+++ b/pkg/sentry/arch/arch.go
@@ -0,0 +1,359 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package arch provides abstractions around architecture-dependent details,
+// such as syscall calling conventions, native types, etc.
+package arch
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Arch describes an architecture.
+type Arch int
+
+const (
+ // AMD64 is the x86-64 architecture.
+ AMD64 Arch = iota
+)
+
+// String implements fmt.Stringer.
+func (a Arch) String() string {
+ switch a {
+ case AMD64:
+ return "amd64"
+ default:
+ return fmt.Sprintf("Arch(%d)", a)
+ }
+}
+
+// FloatingPointData is a generic type, and will always be passed as a pointer.
+// We rely on the individual arch implementations to meet all the necessary
+// requirements. For example, on x86 the region must be 16-byte aligned and 512
+// bytes in size.
+type FloatingPointData byte
+
+// Context provides architecture-dependent information for a specific thread.
+//
+// NOTE(b/34169503): Currently we use uintptr here to refer to a generic native
+// register value. While this will work for the foreseeable future, it isn't
+// strictly correct. We may want to create some abstraction that makes this
+// more clear or enables us to store values of arbitrary widths. This is
+// particularly true for RegisterMap().
+type Context interface {
+ // Arch returns the architecture for this Context.
+ Arch() Arch
+
+ // Native converts a generic type to a native value.
+ //
+ // Because the architecture is not specified here, we may be dealing
+ // with return values of varying sizes (for example ARCH_GETFS). This
+ // is a simple utility function to convert to the native size in these
+ // cases, and then we can CopyOut.
+ Native(val uintptr) interface{}
+
+ // Value converts a native type back to a generic value.
+ // Once a value has been converted to native via the above call -- it
+ // can be converted back here.
+ Value(val interface{}) uintptr
+
+ // Width returns the number of bytes for a native value.
+ Width() uint
+
+ // Fork creates a clone of the context.
+ Fork() Context
+
+ // SyscallNo returns the syscall number.
+ SyscallNo() uintptr
+
+ // SyscallArgs returns the syscall arguments in an array.
+ SyscallArgs() SyscallArguments
+
+ // Return returns the return value for a system call.
+ Return() uintptr
+
+ // SetReturn sets the return value for a system call.
+ SetReturn(value uintptr)
+
+ // RestartSyscall reverses over the current syscall instruction, such that
+ // when the application resumes execution the syscall will be re-attempted.
+ RestartSyscall()
+
+ // RestartSyscallWithRestartBlock reverses over the current syscall
+ // instraction and overwrites the current syscall number with that of
+ // restart_syscall(2). This causes the application to restart the current
+ // syscall with a custom function when execution resumes.
+ RestartSyscallWithRestartBlock()
+
+ // IP returns the current instruction pointer.
+ IP() uintptr
+
+ // SetIP sets the current instruction pointer.
+ SetIP(value uintptr)
+
+ // Stack returns the current stack pointer.
+ Stack() uintptr
+
+ // SetStack sets the current stack pointer.
+ SetStack(value uintptr)
+
+ // TLS returns the current TLS pointer.
+ TLS() uintptr
+
+ // SetTLS sets the current TLS pointer. Returns false if value is invalid.
+ SetTLS(value uintptr) bool
+
+ // SetRSEQInterruptedIP sets the register that contains the old IP when a
+ // restartable sequence is interrupted.
+ SetRSEQInterruptedIP(value uintptr)
+
+ // StateData returns a pointer to underlying architecture state.
+ StateData() *State
+
+ // RegisterMap returns a map of all registers.
+ RegisterMap() (map[string]uintptr, error)
+
+ // NewSignalAct returns a new object that is equivalent to struct sigaction
+ // in the guest architecture.
+ NewSignalAct() NativeSignalAct
+
+ // NewSignalStack returns a new object that is equivalent to stack_t in the
+ // guest architecture.
+ NewSignalStack() NativeSignalStack
+
+ // SignalSetup modifies the context in preparation for handling the
+ // given signal.
+ //
+ // st is the stack where the signal handler frame should be
+ // constructed.
+ //
+ // act is the SignalAct that specifies how this signal is being
+ // handled.
+ //
+ // info is the SignalInfo of the signal being delivered.
+ //
+ // alt is the alternate signal stack (even if the alternate signal
+ // stack is not going to be used).
+ //
+ // sigset is the signal mask before entering the signal handler.
+ SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error
+
+ // SignalRestore restores context after returning from a signal
+ // handler.
+ //
+ // st is the current thread stack.
+ //
+ // rt is true if SignalRestore is being entered from rt_sigreturn and
+ // false if SignalRestore is being entered from sigreturn.
+ // SignalRestore returns the thread's new signal mask.
+ SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error)
+
+ // CPUIDEmulate emulates a CPUID instruction according to current register state.
+ CPUIDEmulate(l log.Logger)
+
+ // SingleStep returns true if single stepping is enabled.
+ SingleStep() bool
+
+ // SetSingleStep enables single stepping.
+ SetSingleStep()
+
+ // ClearSingleStep disables single stepping.
+ ClearSingleStep()
+
+ // FloatingPointData will be passed to underlying save routines.
+ FloatingPointData() *FloatingPointData
+
+ // NewMmapLayout returns a layout for a new MM, where MinAddr for the
+ // returned layout must be no lower than min, and MaxAddr for the returned
+ // layout must be no higher than max. Repeated calls to NewMmapLayout may
+ // return different layouts.
+ NewMmapLayout(min, max usermem.Addr, limits *limits.LimitSet) (MmapLayout, error)
+
+ // PIELoadAddress returns a preferred load address for a
+ // position-independent executable within l.
+ PIELoadAddress(l MmapLayout) usermem.Addr
+
+ // FeatureSet returns the FeatureSet in use in this context.
+ FeatureSet() *cpuid.FeatureSet
+
+ // Hack around our package dependences being too broken to support the
+ // equivalent of arch_ptrace():
+
+ // PtracePeekUser implements ptrace(PTRACE_PEEKUSR).
+ PtracePeekUser(addr uintptr) (interface{}, error)
+
+ // PtracePokeUser implements ptrace(PTRACE_POKEUSR).
+ PtracePokeUser(addr, data uintptr) error
+
+ // PtraceGetRegs implements ptrace(PTRACE_GETREGS) by writing the
+ // general-purpose registers represented by this Context to dst and
+ // returning the number of bytes written.
+ PtraceGetRegs(dst io.Writer) (int, error)
+
+ // PtraceSetRegs implements ptrace(PTRACE_SETREGS) by reading
+ // general-purpose registers from src into this Context and returning the
+ // number of bytes read.
+ PtraceSetRegs(src io.Reader) (int, error)
+
+ // PtraceGetFPRegs implements ptrace(PTRACE_GETFPREGS) by writing the
+ // floating-point registers represented by this Context to addr in dst and
+ // returning the number of bytes written.
+ PtraceGetFPRegs(dst io.Writer) (int, error)
+
+ // PtraceSetFPRegs implements ptrace(PTRACE_SETFPREGS) by reading
+ // floating-point registers from src into this Context and returning the
+ // number of bytes read.
+ PtraceSetFPRegs(src io.Reader) (int, error)
+
+ // PtraceGetRegSet implements ptrace(PTRACE_GETREGSET) by writing the
+ // register set given by architecture-defined value regset from this
+ // Context to dst and returning the number of bytes written, which must be
+ // less than or equal to maxlen.
+ PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error)
+
+ // PtraceSetRegSet implements ptrace(PTRACE_SETREGSET) by reading the
+ // register set given by architecture-defined value regset from src and
+ // returning the number of bytes read, which must be less than or equal to
+ // maxlen.
+ PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error)
+
+ // FullRestore returns 'true' if all CPU registers must be restored
+ // when switching to the untrusted application. Typically a task enters
+ // and leaves the kernel via a system call. Platform.Switch() may
+ // optimize for this by not saving/restoring all registers if allowed
+ // by the ABI. For e.g. the amd64 ABI specifies that syscall clobbers
+ // %rcx and %r11. If FullRestore returns true then these optimizations
+ // must be disabled and all registers restored.
+ FullRestore() bool
+}
+
+// MmapDirection is a search direction for mmaps.
+type MmapDirection int
+
+const (
+ // MmapBottomUp instructs mmap to prefer lower addresses.
+ MmapBottomUp MmapDirection = iota
+
+ // MmapTopDown instructs mmap to prefer higher addresses.
+ MmapTopDown
+)
+
+// MmapLayout defines the layout of the user address space for a particular
+// MemoryManager.
+//
+// Note that "highest address" below is always exclusive.
+//
+// +stateify savable
+type MmapLayout struct {
+ // MinAddr is the lowest mappable address.
+ MinAddr usermem.Addr
+
+ // MaxAddr is the highest mappable address.
+ MaxAddr usermem.Addr
+
+ // BottomUpBase is the lowest address that may be returned for a
+ // MmapBottomUp mmap.
+ BottomUpBase usermem.Addr
+
+ // TopDownBase is the highest address that may be returned for a
+ // MmapTopDown mmap.
+ TopDownBase usermem.Addr
+
+ // DefaultDirection is the direction for most non-fixed mmaps in this
+ // layout.
+ DefaultDirection MmapDirection
+
+ // MaxStackRand is the maximum randomization to apply to stack
+ // allocations to maintain a proper gap between the stack and
+ // TopDownBase.
+ MaxStackRand uint64
+}
+
+// Valid returns true if this layout is valid.
+func (m *MmapLayout) Valid() bool {
+ if m.MinAddr > m.MaxAddr {
+ return false
+ }
+ if m.BottomUpBase < m.MinAddr {
+ return false
+ }
+ if m.BottomUpBase > m.MaxAddr {
+ return false
+ }
+ if m.TopDownBase < m.MinAddr {
+ return false
+ }
+ if m.TopDownBase > m.MaxAddr {
+ return false
+ }
+ return true
+}
+
+// SyscallArgument is an argument supplied to a syscall implementation. The
+// methods used to access the arguments are named after the ***C type name*** and
+// they convert to the closest Go type available. For example, Int() refers to a
+// 32-bit signed integer argument represented in Go as an int32.
+//
+// Using the accessor methods guarantees that the conversion between types is
+// correct, taking into account size and signedness (i.e., zero-extension vs
+// signed-extension).
+type SyscallArgument struct {
+ // Prefer to use accessor methods instead of 'Value' directly.
+ Value uintptr
+}
+
+// SyscallArguments represents the set of arguments passed to a syscall.
+type SyscallArguments [6]SyscallArgument
+
+// Pointer returns the usermem.Addr representation of a pointer argument.
+func (a SyscallArgument) Pointer() usermem.Addr {
+ return usermem.Addr(a.Value)
+}
+
+// Int returns the int32 representation of a 32-bit signed integer argument.
+func (a SyscallArgument) Int() int32 {
+ return int32(a.Value)
+}
+
+// Uint returns the uint32 representation of a 32-bit unsigned integer argument.
+func (a SyscallArgument) Uint() uint32 {
+ return uint32(a.Value)
+}
+
+// Int64 returns the int64 representation of a 64-bit signed integer argument.
+func (a SyscallArgument) Int64() int64 {
+ return int64(a.Value)
+}
+
+// Uint64 returns the uint64 representation of a 64-bit unsigned integer argument.
+func (a SyscallArgument) Uint64() uint64 {
+ return uint64(a.Value)
+}
+
+// SizeT returns the uint representation of a size_t argument.
+func (a SyscallArgument) SizeT() uint {
+ return uint(a.Value)
+}
+
+// ModeT returns the int representation of a mode_t argument.
+func (a SyscallArgument) ModeT() uint {
+ return uint(uint16(a.Value))
+}
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
new file mode 100644
index 000000000..135c2ee1f
--- /dev/null
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -0,0 +1,325 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package arch
+
+import (
+ "bytes"
+ "fmt"
+ "math/rand"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Host specifies the host architecture.
+const Host = AMD64
+
+// These constants come directly from Linux.
+const (
+ // maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux
+ // for a 64-bit process.
+ maxAddr64 usermem.Addr = (1 << 47) - usermem.PageSize
+
+ // maxStackRand64 is the maximum randomization to apply to the stack.
+ // It is defined by arch/x86/mm/mmap.c:stack_maxrandom_size in Linux.
+ maxStackRand64 = 16 << 30 // 16 GB
+
+ // maxMmapRand64 is the maximum randomization to apply to the mmap
+ // layout. It is defined by arch/x86/mm/mmap.c:arch_mmap_rnd in Linux.
+ maxMmapRand64 = (1 << 28) * usermem.PageSize
+
+ // minGap64 is the minimum gap to leave at the top of the address space
+ // for the stack. It is defined by arch/x86/mm/mmap.c:MIN_GAP in Linux.
+ minGap64 = (128 << 20) + maxStackRand64
+
+ // preferredPIELoadAddr is the standard Linux position-independent
+ // executable base load address. It is ELF_ET_DYN_BASE in Linux.
+ //
+ // The Platform {Min,Max}UserAddress() may preclude loading at this
+ // address. See other preferredFoo comments below.
+ preferredPIELoadAddr usermem.Addr = maxAddr64 / 3 * 2
+)
+
+// These constants are selected as heuristics to help make the Platform's
+// potentially limited address space conform as closely to Linux as possible.
+const (
+ // Select a preferred minimum TopDownBase address.
+ //
+ // Some applications (TSAN and other *SANs) are very particular about
+ // the way the Linux mmap allocator layouts out the address space.
+ //
+ // TSAN in particular expects top down allocations to be made in the
+ // range [0x7e8000000000, 0x800000000000).
+ //
+ // The minimum TopDownBase on Linux would be:
+ // 0x800000000000 - minGap64 - maxMmapRand64 = 0x7efbf8000000.
+ //
+ // (minGap64 because TSAN uses a small RLIMIT_STACK.)
+ //
+ // 0x7e8000000000 is selected arbitrarily by TSAN to leave room for
+ // allocations below TopDownBase.
+ //
+ // N.B. ASAN and MSAN are more forgiving; ASAN allows allocations all
+ // the way down to 0x10007fff8000, and MSAN down to 0x700000000000.
+ //
+ // Of course, there is no hard minimum to allocation; an allocator can
+ // search all the way from TopDownBase to Min. However, TSAN declared
+ // their range "good enough".
+ //
+ // We would like to pick a TopDownBase such that it is unlikely that an
+ // allocator will select an address below TSAN's minimum. We achieve
+ // this by trying to leave a sizable gap below TopDownBase.
+ //
+ // This is all "preferred" because the layout min/max address may not
+ // allow us to select such a TopDownBase, in which case we have to fall
+ // back to a layout that TSAN may not be happy with.
+ preferredTopDownAllocMin usermem.Addr = 0x7e8000000000
+ preferredAllocationGap = 128 << 30 // 128 GB
+ preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
+
+ // minMmapRand64 is the smallest we are willing to make the
+ // randomization to stay above preferredTopDownBaseMin.
+ minMmapRand64 = (1 << 26) * usermem.PageSize
+)
+
+// context64 represents an AMD64 context.
+//
+// +stateify savable
+type context64 struct {
+ State
+ sigFPState []x86FPState // fpstate to be restored on sigreturn.
+}
+
+// Arch implements Context.Arch.
+func (c *context64) Arch() Arch {
+ return AMD64
+}
+
+func (c *context64) copySigFPState() []x86FPState {
+ var sigfps []x86FPState
+ for _, s := range c.sigFPState {
+ sigfps = append(sigfps, s.fork())
+ }
+ return sigfps
+}
+
+// Fork returns an exact copy of this context.
+func (c *context64) Fork() Context {
+ return &context64{
+ State: c.State.Fork(),
+ sigFPState: c.copySigFPState(),
+ }
+}
+
+// Return returns the current syscall return value.
+func (c *context64) Return() uintptr {
+ return uintptr(c.Regs.Rax)
+}
+
+// SetReturn sets the syscall return value.
+func (c *context64) SetReturn(value uintptr) {
+ c.Regs.Rax = uint64(value)
+}
+
+// IP returns the current instruction pointer.
+func (c *context64) IP() uintptr {
+ return uintptr(c.Regs.Rip)
+}
+
+// SetIP sets the current instruction pointer.
+func (c *context64) SetIP(value uintptr) {
+ c.Regs.Rip = uint64(value)
+}
+
+// Stack returns the current stack pointer.
+func (c *context64) Stack() uintptr {
+ return uintptr(c.Regs.Rsp)
+}
+
+// SetStack sets the current stack pointer.
+func (c *context64) SetStack(value uintptr) {
+ c.Regs.Rsp = uint64(value)
+}
+
+// TLS returns the current TLS pointer.
+func (c *context64) TLS() uintptr {
+ return uintptr(c.Regs.Fs_base)
+}
+
+// SetTLS sets the current TLS pointer. Returns false if value is invalid.
+func (c *context64) SetTLS(value uintptr) bool {
+ if !isValidSegmentBase(uint64(value)) {
+ return false
+ }
+
+ c.Regs.Fs = 0
+ c.Regs.Fs_base = uint64(value)
+ return true
+}
+
+// SetRSEQInterruptedIP implements Context.SetRSEQInterruptedIP.
+func (c *context64) SetRSEQInterruptedIP(value uintptr) {
+ c.Regs.R10 = uint64(value)
+}
+
+// Native returns the native type for the given val.
+func (c *context64) Native(val uintptr) interface{} {
+ v := uint64(val)
+ return &v
+}
+
+// Value returns the generic val for the given native type.
+func (c *context64) Value(val interface{}) uintptr {
+ return uintptr(*val.(*uint64))
+}
+
+// Width returns the byte width of this architecture.
+func (c *context64) Width() uint {
+ return 8
+}
+
+// FeatureSet returns the FeatureSet in use.
+func (c *context64) FeatureSet() *cpuid.FeatureSet {
+ return c.State.FeatureSet
+}
+
+// mmapRand returns a random adjustment for randomizing an mmap layout.
+func mmapRand(max uint64) usermem.Addr {
+ return usermem.Addr(rand.Int63n(int64(max))).RoundDown()
+}
+
+// NewMmapLayout implements Context.NewMmapLayout consistently with Linux.
+func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) {
+ min, ok := min.RoundUp()
+ if !ok {
+ return MmapLayout{}, syscall.EINVAL
+ }
+ if max > maxAddr64 {
+ max = maxAddr64
+ }
+ max = max.RoundDown()
+
+ if min > max {
+ return MmapLayout{}, syscall.EINVAL
+ }
+
+ stackSize := r.Get(limits.Stack)
+
+ // MAX_GAP in Linux.
+ maxGap := (max / 6) * 5
+ gap := usermem.Addr(stackSize.Cur)
+ if gap < minGap64 {
+ gap = minGap64
+ }
+ if gap > maxGap {
+ gap = maxGap
+ }
+ defaultDir := MmapTopDown
+ if stackSize.Cur == limits.Infinity {
+ defaultDir = MmapBottomUp
+ }
+
+ topDownMin := max - gap - maxMmapRand64
+ maxRand := usermem.Addr(maxMmapRand64)
+ if topDownMin < preferredTopDownBaseMin {
+ // Try to keep TopDownBase above preferredTopDownBaseMin by
+ // shrinking maxRand.
+ maxAdjust := maxRand - minMmapRand64
+ needAdjust := preferredTopDownBaseMin - topDownMin
+ if needAdjust <= maxAdjust {
+ maxRand -= needAdjust
+ }
+ }
+
+ rnd := mmapRand(uint64(maxRand))
+ l := MmapLayout{
+ MinAddr: min,
+ MaxAddr: max,
+ // TASK_UNMAPPED_BASE in Linux.
+ BottomUpBase: (max/3 + rnd).RoundDown(),
+ TopDownBase: (max - gap - rnd).RoundDown(),
+ DefaultDirection: defaultDir,
+ // We may have reduced the maximum randomization to keep
+ // TopDownBase above preferredTopDownBaseMin while maintaining
+ // our stack gap. Stack allocations must use that max
+ // randomization to avoiding eating into the gap.
+ MaxStackRand: uint64(maxRand),
+ }
+
+ // Final sanity check on the layout.
+ if !l.Valid() {
+ panic(fmt.Sprintf("Invalid MmapLayout: %+v", l))
+ }
+
+ return l, nil
+}
+
+// PIELoadAddress implements Context.PIELoadAddress.
+func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
+ base := preferredPIELoadAddr
+ max, ok := base.AddLength(maxMmapRand64)
+ if !ok {
+ panic(fmt.Sprintf("preferredPIELoadAddr %#x too large", base))
+ }
+
+ if max > l.MaxAddr {
+ // preferredPIELoadAddr won't fit; fall back to the standard
+ // Linux behavior of 2/3 of TopDownBase. TSAN won't like this.
+ //
+ // Don't bother trying to shrink the randomization for now.
+ base = l.TopDownBase / 3 * 2
+ }
+
+ return base + mmapRand(maxMmapRand64)
+}
+
+// userStructSize is the size in bytes of Linux's struct user on amd64.
+const userStructSize = 928
+
+// PtracePeekUser implements Context.PtracePeekUser.
+func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+ if addr&7 != 0 || addr >= userStructSize {
+ return nil, syscall.EIO
+ }
+ // PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and
+ // u_debugreg, returning 0 or silently no-oping for other fields
+ // respectively.
+ if addr < uintptr(ptraceRegsSize) {
+ buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs())
+ return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil
+ }
+ // TODO(b/34088053): debug registers
+ return c.Native(0), nil
+}
+
+// PtracePokeUser implements Context.PtracePokeUser.
+func (c *context64) PtracePokeUser(addr, data uintptr) error {
+ if addr&7 != 0 || addr >= userStructSize {
+ return syscall.EIO
+ }
+ if addr < uintptr(ptraceRegsSize) {
+ buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs())
+ usermem.ByteOrder.PutUint64(buf[addr:], uint64(data))
+ _, err := c.PtraceSetRegs(bytes.NewBuffer(buf))
+ return err
+ }
+ // TODO(b/34088053): debug registers
+ return nil
+}
diff --git a/pkg/sentry/arch/arch_amd64.s b/pkg/sentry/arch/arch_amd64.s
new file mode 100644
index 000000000..bd61402cf
--- /dev/null
+++ b/pkg/sentry/arch/arch_amd64.s
@@ -0,0 +1,135 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// MXCSR_DEFAULT is the reset value of MXCSR (Intel SDM Vol. 2, Ch. 3.2
+// "LDMXCSR")
+#define MXCSR_DEFAULT 0x1f80
+
+// MXCSR_OFFSET is the offset in bytes of the MXCSR field from the start of the
+// FXSAVE/XSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE Area")
+#define MXCSR_OFFSET 24
+
+// initX86FPState initializes floating point state.
+//
+// func initX86FPState(data *FloatingPointData, useXsave bool)
+//
+// We need to clear out and initialize an empty fp state area since the sentry
+// may have left sensitive information in the floating point registers.
+//
+// Preconditions: data is zeroed
+TEXT ·initX86FPState(SB), $24-16
+ // Save MXCSR (callee-save)
+ STMXCSR mxcsr-8(SP)
+
+ // Save x87 CW (callee-save)
+ FSTCW cw-16(SP)
+
+ MOVQ fpState+0(FP), DI
+
+ // Do we use xsave?
+ MOVBQZX useXsave+8(FP), AX
+ TESTQ AX, AX
+ JZ no_xsave
+
+ // Use XRSTOR to clear all FP state to an initial state.
+ //
+ // The fpState XSAVE area is zeroed on function entry, meaning
+ // XSTATE_BV is zero.
+ //
+ // "If RFBM[i] = 1 and bit i is clear in the XSTATE_BV field in the
+ // XSAVE header, XRSTOR initializes state component i."
+ //
+ // Initialization is defined in SDM Vol 1, Chapter 13.3. It puts all
+ // the registers in a reasonable initial state, except MXCSR:
+ //
+ // "The MXCSR register is part of state component 1, SSE state (see
+ // Section 13.5.2). However, the standard form of XRSTOR loads the
+ // MXCSR register from memory whenever the RFBM[1] (SSE) or RFBM[2]
+ // (AVX) is set, regardless of the values of XSTATE_BV[1] and
+ // XSTATE_BV[2]."
+
+ // Set MXCSR to the default value.
+ MOVL $MXCSR_DEFAULT, MXCSR_OFFSET(DI)
+
+ // Initialize registers with XRSTOR.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f // XRSTOR64 0(DI)
+
+ // Now that all the state has been reset, write it back out to the
+ // XSAVE area.
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27 // XSAVE64 0(DI)
+
+ JMP out
+
+no_xsave:
+ // Clear out existing X values.
+ PXOR X0, X0
+ MOVO X0, X1
+ MOVO X0, X2
+ MOVO X0, X3
+ MOVO X0, X4
+ MOVO X0, X5
+ MOVO X0, X6
+ MOVO X0, X7
+ MOVO X0, X8
+ MOVO X0, X9
+ MOVO X0, X10
+ MOVO X0, X11
+ MOVO X0, X12
+ MOVO X0, X13
+ MOVO X0, X14
+ MOVO X0, X15
+
+ // Zero out %rax and store into MMX registers. MMX registers are
+ // an alias of 8x64 bits of the 8x80 bits used for the original
+ // x87 registers. Storing zero into them will reset the FPU registers
+ // to bits [63:0] = 0, [79:64] = 1. But the contents aren't too
+ // important, just the fact that we have reset them to a known value.
+ XORQ AX, AX
+ MOVQ AX, M0
+ MOVQ AX, M1
+ MOVQ AX, M2
+ MOVQ AX, M3
+ MOVQ AX, M4
+ MOVQ AX, M5
+ MOVQ AX, M6
+ MOVQ AX, M7
+
+ // The Go assembler doesn't support FNINIT, so we use BYTE.
+ // This will:
+ // - Reset FPU control word to 0x037f
+ // - Clear FPU status word
+ // - Reset FPU tag word to 0xffff
+ // - Clear FPU data pointer
+ // - Clear FPU instruction pointer
+ BYTE $0xDB; BYTE $0xE3; // FNINIT
+
+ // Reset MXCSR.
+ MOVL $MXCSR_DEFAULT, tmpmxcsr-24(SP)
+ LDMXCSR tmpmxcsr-24(SP)
+
+ // Save the floating point state with fxsave.
+ FXSAVE64 0(DI)
+
+out:
+ // Restore MXCSR.
+ LDMXCSR mxcsr-8(SP)
+
+ // Restore x87 CW.
+ FLDCW cw-16(SP)
+
+ RET
diff --git a/pkg/sentry/arch/arch_state_autogen.go b/pkg/sentry/arch/arch_state_autogen.go
new file mode 100755
index 000000000..0c3b11507
--- /dev/null
+++ b/pkg/sentry/arch/arch_state_autogen.go
@@ -0,0 +1,193 @@
+// automatically generated by stateify.
+
+package arch
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *MmapLayout) beforeSave() {}
+func (x *MmapLayout) save(m state.Map) {
+ x.beforeSave()
+ m.Save("MinAddr", &x.MinAddr)
+ m.Save("MaxAddr", &x.MaxAddr)
+ m.Save("BottomUpBase", &x.BottomUpBase)
+ m.Save("TopDownBase", &x.TopDownBase)
+ m.Save("DefaultDirection", &x.DefaultDirection)
+ m.Save("MaxStackRand", &x.MaxStackRand)
+}
+
+func (x *MmapLayout) afterLoad() {}
+func (x *MmapLayout) load(m state.Map) {
+ m.Load("MinAddr", &x.MinAddr)
+ m.Load("MaxAddr", &x.MaxAddr)
+ m.Load("BottomUpBase", &x.BottomUpBase)
+ m.Load("TopDownBase", &x.TopDownBase)
+ m.Load("DefaultDirection", &x.DefaultDirection)
+ m.Load("MaxStackRand", &x.MaxStackRand)
+}
+
+func (x *context64) beforeSave() {}
+func (x *context64) save(m state.Map) {
+ x.beforeSave()
+ m.Save("State", &x.State)
+ m.Save("sigFPState", &x.sigFPState)
+}
+
+func (x *context64) afterLoad() {}
+func (x *context64) load(m state.Map) {
+ m.Load("State", &x.State)
+ m.Load("sigFPState", &x.sigFPState)
+}
+
+func (x *syscallPtraceRegs) beforeSave() {}
+func (x *syscallPtraceRegs) save(m state.Map) {
+ x.beforeSave()
+ m.Save("R15", &x.R15)
+ m.Save("R14", &x.R14)
+ m.Save("R13", &x.R13)
+ m.Save("R12", &x.R12)
+ m.Save("Rbp", &x.Rbp)
+ m.Save("Rbx", &x.Rbx)
+ m.Save("R11", &x.R11)
+ m.Save("R10", &x.R10)
+ m.Save("R9", &x.R9)
+ m.Save("R8", &x.R8)
+ m.Save("Rax", &x.Rax)
+ m.Save("Rcx", &x.Rcx)
+ m.Save("Rdx", &x.Rdx)
+ m.Save("Rsi", &x.Rsi)
+ m.Save("Rdi", &x.Rdi)
+ m.Save("Orig_rax", &x.Orig_rax)
+ m.Save("Rip", &x.Rip)
+ m.Save("Cs", &x.Cs)
+ m.Save("Eflags", &x.Eflags)
+ m.Save("Rsp", &x.Rsp)
+ m.Save("Ss", &x.Ss)
+ m.Save("Fs_base", &x.Fs_base)
+ m.Save("Gs_base", &x.Gs_base)
+ m.Save("Ds", &x.Ds)
+ m.Save("Es", &x.Es)
+ m.Save("Fs", &x.Fs)
+ m.Save("Gs", &x.Gs)
+}
+
+func (x *syscallPtraceRegs) afterLoad() {}
+func (x *syscallPtraceRegs) load(m state.Map) {
+ m.Load("R15", &x.R15)
+ m.Load("R14", &x.R14)
+ m.Load("R13", &x.R13)
+ m.Load("R12", &x.R12)
+ m.Load("Rbp", &x.Rbp)
+ m.Load("Rbx", &x.Rbx)
+ m.Load("R11", &x.R11)
+ m.Load("R10", &x.R10)
+ m.Load("R9", &x.R9)
+ m.Load("R8", &x.R8)
+ m.Load("Rax", &x.Rax)
+ m.Load("Rcx", &x.Rcx)
+ m.Load("Rdx", &x.Rdx)
+ m.Load("Rsi", &x.Rsi)
+ m.Load("Rdi", &x.Rdi)
+ m.Load("Orig_rax", &x.Orig_rax)
+ m.Load("Rip", &x.Rip)
+ m.Load("Cs", &x.Cs)
+ m.Load("Eflags", &x.Eflags)
+ m.Load("Rsp", &x.Rsp)
+ m.Load("Ss", &x.Ss)
+ m.Load("Fs_base", &x.Fs_base)
+ m.Load("Gs_base", &x.Gs_base)
+ m.Load("Ds", &x.Ds)
+ m.Load("Es", &x.Es)
+ m.Load("Fs", &x.Fs)
+ m.Load("Gs", &x.Gs)
+}
+
+func (x *State) beforeSave() {}
+func (x *State) save(m state.Map) {
+ x.beforeSave()
+ var Regs syscallPtraceRegs = x.saveRegs()
+ m.SaveValue("Regs", Regs)
+ m.Save("x86FPState", &x.x86FPState)
+ m.Save("FeatureSet", &x.FeatureSet)
+}
+
+func (x *State) load(m state.Map) {
+ m.LoadWait("x86FPState", &x.x86FPState)
+ m.Load("FeatureSet", &x.FeatureSet)
+ m.LoadValue("Regs", new(syscallPtraceRegs), func(y interface{}) { x.loadRegs(y.(syscallPtraceRegs)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *AuxEntry) beforeSave() {}
+func (x *AuxEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Key", &x.Key)
+ m.Save("Value", &x.Value)
+}
+
+func (x *AuxEntry) afterLoad() {}
+func (x *AuxEntry) load(m state.Map) {
+ m.Load("Key", &x.Key)
+ m.Load("Value", &x.Value)
+}
+
+func (x *SignalAct) beforeSave() {}
+func (x *SignalAct) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Handler", &x.Handler)
+ m.Save("Flags", &x.Flags)
+ m.Save("Restorer", &x.Restorer)
+ m.Save("Mask", &x.Mask)
+}
+
+func (x *SignalAct) afterLoad() {}
+func (x *SignalAct) load(m state.Map) {
+ m.Load("Handler", &x.Handler)
+ m.Load("Flags", &x.Flags)
+ m.Load("Restorer", &x.Restorer)
+ m.Load("Mask", &x.Mask)
+}
+
+func (x *SignalStack) beforeSave() {}
+func (x *SignalStack) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Addr", &x.Addr)
+ m.Save("Flags", &x.Flags)
+ m.Save("Size", &x.Size)
+}
+
+func (x *SignalStack) afterLoad() {}
+func (x *SignalStack) load(m state.Map) {
+ m.Load("Addr", &x.Addr)
+ m.Load("Flags", &x.Flags)
+ m.Load("Size", &x.Size)
+}
+
+func (x *SignalInfo) beforeSave() {}
+func (x *SignalInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Signo", &x.Signo)
+ m.Save("Errno", &x.Errno)
+ m.Save("Code", &x.Code)
+ m.Save("Fields", &x.Fields)
+}
+
+func (x *SignalInfo) afterLoad() {}
+func (x *SignalInfo) load(m state.Map) {
+ m.Load("Signo", &x.Signo)
+ m.Load("Errno", &x.Errno)
+ m.Load("Code", &x.Code)
+ m.Load("Fields", &x.Fields)
+}
+
+func init() {
+ state.Register("arch.MmapLayout", (*MmapLayout)(nil), state.Fns{Save: (*MmapLayout).save, Load: (*MmapLayout).load})
+ state.Register("arch.context64", (*context64)(nil), state.Fns{Save: (*context64).save, Load: (*context64).load})
+ state.Register("arch.syscallPtraceRegs", (*syscallPtraceRegs)(nil), state.Fns{Save: (*syscallPtraceRegs).save, Load: (*syscallPtraceRegs).load})
+ state.Register("arch.State", (*State)(nil), state.Fns{Save: (*State).save, Load: (*State).load})
+ state.Register("arch.AuxEntry", (*AuxEntry)(nil), state.Fns{Save: (*AuxEntry).save, Load: (*AuxEntry).load})
+ state.Register("arch.SignalAct", (*SignalAct)(nil), state.Fns{Save: (*SignalAct).save, Load: (*SignalAct).load})
+ state.Register("arch.SignalStack", (*SignalStack)(nil), state.Fns{Save: (*SignalStack).save, Load: (*SignalStack).load})
+ state.Register("arch.SignalInfo", (*SignalInfo)(nil), state.Fns{Save: (*SignalInfo).save, Load: (*SignalInfo).load})
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
new file mode 100644
index 000000000..bb52d8db0
--- /dev/null
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -0,0 +1,131 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// ErrFloatingPoint indicates a failed restore due to unusable floating point
+// state.
+type ErrFloatingPoint struct {
+ // supported is the supported floating point state.
+ supported uint64
+
+ // saved is the saved floating point state.
+ saved uint64
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrFloatingPoint) Error() string {
+ return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
+}
+
+// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
+// and SSE state, so this is the equivalent XSTATE_BV value.
+const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
+
+// afterLoad is invoked by stateify.
+func (s *State) afterLoad() {
+ old := s.x86FPState
+
+ // Recreate the slice. This is done to ensure that it is aligned
+ // appropriately in memory, and large enough to accommodate any new
+ // state that may be saved by the new CPU. Even if extraneous new state
+ // is saved, the state we care about is guaranteed to be a subset of
+ // new state. Later optimizations can use less space when using a
+ // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has
+ // more info.
+ s.x86FPState = newX86FPState()
+
+ // x86FPState always contains all the FP state supported by the host.
+ // We may have come from a newer machine that supports additional state
+ // which we cannot restore.
+ //
+ // The x86 FP state areas are backwards compatible, so we can simply
+ // truncate the additional floating point state.
+ //
+ // Applications should not depend on the truncated state because it
+ // should relate only to features that were not exposed in the app
+ // FeatureSet. However, because we do not *prevent* them from using
+ // this state, we must verify here that there is no in-use state
+ // (according to XSTATE_BV) which we do not support.
+ if len(s.x86FPState) < len(old) {
+ // What do we support?
+ supportedBV := fxsaveBV
+ if fs := cpuid.HostFeatureSet(); fs.UseXsave() {
+ supportedBV = fs.ValidXCR0Mask()
+ }
+
+ // What was in use?
+ savedBV := fxsaveBV
+ if len(old) >= xstateBVOffset+8 {
+ savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:])
+ }
+
+ // Supported features must be a superset of saved features.
+ if savedBV&^supportedBV != 0 {
+ panic(ErrFloatingPoint{supported: supportedBV, saved: savedBV})
+ }
+ }
+
+ // Copy to the new, aligned location.
+ copy(s.x86FPState, old)
+}
+
+// +stateify savable
+type syscallPtraceRegs struct {
+ R15 uint64
+ R14 uint64
+ R13 uint64
+ R12 uint64
+ Rbp uint64
+ Rbx uint64
+ R11 uint64
+ R10 uint64
+ R9 uint64
+ R8 uint64
+ Rax uint64
+ Rcx uint64
+ Rdx uint64
+ Rsi uint64
+ Rdi uint64
+ Orig_rax uint64
+ Rip uint64
+ Cs uint64
+ Eflags uint64
+ Rsp uint64
+ Ss uint64
+ Fs_base uint64
+ Gs_base uint64
+ Ds uint64
+ Es uint64
+ Fs uint64
+ Gs uint64
+}
+
+// saveRegs is invoked by stateify.
+func (s *State) saveRegs() syscallPtraceRegs {
+ return syscallPtraceRegs(s.Regs)
+}
+
+// loadRegs is invoked by stateify.
+func (s *State) loadRegs(r syscallPtraceRegs) {
+ s.Regs = syscall.PtraceRegs(r)
+}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
new file mode 100644
index 000000000..4d167ce98
--- /dev/null
+++ b/pkg/sentry/arch/arch_x86.go
@@ -0,0 +1,621 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 i386
+
+package arch
+
+import (
+ "fmt"
+ "io"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ rpb "gvisor.googlesource.com/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// System-related constants for x86.
+const (
+ // SyscallWidth is the width of syscall, sysenter, and int 80 insturctions.
+ SyscallWidth = 2
+)
+
+// EFLAGS register bits.
+const (
+ // eflagsCF is the mask for the carry flag.
+ eflagsCF = uint64(1) << 0
+ // eflagsPF is the mask for the parity flag.
+ eflagsPF = uint64(1) << 2
+ // eflagsAF is the mask for the auxiliary carry flag.
+ eflagsAF = uint64(1) << 4
+ // eflagsZF is the mask for the zero flag.
+ eflagsZF = uint64(1) << 6
+ // eflagsSF is the mask for the sign flag.
+ eflagsSF = uint64(1) << 7
+ // eflagsTF is the mask for the trap flag.
+ eflagsTF = uint64(1) << 8
+ // eflagsIF is the mask for the interrupt flag.
+ eflagsIF = uint64(1) << 9
+ // eflagsDF is the mask for the direction flag.
+ eflagsDF = uint64(1) << 10
+ // eflagsOF is the mask for the overflow flag.
+ eflagsOF = uint64(1) << 11
+ // eflagsIOPL is the mask for the I/O privilege level.
+ eflagsIOPL = uint64(3) << 12
+ // eflagsNT is the mask for the nested task bit.
+ eflagsNT = uint64(1) << 14
+ // eflagsRF is the mask for the resume flag.
+ eflagsRF = uint64(1) << 16
+ // eflagsVM is the mask for the virtual mode bit.
+ eflagsVM = uint64(1) << 17
+ // eflagsAC is the mask for the alignment check / access control bit.
+ eflagsAC = uint64(1) << 18
+ // eflagsVIF is the mask for the virtual interrupt flag.
+ eflagsVIF = uint64(1) << 19
+ // eflagsVIP is the mask for the virtual interrupt pending bit.
+ eflagsVIP = uint64(1) << 20
+ // eflagsID is the mask for the CPUID detection bit.
+ eflagsID = uint64(1) << 21
+
+ // eflagsPtraceMutable is the mask for the set of EFLAGS that may be
+ // changed by ptrace(PTRACE_SETREGS). eflagsPtraceMutable is analogous to
+ // Linux's FLAG_MASK.
+ eflagsPtraceMutable = eflagsCF | eflagsPF | eflagsAF | eflagsZF | eflagsSF | eflagsTF | eflagsDF | eflagsOF | eflagsRF | eflagsAC | eflagsNT
+
+ // eflagsRestorable is the mask for the set of EFLAGS that may be changed by
+ // SignalReturn. eflagsRestorable is analogous to Linux's FIX_EFLAGS.
+ eflagsRestorable = eflagsAC | eflagsOF | eflagsDF | eflagsTF | eflagsSF | eflagsZF | eflagsAF | eflagsPF | eflagsCF | eflagsRF
+)
+
+// Segment selectors. See arch/x86/include/asm/segment.h.
+const (
+ userCS = 0x33 // guest ring 3 code selector
+ user32CS = 0x23 // guest ring 3 32 bit code selector
+ userDS = 0x2b // guest ring 3 data selector
+
+ _FS_TLS_SEL = 0x63 // Linux FS thread-local storage selector
+ _GS_TLS_SEL = 0x6b // Linux GS thread-local storage selector
+)
+
+var (
+ // TrapInstruction is the x86 trap instruction.
+ TrapInstruction = [1]byte{0xcc}
+
+ // CPUIDInstruction is the x86 CPUID instruction.
+ CPUIDInstruction = [2]byte{0xf, 0xa2}
+
+ // X86TrapFlag is an exported const for use by other packages.
+ X86TrapFlag uint64 = (1 << 8)
+)
+
+// x86FPState is x86 floating point state.
+type x86FPState []byte
+
+// initX86FPState (defined in asm files) sets up initial state.
+func initX86FPState(data *FloatingPointData, useXsave bool)
+
+func newX86FPStateSlice() []byte {
+ size, align := cpuid.HostFeatureSet().ExtendedStateSize()
+ capacity := size
+ // Always use at least 4096 bytes.
+ if capacity < 4096 {
+ capacity = 4096
+ }
+ return alignedBytes(capacity, align)[:size]
+}
+
+// newX86FPState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet. Since they may still be able to see state not advertised by
+// CPUID we must ensure it does not contain any sentry state.
+func newX86FPState() x86FPState {
+ f := x86FPState(newX86FPStateSlice())
+ initX86FPState(f.FloatingPointData(), cpuid.HostFeatureSet().UseXsave())
+ return f
+}
+
+// fork creates and returns an identical copy of the x86 floating point state.
+func (f x86FPState) fork() x86FPState {
+ n := x86FPState(newX86FPStateSlice())
+ copy(n, f)
+ return n
+}
+
+// FloatingPointData returns the raw data pointer.
+func (f x86FPState) FloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&f[0])
+}
+
+// NewFloatingPointData returns a new floating point data blob.
+//
+// This is primarily for use in tests.
+func NewFloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&(newX86FPState()[0]))
+}
+
+// State contains the common architecture bits for X86 (the build tag of this
+// file ensures it's only built on x86).
+//
+// +stateify savable
+type State struct {
+ // The system registers.
+ Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"`
+
+ // Our floating point state.
+ x86FPState `state:"wait"`
+
+ // FeatureSet is a pointer to the currently active feature set.
+ FeatureSet *cpuid.FeatureSet
+}
+
+// Proto returns a protobuf representation of the system registers in State.
+func (s State) Proto() *rpb.Registers {
+ regs := &rpb.AMD64Registers{
+ Rax: s.Regs.Rax,
+ Rbx: s.Regs.Rbx,
+ Rcx: s.Regs.Rcx,
+ Rdx: s.Regs.Rdx,
+ Rsi: s.Regs.Rsi,
+ Rdi: s.Regs.Rdi,
+ Rsp: s.Regs.Rsp,
+ Rbp: s.Regs.Rbp,
+ R8: s.Regs.R8,
+ R9: s.Regs.R9,
+ R10: s.Regs.R10,
+ R11: s.Regs.R11,
+ R12: s.Regs.R12,
+ R13: s.Regs.R13,
+ R14: s.Regs.R14,
+ R15: s.Regs.R15,
+ Rip: s.Regs.Rip,
+ Rflags: s.Regs.Eflags,
+ OrigRax: s.Regs.Orig_rax,
+ Cs: s.Regs.Cs,
+ Ds: s.Regs.Ds,
+ Es: s.Regs.Es,
+ Fs: s.Regs.Fs,
+ Gs: s.Regs.Gs,
+ Ss: s.Regs.Ss,
+ FsBase: s.Regs.Fs_base,
+ GsBase: s.Regs.Gs_base,
+ }
+ return &rpb.Registers{Arch: &rpb.Registers_Amd64{Amd64: regs}}
+}
+
+// Fork creates and returns an identical copy of the state.
+func (s *State) Fork() State {
+ return State{
+ Regs: s.Regs,
+ x86FPState: s.x86FPState.fork(),
+ FeatureSet: s.FeatureSet,
+ }
+}
+
+// StateData implements Context.StateData.
+func (s *State) StateData() *State {
+ return s
+}
+
+// CPUIDEmulate emulates a cpuid instruction.
+func (s *State) CPUIDEmulate(l log.Logger) {
+ argax := uint32(s.Regs.Rax)
+ argcx := uint32(s.Regs.Rcx)
+ ax, bx, cx, dx := s.FeatureSet.EmulateID(argax, argcx)
+ s.Regs.Rax = uint64(ax)
+ s.Regs.Rbx = uint64(bx)
+ s.Regs.Rcx = uint64(cx)
+ s.Regs.Rdx = uint64(dx)
+ l.Debugf("CPUID(%x,%x): %x %x %x %x", argax, argcx, ax, bx, cx, dx)
+}
+
+// SingleStep implements Context.SingleStep.
+func (s *State) SingleStep() bool {
+ return s.Regs.Eflags&X86TrapFlag != 0
+}
+
+// SetSingleStep enables single stepping.
+func (s *State) SetSingleStep() {
+ // Set the trap flag.
+ s.Regs.Eflags |= X86TrapFlag
+}
+
+// ClearSingleStep enables single stepping.
+func (s *State) ClearSingleStep() {
+ // Clear the trap flag.
+ s.Regs.Eflags &= ^X86TrapFlag
+}
+
+// RegisterMap returns a map of all registers.
+func (s *State) RegisterMap() (map[string]uintptr, error) {
+ return map[string]uintptr{
+ "R15": uintptr(s.Regs.R15),
+ "R14": uintptr(s.Regs.R14),
+ "R13": uintptr(s.Regs.R13),
+ "R12": uintptr(s.Regs.R12),
+ "Rbp": uintptr(s.Regs.Rbp),
+ "Rbx": uintptr(s.Regs.Rbx),
+ "R11": uintptr(s.Regs.R11),
+ "R10": uintptr(s.Regs.R10),
+ "R9": uintptr(s.Regs.R9),
+ "R8": uintptr(s.Regs.R8),
+ "Rax": uintptr(s.Regs.Rax),
+ "Rcx": uintptr(s.Regs.Rcx),
+ "Rdx": uintptr(s.Regs.Rdx),
+ "Rsi": uintptr(s.Regs.Rsi),
+ "Rdi": uintptr(s.Regs.Rdi),
+ "Orig_rax": uintptr(s.Regs.Orig_rax),
+ "Rip": uintptr(s.Regs.Rip),
+ "Cs": uintptr(s.Regs.Cs),
+ "Eflags": uintptr(s.Regs.Eflags),
+ "Rsp": uintptr(s.Regs.Rsp),
+ "Ss": uintptr(s.Regs.Ss),
+ "Fs_base": uintptr(s.Regs.Fs_base),
+ "Gs_base": uintptr(s.Regs.Gs_base),
+ "Ds": uintptr(s.Regs.Ds),
+ "Es": uintptr(s.Regs.Es),
+ "Fs": uintptr(s.Regs.Fs),
+ "Gs": uintptr(s.Regs.Gs),
+ }, nil
+}
+
+// PtraceGetRegs implements Context.PtraceGetRegs.
+func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
+ return dst.Write(binary.Marshal(nil, usermem.ByteOrder, s.ptraceGetRegs()))
+}
+
+func (s *State) ptraceGetRegs() syscall.PtraceRegs {
+ regs := s.Regs
+ // These may not be initialized.
+ if regs.Cs == 0 || regs.Ss == 0 || regs.Eflags == 0 {
+ regs.Eflags = eflagsIF
+ regs.Cs = userCS
+ regs.Ss = userDS
+ }
+ // As an optimization, Linux <4.7 implements 32-bit fs_base/gs_base
+ // addresses using reserved descriptors in the GDT instead of the MSRs,
+ // with selector values FS_TLS_SEL and GS_TLS_SEL respectively. These
+ // values are actually visible in struct user_regs_struct::fs/gs;
+ // arch/x86/kernel/ptrace.c:getreg() doesn't attempt to sanitize struct
+ // thread_struct::fsindex/gsindex.
+ //
+ // We always use fs == gs == 0 when fs_base/gs_base is in use, for
+ // simplicity.
+ //
+ // Luckily, Linux <4.7 silently ignores setting fs/gs to 0 via
+ // arch/x86/kernel/ptrace.c:set_segment_reg() when fs_base/gs_base is a
+ // 32-bit value and fsindex/gsindex indicates that this optimization is
+ // in use, as well as the reverse case of setting fs/gs to
+ // FS/GS_TLS_SEL when fs_base/gs_base is a 64-bit value. (We do the
+ // same in PtraceSetRegs.)
+ //
+ // TODO(gvisor.dev/issue/168): Remove this fixup since newer Linux
+ // doesn't have this behavior anymore.
+ if regs.Fs == 0 && regs.Fs_base <= 0xffffffff {
+ regs.Fs = _FS_TLS_SEL
+ }
+ if regs.Gs == 0 && regs.Gs_base <= 0xffffffff {
+ regs.Gs = _GS_TLS_SEL
+ }
+ return regs
+}
+
+var ptraceRegsSize = int(binary.Size(syscall.PtraceRegs{}))
+
+// PtraceSetRegs implements Context.PtraceSetRegs.
+func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
+ var regs syscall.PtraceRegs
+ buf := make([]byte, ptraceRegsSize)
+ if _, err := io.ReadFull(src, buf); err != nil {
+ return 0, err
+ }
+ binary.Unmarshal(buf, usermem.ByteOrder, &regs)
+ // Truncate segment registers to 16 bits.
+ regs.Cs = uint64(uint16(regs.Cs))
+ regs.Ds = uint64(uint16(regs.Ds))
+ regs.Es = uint64(uint16(regs.Es))
+ regs.Fs = uint64(uint16(regs.Fs))
+ regs.Gs = uint64(uint16(regs.Gs))
+ regs.Ss = uint64(uint16(regs.Ss))
+ // In Linux this validation is via arch/x86/kernel/ptrace.c:putreg().
+ if !isUserSegmentSelector(regs.Cs) {
+ return 0, syscall.EIO
+ }
+ if regs.Ds != 0 && !isUserSegmentSelector(regs.Ds) {
+ return 0, syscall.EIO
+ }
+ if regs.Es != 0 && !isUserSegmentSelector(regs.Es) {
+ return 0, syscall.EIO
+ }
+ if regs.Fs != 0 && !isUserSegmentSelector(regs.Fs) {
+ return 0, syscall.EIO
+ }
+ if regs.Gs != 0 && !isUserSegmentSelector(regs.Gs) {
+ return 0, syscall.EIO
+ }
+ if !isUserSegmentSelector(regs.Ss) {
+ return 0, syscall.EIO
+ }
+ if !isValidSegmentBase(regs.Fs_base) {
+ return 0, syscall.EIO
+ }
+ if !isValidSegmentBase(regs.Gs_base) {
+ return 0, syscall.EIO
+ }
+ // CS and SS are validated, but changes to them are otherwise silently
+ // ignored on amd64.
+ regs.Cs = s.Regs.Cs
+ regs.Ss = s.Regs.Ss
+ // fs_base/gs_base changes reset fs/gs via do_arch_prctl() on Linux.
+ if regs.Fs_base != s.Regs.Fs_base {
+ regs.Fs = 0
+ }
+ if regs.Gs_base != s.Regs.Gs_base {
+ regs.Gs = 0
+ }
+ // Ignore "stale" TLS segment selectors for FS and GS. See comment in
+ // ptraceGetRegs.
+ if regs.Fs == _FS_TLS_SEL && regs.Fs_base != 0 {
+ regs.Fs = 0
+ }
+ if regs.Gs == _GS_TLS_SEL && regs.Gs_base != 0 {
+ regs.Gs = 0
+ }
+ regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable)
+ s.Regs = regs
+ return ptraceRegsSize, nil
+}
+
+// isUserSegmentSelector returns true if the given segment selector specifies a
+// privilege level of 3 (USER_RPL).
+func isUserSegmentSelector(reg uint64) bool {
+ return reg&3 == 3
+}
+
+// isValidSegmentBase returns true if the given segment base specifies a
+// canonical user address.
+func isValidSegmentBase(reg uint64) bool {
+ return reg < uint64(maxAddr64)
+}
+
+// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type
+// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently,
+// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area.
+const ptraceFPRegsSize = 512
+
+// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
+func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) {
+ return dst.Write(s.x86FPState[:ptraceFPRegsSize])
+}
+
+// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
+func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
+ var f [ptraceFPRegsSize]byte
+ n, err := io.ReadFull(src, f[:])
+ if err != nil {
+ return 0, err
+ }
+ // Force reserved bits in MXCSR to 0. This is consistent with Linux.
+ sanitizeMXCSR(x86FPState(f[:]))
+ // N.B. this only copies the beginning of the FP state, which
+ // corresponds to the FXSAVE area.
+ copy(s.x86FPState, f[:])
+ return n, nil
+}
+
+const (
+ // mxcsrOffset is the offset in bytes of the MXCSR field from the start of
+ // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE
+ // Area")
+ mxcsrOffset = 24
+
+ // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the
+ // start of the FXSAVE area.
+ mxcsrMaskOffset = 28
+)
+
+var (
+ mxcsrMask uint32
+ initMXCSRMask sync.Once
+)
+
+// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR
+// generates a general-protection fault (#GP) in response to an attempt to set
+// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section
+// 10.5.1.2 "SSE State")
+func sanitizeMXCSR(f x86FPState) {
+ mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:])
+ initMXCSRMask.Do(func() {
+ temp := x86FPState(alignedBytes(uint(ptraceFPRegsSize), 16))
+ initX86FPState(temp.FloatingPointData(), false /* useXsave */)
+ mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:])
+ if mxcsrMask == 0 {
+ // "If the value of the MXCSR_MASK field is 00000000H, then the
+ // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM
+ // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR
+ // Register"
+ mxcsrMask = 0xffbf
+ }
+ })
+ mxcsr &= mxcsrMask
+ usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr)
+}
+
+const (
+ // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal
+ // to the size of the XSAVE legacy area (512 bytes) plus the size of the
+ // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's
+ // X86_XSTATE_SSE_SIZE.
+ minXstateBytes = 512 + 64
+
+ // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD
+ // field in Linux's struct user_xstateregs, which is the type manipulated
+ // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently,
+ // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET.
+ userXstateXCR0Offset = 464
+
+ // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86
+ // XSAVE area.
+ xstateBVOffset = 512
+
+ // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the
+ // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is
+ // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE
+ // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header".
+ // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP
+ // exceptions resulting from invalid values; we aren't. Linux also never
+ // uses the compacted format when doing XSAVE and doesn't even define the
+ // compaction extensions to XSAVE as a CPU feature, so for simplicity we
+ // assume no one is using them.
+ xsaveHeaderZeroedOffset = 512 + 8
+ xsaveHeaderZeroedBytes = 64 - 8
+)
+
+func (s *State) ptraceGetXstateRegs(dst io.Writer, maxlen int) (int, error) {
+ // N.B. s.x86FPState may contain more state than the application
+ // expects. We only copy the subset that would be in their XSAVE area.
+ ess, _ := s.FeatureSet.ExtendedStateSize()
+ f := make([]byte, ess)
+ copy(f, s.x86FPState)
+ // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are
+ // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE
+ // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE
+ // mask. GDB relies on this: see
+ // gdb/x86-linux-nat.c:x86_linux_read_description().
+ usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], s.FeatureSet.ValidXCR0Mask())
+ if len(f) > maxlen {
+ f = f[:maxlen]
+ }
+ return dst.Write(f)
+}
+
+func (s *State) ptraceSetXstateRegs(src io.Reader, maxlen int) (int, error) {
+ // Allow users to pass an xstate register set smaller than ours (they can
+ // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes.
+ // Also allow users to pass a register set larger than ours; anything after
+ // their ExtendedStateSize will be ignored. (I think Linux technically
+ // permits setting a register set smaller than minXstateBytes, but it has
+ // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().)
+ if maxlen < minXstateBytes {
+ return 0, syscall.EFAULT
+ }
+ ess, _ := s.FeatureSet.ExtendedStateSize()
+ if maxlen > int(ess) {
+ maxlen = int(ess)
+ }
+ f := make([]byte, maxlen)
+ if _, err := io.ReadFull(src, f); err != nil {
+ return 0, err
+ }
+ // Force reserved bits in MXCSR to 0. This is consistent with Linux.
+ sanitizeMXCSR(x86FPState(f))
+ // Users can't enable *more* XCR0 bits than what we, and the CPU, support.
+ xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:])
+ xstateBV &= s.FeatureSet.ValidXCR0Mask()
+ usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV)
+ // Force XCOMP_BV and reserved bytes in the XSAVE header to 0.
+ reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes]
+ for i := range reserved {
+ reserved[i] = 0
+ }
+ return copy(s.x86FPState, f), nil
+}
+
+// Register sets defined in include/uapi/linux/elf.h.
+const (
+ _NT_PRSTATUS = 1
+ _NT_PRFPREG = 2
+ _NT_X86_XSTATE = 0x202
+)
+
+// PtraceGetRegSet implements Context.PtraceGetRegSet.
+func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < ptraceRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceGetRegs(dst)
+ case _NT_PRFPREG:
+ if maxlen < ptraceFPRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceGetFPRegs(dst)
+ case _NT_X86_XSTATE:
+ return s.ptraceGetXstateRegs(dst, maxlen)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// PtraceSetRegSet implements Context.PtraceSetRegSet.
+func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < ptraceRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceSetRegs(src)
+ case _NT_PRFPREG:
+ if maxlen < ptraceFPRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceSetFPRegs(src)
+ case _NT_X86_XSTATE:
+ return s.ptraceSetXstateRegs(src, maxlen)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// FullRestore indicates whether a full restore is required.
+func (s *State) FullRestore() bool {
+ // A fast system call return is possible only if
+ //
+ // * RCX matches the instruction pointer.
+ // * R11 matches our flags value.
+ // * Usermode does not expect to set either the resume flag or the
+ // virtual mode flags (unlikely.)
+ // * CS and SS are set to the standard selectors.
+ //
+ // That is, SYSRET results in the correct final state.
+ fastRestore := s.Regs.Rcx == s.Regs.Rip &&
+ s.Regs.Eflags == s.Regs.R11 &&
+ (s.Regs.Eflags&eflagsRF == 0) &&
+ (s.Regs.Eflags&eflagsVM == 0) &&
+ s.Regs.Cs == userCS &&
+ s.Regs.Ss == userDS
+ return !fastRestore
+}
+
+// New returns a new architecture context.
+func New(arch Arch, fs *cpuid.FeatureSet) Context {
+ switch arch {
+ case AMD64:
+ return &context64{
+ State{
+ x86FPState: newX86FPState(),
+ FeatureSet: fs,
+ },
+ []x86FPState(nil),
+ }
+ }
+ panic(fmt.Sprintf("unknown architecture %v", arch))
+}
diff --git a/pkg/sentry/arch/auxv.go b/pkg/sentry/arch/auxv.go
new file mode 100644
index 000000000..80c923103
--- /dev/null
+++ b/pkg/sentry/arch/auxv.go
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// An AuxEntry represents an entry in an ELF auxiliary vector.
+//
+// +stateify savable
+type AuxEntry struct {
+ Key uint64
+ Value usermem.Addr
+}
+
+// An Auxv represents an ELF auxiliary vector.
+type Auxv []AuxEntry
diff --git a/pkg/sentry/arch/registers_go_proto/registers.pb.go b/pkg/sentry/arch/registers_go_proto/registers.pb.go
new file mode 100755
index 000000000..088209be7
--- /dev/null
+++ b/pkg/sentry/arch/registers_go_proto/registers.pb.go
@@ -0,0 +1,367 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: pkg/sentry/arch/registers.proto
+
+package gvisor
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type AMD64Registers struct {
+ Rax uint64 `protobuf:"varint,1,opt,name=rax,proto3" json:"rax,omitempty"`
+ Rbx uint64 `protobuf:"varint,2,opt,name=rbx,proto3" json:"rbx,omitempty"`
+ Rcx uint64 `protobuf:"varint,3,opt,name=rcx,proto3" json:"rcx,omitempty"`
+ Rdx uint64 `protobuf:"varint,4,opt,name=rdx,proto3" json:"rdx,omitempty"`
+ Rsi uint64 `protobuf:"varint,5,opt,name=rsi,proto3" json:"rsi,omitempty"`
+ Rdi uint64 `protobuf:"varint,6,opt,name=rdi,proto3" json:"rdi,omitempty"`
+ Rsp uint64 `protobuf:"varint,7,opt,name=rsp,proto3" json:"rsp,omitempty"`
+ Rbp uint64 `protobuf:"varint,8,opt,name=rbp,proto3" json:"rbp,omitempty"`
+ R8 uint64 `protobuf:"varint,9,opt,name=r8,proto3" json:"r8,omitempty"`
+ R9 uint64 `protobuf:"varint,10,opt,name=r9,proto3" json:"r9,omitempty"`
+ R10 uint64 `protobuf:"varint,11,opt,name=r10,proto3" json:"r10,omitempty"`
+ R11 uint64 `protobuf:"varint,12,opt,name=r11,proto3" json:"r11,omitempty"`
+ R12 uint64 `protobuf:"varint,13,opt,name=r12,proto3" json:"r12,omitempty"`
+ R13 uint64 `protobuf:"varint,14,opt,name=r13,proto3" json:"r13,omitempty"`
+ R14 uint64 `protobuf:"varint,15,opt,name=r14,proto3" json:"r14,omitempty"`
+ R15 uint64 `protobuf:"varint,16,opt,name=r15,proto3" json:"r15,omitempty"`
+ Rip uint64 `protobuf:"varint,17,opt,name=rip,proto3" json:"rip,omitempty"`
+ Rflags uint64 `protobuf:"varint,18,opt,name=rflags,proto3" json:"rflags,omitempty"`
+ OrigRax uint64 `protobuf:"varint,19,opt,name=orig_rax,json=origRax,proto3" json:"orig_rax,omitempty"`
+ Cs uint64 `protobuf:"varint,20,opt,name=cs,proto3" json:"cs,omitempty"`
+ Ds uint64 `protobuf:"varint,21,opt,name=ds,proto3" json:"ds,omitempty"`
+ Es uint64 `protobuf:"varint,22,opt,name=es,proto3" json:"es,omitempty"`
+ Fs uint64 `protobuf:"varint,23,opt,name=fs,proto3" json:"fs,omitempty"`
+ Gs uint64 `protobuf:"varint,24,opt,name=gs,proto3" json:"gs,omitempty"`
+ Ss uint64 `protobuf:"varint,25,opt,name=ss,proto3" json:"ss,omitempty"`
+ FsBase uint64 `protobuf:"varint,26,opt,name=fs_base,json=fsBase,proto3" json:"fs_base,omitempty"`
+ GsBase uint64 `protobuf:"varint,27,opt,name=gs_base,json=gsBase,proto3" json:"gs_base,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *AMD64Registers) Reset() { *m = AMD64Registers{} }
+func (m *AMD64Registers) String() string { return proto.CompactTextString(m) }
+func (*AMD64Registers) ProtoMessage() {}
+func (*AMD64Registers) Descriptor() ([]byte, []int) {
+ return fileDescriptor_082b7510610e0457, []int{0}
+}
+
+func (m *AMD64Registers) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_AMD64Registers.Unmarshal(m, b)
+}
+func (m *AMD64Registers) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_AMD64Registers.Marshal(b, m, deterministic)
+}
+func (m *AMD64Registers) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_AMD64Registers.Merge(m, src)
+}
+func (m *AMD64Registers) XXX_Size() int {
+ return xxx_messageInfo_AMD64Registers.Size(m)
+}
+func (m *AMD64Registers) XXX_DiscardUnknown() {
+ xxx_messageInfo_AMD64Registers.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_AMD64Registers proto.InternalMessageInfo
+
+func (m *AMD64Registers) GetRax() uint64 {
+ if m != nil {
+ return m.Rax
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRbx() uint64 {
+ if m != nil {
+ return m.Rbx
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRcx() uint64 {
+ if m != nil {
+ return m.Rcx
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRdx() uint64 {
+ if m != nil {
+ return m.Rdx
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRsi() uint64 {
+ if m != nil {
+ return m.Rsi
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRdi() uint64 {
+ if m != nil {
+ return m.Rdi
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRsp() uint64 {
+ if m != nil {
+ return m.Rsp
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRbp() uint64 {
+ if m != nil {
+ return m.Rbp
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetR8() uint64 {
+ if m != nil {
+ return m.R8
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetR9() uint64 {
+ if m != nil {
+ return m.R9
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetR10() uint64 {
+ if m != nil {
+ return m.R10
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetR11() uint64 {
+ if m != nil {
+ return m.R11
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetR12() uint64 {
+ if m != nil {
+ return m.R12
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetR13() uint64 {
+ if m != nil {
+ return m.R13
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetR14() uint64 {
+ if m != nil {
+ return m.R14
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetR15() uint64 {
+ if m != nil {
+ return m.R15
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRip() uint64 {
+ if m != nil {
+ return m.Rip
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetRflags() uint64 {
+ if m != nil {
+ return m.Rflags
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetOrigRax() uint64 {
+ if m != nil {
+ return m.OrigRax
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetCs() uint64 {
+ if m != nil {
+ return m.Cs
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetDs() uint64 {
+ if m != nil {
+ return m.Ds
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetEs() uint64 {
+ if m != nil {
+ return m.Es
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetFs() uint64 {
+ if m != nil {
+ return m.Fs
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetGs() uint64 {
+ if m != nil {
+ return m.Gs
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetSs() uint64 {
+ if m != nil {
+ return m.Ss
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetFsBase() uint64 {
+ if m != nil {
+ return m.FsBase
+ }
+ return 0
+}
+
+func (m *AMD64Registers) GetGsBase() uint64 {
+ if m != nil {
+ return m.GsBase
+ }
+ return 0
+}
+
+type Registers struct {
+ // Types that are valid to be assigned to Arch:
+ // *Registers_Amd64
+ Arch isRegisters_Arch `protobuf_oneof:"arch"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Registers) Reset() { *m = Registers{} }
+func (m *Registers) String() string { return proto.CompactTextString(m) }
+func (*Registers) ProtoMessage() {}
+func (*Registers) Descriptor() ([]byte, []int) {
+ return fileDescriptor_082b7510610e0457, []int{1}
+}
+
+func (m *Registers) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Registers.Unmarshal(m, b)
+}
+func (m *Registers) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Registers.Marshal(b, m, deterministic)
+}
+func (m *Registers) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Registers.Merge(m, src)
+}
+func (m *Registers) XXX_Size() int {
+ return xxx_messageInfo_Registers.Size(m)
+}
+func (m *Registers) XXX_DiscardUnknown() {
+ xxx_messageInfo_Registers.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Registers proto.InternalMessageInfo
+
+type isRegisters_Arch interface {
+ isRegisters_Arch()
+}
+
+type Registers_Amd64 struct {
+ Amd64 *AMD64Registers `protobuf:"bytes,1,opt,name=amd64,proto3,oneof"`
+}
+
+func (*Registers_Amd64) isRegisters_Arch() {}
+
+func (m *Registers) GetArch() isRegisters_Arch {
+ if m != nil {
+ return m.Arch
+ }
+ return nil
+}
+
+func (m *Registers) GetAmd64() *AMD64Registers {
+ if x, ok := m.GetArch().(*Registers_Amd64); ok {
+ return x.Amd64
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*Registers) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*Registers_Amd64)(nil),
+ }
+}
+
+func init() {
+ proto.RegisterType((*AMD64Registers)(nil), "gvisor.AMD64Registers")
+ proto.RegisterType((*Registers)(nil), "gvisor.Registers")
+}
+
+func init() { proto.RegisterFile("pkg/sentry/arch/registers.proto", fileDescriptor_082b7510610e0457) }
+
+var fileDescriptor_082b7510610e0457 = []byte{
+ // 354 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x54, 0x92, 0x4d, 0x6f, 0xe2, 0x30,
+ 0x10, 0x86, 0x17, 0x08, 0x01, 0xcc, 0x2e, 0xcb, 0x66, 0x5b, 0x18, 0xda, 0x43, 0x2b, 0x4e, 0x3d,
+ 0x05, 0x02, 0x01, 0xc1, 0xb1, 0xb4, 0x87, 0x5e, 0x7a, 0xc9, 0x1f, 0x40, 0xf9, 0x74, 0xad, 0x7e,
+ 0xc4, 0xf2, 0xa0, 0x2a, 0x3d, 0xf7, 0x8f, 0x57, 0xf6, 0xd8, 0xaa, 0x7a, 0xcb, 0xf3, 0xcc, 0x6b,
+ 0x39, 0x93, 0xbc, 0xec, 0x4a, 0x3e, 0xf3, 0x05, 0x96, 0x6f, 0x27, 0xf5, 0xb1, 0x48, 0x55, 0xfe,
+ 0xb4, 0x50, 0x25, 0x17, 0x78, 0x2a, 0x15, 0x86, 0x52, 0xd5, 0xa7, 0x3a, 0xf0, 0xf9, 0xbb, 0xc0,
+ 0x5a, 0xcd, 0x3f, 0x3d, 0x36, 0xba, 0x7d, 0xbc, 0xdf, 0xc6, 0x89, 0x0b, 0x04, 0x63, 0xd6, 0x51,
+ 0x69, 0x03, 0xad, 0xeb, 0xd6, 0x8d, 0x97, 0xe8, 0x47, 0x63, 0xb2, 0x06, 0xda, 0xd6, 0x64, 0x64,
+ 0xf2, 0x06, 0x3a, 0xd6, 0xe4, 0x64, 0x8a, 0x06, 0x3c, 0x6b, 0x0a, 0x32, 0x28, 0xa0, 0x6b, 0x0d,
+ 0x0a, 0xca, 0x08, 0xf0, 0x5d, 0x86, 0x0c, 0x4a, 0xe8, 0xb9, 0x8c, 0xa4, 0xbb, 0x24, 0xf4, 0xdd,
+ 0x5d, 0x32, 0x18, 0xb1, 0xb6, 0xda, 0xc1, 0xc0, 0x88, 0xb6, 0xda, 0x19, 0xde, 0x03, 0xb3, 0xbc,
+ 0x37, 0x27, 0xa2, 0x25, 0x0c, 0xed, 0x89, 0x68, 0x49, 0x26, 0x82, 0xdf, 0xce, 0x44, 0x64, 0x56,
+ 0xf0, 0xc7, 0x99, 0x15, 0x99, 0x35, 0x8c, 0x9c, 0x59, 0x93, 0x89, 0xe1, 0xaf, 0x33, 0x31, 0x99,
+ 0x0d, 0x8c, 0x9d, 0xd9, 0x18, 0x23, 0x24, 0xfc, 0xb3, 0x46, 0xc8, 0x60, 0xc2, 0x7c, 0x55, 0xbd,
+ 0xa4, 0x1c, 0x21, 0x30, 0xd2, 0x52, 0x30, 0x63, 0xfd, 0x5a, 0x09, 0x7e, 0xd4, 0x9f, 0xf2, 0xbf,
+ 0x99, 0xf4, 0x34, 0x27, 0x69, 0xa3, 0x17, 0xc8, 0x11, 0xce, 0x68, 0x81, 0x1c, 0x35, 0x17, 0x08,
+ 0xe7, 0xc4, 0x85, 0xe1, 0x12, 0x61, 0x42, 0x5c, 0x1a, 0xae, 0x10, 0xa6, 0xc4, 0x95, 0x61, 0x8e,
+ 0x00, 0xc4, 0xdc, 0x30, 0x22, 0xcc, 0x88, 0x11, 0x83, 0x29, 0xeb, 0x55, 0x78, 0xcc, 0x52, 0x2c,
+ 0xe1, 0x82, 0xde, 0xa9, 0xc2, 0x43, 0x8a, 0xa5, 0x1e, 0x70, 0x3b, 0xb8, 0xa4, 0x01, 0x37, 0x83,
+ 0xf9, 0x1d, 0x1b, 0x7c, 0xff, 0xff, 0x90, 0x75, 0xd3, 0xd7, 0x62, 0x1b, 0x9b, 0x06, 0x0c, 0x57,
+ 0x93, 0x90, 0xaa, 0x12, 0xfe, 0xac, 0xc9, 0xc3, 0xaf, 0x84, 0x62, 0x07, 0x9f, 0x79, 0xba, 0x62,
+ 0x99, 0x6f, 0x9a, 0xb5, 0xfe, 0x0a, 0x00, 0x00, 0xff, 0xff, 0xb4, 0xcc, 0x03, 0x27, 0x7c, 0x02,
+ 0x00, 0x00,
+}
diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go
new file mode 100644
index 000000000..f9ca2e74e
--- /dev/null
+++ b/pkg/sentry/arch/signal_act.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+// Special values for SignalAct.Handler.
+const (
+ // SignalActDefault is SIG_DFL and specifies that the default behavior for
+ // a signal should be taken.
+ SignalActDefault = 0
+
+ // SignalActIgnore is SIG_IGN and specifies that a signal should be
+ // ignored.
+ SignalActIgnore = 1
+)
+
+// Available signal flags.
+const (
+ SignalFlagNoCldStop = 0x00000001
+ SignalFlagNoCldWait = 0x00000002
+ SignalFlagSigInfo = 0x00000004
+ SignalFlagRestorer = 0x04000000
+ SignalFlagOnStack = 0x08000000
+ SignalFlagRestart = 0x10000000
+ SignalFlagInterrupt = 0x20000000
+ SignalFlagNoDefer = 0x40000000
+ SignalFlagResetHandler = 0x80000000
+)
+
+// IsSigInfo returns true iff this handle expects siginfo.
+func (s SignalAct) IsSigInfo() bool {
+ return s.Flags&SignalFlagSigInfo != 0
+}
+
+// IsNoDefer returns true iff this SignalAct has the NoDefer flag set.
+func (s SignalAct) IsNoDefer() bool {
+ return s.Flags&SignalFlagNoDefer != 0
+}
+
+// IsRestart returns true iff this SignalAct has the Restart flag set.
+func (s SignalAct) IsRestart() bool {
+ return s.Flags&SignalFlagRestart != 0
+}
+
+// IsResetHandler returns true iff this SignalAct has the ResetHandler flag set.
+func (s SignalAct) IsResetHandler() bool {
+ return s.Flags&SignalFlagResetHandler != 0
+}
+
+// IsOnStack returns true iff this SignalAct has the OnStack flag set.
+func (s SignalAct) IsOnStack() bool {
+ return s.Flags&SignalFlagOnStack != 0
+}
+
+// HasRestorer returns true iff this SignalAct has the Restorer flag set.
+func (s SignalAct) HasRestorer() bool {
+ return s.Flags&SignalFlagRestorer != 0
+}
+
+// NativeSignalAct is a type that is equivalent to struct sigaction in the
+// guest architecture.
+type NativeSignalAct interface {
+ // SerializeFrom copies the data in the host SignalAct s into this object.
+ SerializeFrom(s *SignalAct)
+
+ // DeserializeTo copies the data in this object into the host SignalAct s.
+ DeserializeTo(s *SignalAct)
+}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
new file mode 100644
index 000000000..aa030fd70
--- /dev/null
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -0,0 +1,521 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package arch
+
+import (
+ "encoding/binary"
+ "math"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// SignalAct represents the action that should be taken when a signal is
+// delivered, and is equivalent to struct sigaction on 64-bit x86.
+//
+// +stateify savable
+type SignalAct struct {
+ Handler uint64
+ Flags uint64
+ Restorer uint64
+ Mask linux.SignalSet
+}
+
+// SerializeFrom implements NativeSignalAct.SerializeFrom.
+func (s *SignalAct) SerializeFrom(other *SignalAct) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalAct.DeserializeTo.
+func (s *SignalAct) DeserializeTo(other *SignalAct) {
+ *other = *s
+}
+
+// SignalStack represents information about a user stack, and is equivalent to
+// stack_t on 64-bit x86.
+//
+// +stateify savable
+type SignalStack struct {
+ Addr uint64
+ Flags uint32
+ _ uint32
+ Size uint64
+}
+
+// SerializeFrom implements NativeSignalStack.SerializeFrom.
+func (s *SignalStack) SerializeFrom(other *SignalStack) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalStack.DeserializeTo.
+func (s *SignalStack) DeserializeTo(other *SignalStack) {
+ *other = *s
+}
+
+// SignalInfo represents information about a signal being delivered, and is
+// equivalent to struct siginfo on 64-bit x86.
+//
+// +stateify savable
+type SignalInfo struct {
+ Signo int32 // Signal number
+ Errno int32 // Errno value
+ Code int32 // Signal code
+ _ uint32
+
+ // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
+ // are accessed through methods.
+ //
+ // For reference, here is the definition of _sifields: (_sigfault._trapno,
+ // which does not exist on x86, omitted for clarity)
+ //
+ // union {
+ // int _pad[SI_PAD_SIZE];
+ //
+ // /* kill() */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // } _kill;
+ //
+ // /* POSIX.1b timers */
+ // struct {
+ // __kernel_timer_t _tid; /* timer id */
+ // int _overrun; /* overrun count */
+ // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
+ // sigval_t _sigval; /* same as below */
+ // int _sys_private; /* not to be passed to user */
+ // } _timer;
+ //
+ // /* POSIX.1b signals */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // sigval_t _sigval;
+ // } _rt;
+ //
+ // /* SIGCHLD */
+ // struct {
+ // __kernel_pid_t _pid; /* which child */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // int _status; /* exit code */
+ // __ARCH_SI_CLOCK_T _utime;
+ // __ARCH_SI_CLOCK_T _stime;
+ // } _sigchld;
+ //
+ // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
+ // struct {
+ // void *_addr; /* faulting insn/memory ref. */
+ // short _addr_lsb; /* LSB of the reported address */
+ // } _sigfault;
+ //
+ // /* SIGPOLL */
+ // struct {
+ // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
+ // int _fd;
+ // } _sigpoll;
+ //
+ // /* SIGSYS */
+ // struct {
+ // void *_call_addr; /* calling user insn */
+ // int _syscall; /* triggering system call number */
+ // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
+ // } _sigsys;
+ // } _sifields;
+ //
+ // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
+ // bytes.
+ Fields [128 - 16]byte
+}
+
+// FixSignalCodeForUser fixes up si_code.
+//
+// The si_code we get from Linux may contain the kernel-specific code in the
+// top 16 bits if it's positive (e.g., from ptrace). Linux's
+// copy_siginfo_to_user does
+// err |= __put_user((short)from->si_code, &to->si_code);
+// to mask out those bits and we need to do the same.
+func (s *SignalInfo) FixSignalCodeForUser() {
+ if s.Code > 0 {
+ s.Code &= 0x0000ffff
+ }
+}
+
+// Pid returns the si_pid field.
+func (s *SignalInfo) Pid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetPid mutates the si_pid field.
+func (s *SignalInfo) SetPid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Uid returns the si_uid field.
+func (s *SignalInfo) Uid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetUid mutates the si_uid field.
+func (s *SignalInfo) SetUid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
+func (s *SignalInfo) Sigval() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[8:16])
+}
+
+// SetSigval mutates the sigval field.
+func (s *SignalInfo) SetSigval(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
+}
+
+// TimerID returns the si_timerid field.
+func (s *SignalInfo) TimerID() linux.TimerID {
+ return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetTimerID sets the si_timerid field.
+func (s *SignalInfo) SetTimerID(val linux.TimerID) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Overrun returns the si_overrun field.
+func (s *SignalInfo) Overrun() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetOverrun sets the si_overrun field.
+func (s *SignalInfo) SetOverrun(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Addr returns the si_addr field.
+func (s *SignalInfo) Addr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetAddr sets the si_addr field.
+func (s *SignalInfo) SetAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Status returns the si_status field.
+func (s *SignalInfo) Status() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetStatus mutates the si_status field.
+func (s *SignalInfo) SetStatus(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// CallAddr returns the si_call_addr field.
+func (s *SignalInfo) CallAddr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetCallAddr mutates the si_call_addr field.
+func (s *SignalInfo) SetCallAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Syscall returns the si_syscall field.
+func (s *SignalInfo) Syscall() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetSyscall mutates the si_syscall field.
+func (s *SignalInfo) SetSyscall(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// Arch returns the si_arch field.
+func (s *SignalInfo) Arch() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[12:16])
+}
+
+// SetArch mutates the si_arch field.
+func (s *SignalInfo) SetArch(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
+}
+
+// SignalContext64 is equivalent to struct sigcontext, the type passed as the
+// second argument to signal handlers set by signal(2).
+type SignalContext64 struct {
+ R8 uint64
+ R9 uint64
+ R10 uint64
+ R11 uint64
+ R12 uint64
+ R13 uint64
+ R14 uint64
+ R15 uint64
+ Rdi uint64
+ Rsi uint64
+ Rbp uint64
+ Rbx uint64
+ Rdx uint64
+ Rax uint64
+ Rcx uint64
+ Rsp uint64
+ Rip uint64
+ Eflags uint64
+ Cs uint16
+ Gs uint16 // always 0 on amd64.
+ Fs uint16 // always 0 on amd64.
+ Ss uint16 // only restored if _UC_STRICT_RESTORE_SS (unsupported).
+ Err uint64
+ Trapno uint64
+ Oldmask linux.SignalSet
+ Cr2 uint64
+ // Pointer to a struct _fpstate.
+ Fpstate uint64
+ Reserved [8]uint64
+}
+
+// Flags for UContext64.Flags.
+const (
+ _UC_FP_XSTATE = 1
+ _UC_SIGCONTEXT_SS = 2
+ _UC_STRICT_RESTORE_SS = 4
+)
+
+// UContext64 is equivalent to ucontext_t on 64-bit x86.
+type UContext64 struct {
+ Flags uint64
+ Link uint64
+ Stack SignalStack
+ MContext SignalContext64
+ Sigset linux.SignalSet
+}
+
+// NewSignalAct implements Context.NewSignalAct.
+func (c *context64) NewSignalAct() NativeSignalAct {
+ return &SignalAct{}
+}
+
+// NewSignalStack implements Context.NewSignalStack.
+func (c *context64) NewSignalStack() NativeSignalStack {
+ return &SignalStack{}
+}
+
+// From Linux 'arch/x86/include/uapi/asm/sigcontext.h' the following is the
+// size of the magic cookie at the end of the xsave frame.
+//
+// NOTE(b/33003106#comment11): Currently we don't actually populate the fpstate
+// on the signal stack.
+const _FP_XSTATE_MAGIC2_SIZE = 4
+
+func (c *context64) fpuFrameSize() (size int, useXsave bool) {
+ size = len(c.x86FPState)
+ if size > 512 {
+ // Make room for the magic cookie at the end of the xsave frame.
+ size += _FP_XSTATE_MAGIC2_SIZE
+ useXsave = true
+ }
+ return size, useXsave
+}
+
+// SignalSetup implements Context.SignalSetup. (Compare to Linux's
+// arch/x86/kernel/signal.c:__setup_rt_frame().)
+func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+ sp := st.Bottom
+
+ // "The 128-byte area beyond the location pointed to by %rsp is considered
+ // to be reserved and shall not be modified by signal or interrupt
+ // handlers. ... leaf functions may use this area for their entire stack
+ // frame, rather than adjusting the stack pointer in the prologue and
+ // epilogue." - AMD64 ABI
+ //
+ // (But this doesn't apply if we're starting at the top of the signal
+ // stack, in which case there is no following stack frame.)
+ if !(alt.IsEnabled() && sp == alt.Top()) {
+ sp -= 128
+ }
+
+ // Allocate space for floating point state on the stack.
+ //
+ // This isn't strictly necessary because we don't actually populate
+ // the fpstate. However we do store the floating point state of the
+ // interrupted thread inside the sentry. Simply accounting for this
+ // space on the user stack naturally caps the amount of memory the
+ // sentry will allocate for this purpose.
+ fpSize, _ := c.fpuFrameSize()
+ sp = (sp - usermem.Addr(fpSize)) & ^usermem.Addr(63)
+
+ // Construct the UContext64 now since we need its size.
+ uc := &UContext64{
+ // No _UC_FP_XSTATE: see Fpstate above.
+ // No _UC_STRICT_RESTORE_SS: we don't allow SS changes.
+ Flags: _UC_SIGCONTEXT_SS,
+ Stack: *alt,
+ MContext: SignalContext64{
+ R8: c.Regs.R8,
+ R9: c.Regs.R9,
+ R10: c.Regs.R10,
+ R11: c.Regs.R11,
+ R12: c.Regs.R12,
+ R13: c.Regs.R13,
+ R14: c.Regs.R14,
+ R15: c.Regs.R15,
+ Rdi: c.Regs.Rdi,
+ Rsi: c.Regs.Rsi,
+ Rbp: c.Regs.Rbp,
+ Rbx: c.Regs.Rbx,
+ Rdx: c.Regs.Rdx,
+ Rax: c.Regs.Rax,
+ Rcx: c.Regs.Rcx,
+ Rsp: c.Regs.Rsp,
+ Rip: c.Regs.Rip,
+ Eflags: c.Regs.Eflags,
+ Cs: uint16(c.Regs.Cs),
+ Ss: uint16(c.Regs.Ss),
+ Oldmask: sigset,
+ },
+ Sigset: sigset,
+ }
+
+ // TODO(gvisor.dev/issue/159): Set SignalContext64.Err, Trapno, and Cr2
+ // based on the fault that caused the signal. For now, leave Err and
+ // Trapno unset and assume CR2 == info.Addr() for SIGSEGVs and
+ // SIGBUSes.
+ if linux.Signal(info.Signo) == linux.SIGSEGV || linux.Signal(info.Signo) == linux.SIGBUS {
+ uc.MContext.Cr2 = info.Addr()
+ }
+
+ // "... the value (%rsp+8) is always a multiple of 16 (...) when
+ // control is transferred to the function entry point." - AMD64 ABI
+ ucSize := binary.Size(uc)
+ if ucSize < 0 {
+ // This can only happen if we've screwed up the definition of
+ // UContext64.
+ panic("can't get size of UContext64")
+ }
+ // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128.
+ frameSize := int(st.Arch.Width()) + ucSize + 128
+ frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8
+ sp = frameBottom + usermem.Addr(frameSize)
+ st.Bottom = sp
+
+ // Prior to proceeding, figure out if the frame will exhaust the range
+ // for the signal stack. This is not allowed, and should immediately
+ // force signal delivery (reverting to the default handler).
+ if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ return syscall.EFAULT
+ }
+
+ // Adjust the code.
+ info.FixSignalCodeForUser()
+
+ // Set up the stack frame.
+ infoAddr, err := st.Push(info)
+ if err != nil {
+ return err
+ }
+ ucAddr, err := st.Push(uc)
+ if err != nil {
+ return err
+ }
+ if act.HasRestorer() {
+ // Push the restorer return address.
+ // Note that this doesn't need to be popped.
+ if _, err := st.Push(usermem.Addr(act.Restorer)); err != nil {
+ return err
+ }
+ } else {
+ // amd64 requires a restorer.
+ return syscall.EFAULT
+ }
+
+ // Set up registers.
+ c.Regs.Rip = act.Handler
+ c.Regs.Rsp = uint64(st.Bottom)
+ c.Regs.Rdi = uint64(info.Signo)
+ c.Regs.Rsi = uint64(infoAddr)
+ c.Regs.Rdx = uint64(ucAddr)
+ c.Regs.Rax = 0
+ c.Regs.Ds = userDS
+ c.Regs.Es = userDS
+ c.Regs.Cs = userCS
+ c.Regs.Ss = userDS
+
+ // Save the thread's floating point state.
+ c.sigFPState = append(c.sigFPState, c.x86FPState)
+
+ // Signal handler gets a clean floating point state.
+ c.x86FPState = newX86FPState()
+
+ return nil
+}
+
+// SignalRestore implements Context.SignalRestore. (Compare to Linux's
+// arch/x86/kernel/signal.c:sys_rt_sigreturn().)
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+ // Copy out the stack frame.
+ var uc UContext64
+ if _, err := st.Pop(&uc); err != nil {
+ return 0, SignalStack{}, err
+ }
+ var info SignalInfo
+ if _, err := st.Pop(&info); err != nil {
+ return 0, SignalStack{}, err
+ }
+
+ // Restore registers.
+ c.Regs.R8 = uc.MContext.R8
+ c.Regs.R9 = uc.MContext.R9
+ c.Regs.R10 = uc.MContext.R10
+ c.Regs.R11 = uc.MContext.R11
+ c.Regs.R12 = uc.MContext.R12
+ c.Regs.R13 = uc.MContext.R13
+ c.Regs.R14 = uc.MContext.R14
+ c.Regs.R15 = uc.MContext.R15
+ c.Regs.Rdi = uc.MContext.Rdi
+ c.Regs.Rsi = uc.MContext.Rsi
+ c.Regs.Rbp = uc.MContext.Rbp
+ c.Regs.Rbx = uc.MContext.Rbx
+ c.Regs.Rdx = uc.MContext.Rdx
+ c.Regs.Rax = uc.MContext.Rax
+ c.Regs.Rcx = uc.MContext.Rcx
+ c.Regs.Rsp = uc.MContext.Rsp
+ c.Regs.Rip = uc.MContext.Rip
+ c.Regs.Eflags = (c.Regs.Eflags & ^eflagsRestorable) | (uc.MContext.Eflags & eflagsRestorable)
+ c.Regs.Cs = uint64(uc.MContext.Cs) | 3
+ // N.B. _UC_STRICT_RESTORE_SS not supported.
+ c.Regs.Orig_rax = math.MaxUint64
+
+ // Restore floating point state.
+ l := len(c.sigFPState)
+ if l > 0 {
+ c.x86FPState = c.sigFPState[l-1]
+ // NOTE(cl/133042258): State save requires that any slice
+ // elements from '[len:cap]' to be zero value.
+ c.sigFPState[l-1] = nil
+ c.sigFPState = c.sigFPState[0 : l-1]
+ } else {
+ // This might happen if sigreturn(2) calls are unbalanced with
+ // respect to signal handler entries. This is not expected so
+ // don't bother to do anything fancy with the floating point
+ // state.
+ log.Infof("sigreturn unable to restore application fpstate")
+ }
+
+ return uc.Sigset, uc.Stack, nil
+}
diff --git a/pkg/sentry/arch/signal_info.go b/pkg/sentry/arch/signal_info.go
new file mode 100644
index 000000000..f93ee8b46
--- /dev/null
+++ b/pkg/sentry/arch/signal_info.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+// Possible values for SignalInfo.Code. These values originate from the Linux
+// kernel's include/uapi/asm-generic/siginfo.h.
+const (
+ // SignalInfoUser (properly SI_USER) indicates that a signal was sent from
+ // a kill() or raise() syscall.
+ SignalInfoUser = 0
+
+ // SignalInfoKernel (properly SI_KERNEL) indicates that the signal was sent
+ // by the kernel.
+ SignalInfoKernel = 0x80
+
+ // SignalInfoTimer (properly SI_TIMER) indicates that the signal was sent
+ // by an expired timer.
+ SignalInfoTimer = -2
+
+ // SignalInfoTkill (properly SI_TKILL) indicates that the signal was sent
+ // from a tkill() or tgkill() syscall.
+ SignalInfoTkill = -6
+
+ // CLD_* codes are only meaningful for SIGCHLD.
+
+ // CLD_EXITED indicates that a task exited.
+ CLD_EXITED = 1
+
+ // CLD_KILLED indicates that a task was killed by a signal.
+ CLD_KILLED = 2
+
+ // CLD_DUMPED indicates that a task was killed by a signal and then dumped
+ // core.
+ CLD_DUMPED = 3
+
+ // CLD_TRAPPED indicates that a task was stopped by ptrace.
+ CLD_TRAPPED = 4
+
+ // CLD_STOPPED indicates that a thread group completed a group stop.
+ CLD_STOPPED = 5
+
+ // CLD_CONTINUED indicates that a group-stopped thread group was continued.
+ CLD_CONTINUED = 6
+
+ // SYS_* codes are only meaningful for SIGSYS.
+
+ // SYS_SECCOMP indicates that a signal originates from seccomp.
+ SYS_SECCOMP = 1
+
+ // TRAP_* codes are only meaningful for SIGTRAP.
+
+ // TRAP_BRKPT indicates a breakpoint trap.
+ TRAP_BRKPT = 1
+)
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
new file mode 100644
index 000000000..a442f9fdc
--- /dev/null
+++ b/pkg/sentry/arch/signal_stack.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build i386 amd64
+
+package arch
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+const (
+ // SignalStackFlagOnStack is possible set on return from getaltstack,
+ // in order to indicate that the thread is currently on the alt stack.
+ SignalStackFlagOnStack = 1
+
+ // SignalStackFlagDisable is a flag to indicate the stack is disabled.
+ SignalStackFlagDisable = 2
+)
+
+// IsEnabled returns true iff this signal stack is marked as enabled.
+func (s SignalStack) IsEnabled() bool {
+ return s.Flags&SignalStackFlagDisable == 0
+}
+
+// Top returns the stack's top address.
+func (s SignalStack) Top() usermem.Addr {
+ return usermem.Addr(s.Addr + s.Size)
+}
+
+// SetOnStack marks this signal stack as in use.
+//
+// Note that there is no corresponding ClearOnStack, and that this should only
+// be called on copies that are serialized to userspace.
+func (s *SignalStack) SetOnStack() {
+ s.Flags |= SignalStackFlagOnStack
+}
+
+// Contains checks if the stack pointer is within this stack.
+func (s *SignalStack) Contains(sp usermem.Addr) bool {
+ return usermem.Addr(s.Addr) < sp && sp <= usermem.Addr(s.Addr+s.Size)
+}
+
+// NativeSignalStack is a type that is equivalent to stack_t in the guest
+// architecture.
+type NativeSignalStack interface {
+ // SerializeFrom copies the data in the host SignalStack s into this
+ // object.
+ SerializeFrom(s *SignalStack)
+
+ // DeserializeTo copies the data in this object into the host SignalStack
+ // s.
+ DeserializeTo(s *SignalStack)
+}
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
new file mode 100644
index 000000000..7e6324e82
--- /dev/null
+++ b/pkg/sentry/arch/stack.go
@@ -0,0 +1,252 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "encoding/binary"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Stack is a simple wrapper around a usermem.IO and an address.
+type Stack struct {
+ // Our arch info.
+ // We use this for automatic Native conversion of usermem.Addrs during
+ // Push() and Pop().
+ Arch Context
+
+ // The interface used to actually copy user memory.
+ IO usermem.IO
+
+ // Our current stack bottom.
+ Bottom usermem.Addr
+}
+
+// Push pushes the given values on to the stack.
+//
+// (This method supports Addrs and treats them as native types.)
+func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) {
+ for _, v := range vals {
+
+ // We convert some types to well-known serializable quanities.
+ var norm interface{}
+
+ // For array types, we will automatically add an appropriate
+ // terminal value. This is done simply to make the interface
+ // easier to use.
+ var term interface{}
+
+ switch v.(type) {
+ case string:
+ norm = []byte(v.(string))
+ term = byte(0)
+ case []int8, []uint8:
+ norm = v
+ term = byte(0)
+ case []int16, []uint16:
+ norm = v
+ term = uint16(0)
+ case []int32, []uint32:
+ norm = v
+ term = uint32(0)
+ case []int64, []uint64:
+ norm = v
+ term = uint64(0)
+ case []usermem.Addr:
+ // Special case: simply push recursively.
+ _, err := s.Push(s.Arch.Native(uintptr(0)))
+ if err != nil {
+ return 0, err
+ }
+ varr := v.([]usermem.Addr)
+ for i := len(varr) - 1; i >= 0; i-- {
+ _, err := s.Push(varr[i])
+ if err != nil {
+ return 0, err
+ }
+ }
+ continue
+ case usermem.Addr:
+ norm = s.Arch.Native(uintptr(v.(usermem.Addr)))
+ default:
+ norm = v
+ }
+
+ if term != nil {
+ _, err := s.Push(term)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ c := binary.Size(norm)
+ if c < 0 {
+ return 0, fmt.Errorf("bad binary.Size for %T", v)
+ }
+ // TODO(b/38173783): Use a real context.Context.
+ n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{})
+ if err != nil || c != n {
+ return 0, err
+ }
+
+ s.Bottom -= usermem.Addr(n)
+ }
+
+ return s.Bottom, nil
+}
+
+// Pop pops the given values off the stack.
+//
+// (This method supports Addrs and treats them as native types.)
+func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) {
+ for _, v := range vals {
+
+ vaddr, isVaddr := v.(*usermem.Addr)
+
+ var n int
+ var err error
+ if isVaddr {
+ value := s.Arch.Native(uintptr(0))
+ // TODO(b/38173783): Use a real context.Context.
+ n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{})
+ *vaddr = usermem.Addr(s.Arch.Value(value))
+ } else {
+ // TODO(b/38173783): Use a real context.Context.
+ n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{})
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ s.Bottom += usermem.Addr(n)
+ }
+
+ return s.Bottom, nil
+}
+
+// Align aligns the stack to the given offset.
+func (s *Stack) Align(offset int) {
+ if s.Bottom%usermem.Addr(offset) != 0 {
+ s.Bottom -= (s.Bottom % usermem.Addr(offset))
+ }
+}
+
+// StackLayout describes the location of the arguments and environment on the
+// stack.
+type StackLayout struct {
+ // ArgvStart is the beginning of the argument vector.
+ ArgvStart usermem.Addr
+
+ // ArgvEnd is the end of the argument vector.
+ ArgvEnd usermem.Addr
+
+ // EnvvStart is the beginning of the environment vector.
+ EnvvStart usermem.Addr
+
+ // EnvvEnd is the end of the environment vector.
+ EnvvEnd usermem.Addr
+}
+
+// Load pushes the given args, env and aux vector to the stack using the
+// well-known format for a new executable. It returns the start and end
+// of the argument and environment vectors.
+func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) {
+ l := StackLayout{}
+
+ // Make sure we start with a 16-byte alignment.
+ s.Align(16)
+
+ // Push the environment vector so the end of the argument vector is adjacent to
+ // the beginning of the environment vector.
+ // While the System V abi for x86_64 does not specify an ordering to the
+ // Information Block (the block holding the arg, env, and aux vectors),
+ // support features like setproctitle(3) naturally expect these segments
+ // to be in this order. See: https://www.uclibc.org/docs/psABI-x86_64.pdf
+ // page 29.
+ l.EnvvEnd = s.Bottom
+ envAddrs := make([]usermem.Addr, len(env))
+ for i := len(env) - 1; i >= 0; i-- {
+ addr, err := s.Push(env[i])
+ if err != nil {
+ return StackLayout{}, err
+ }
+ envAddrs[i] = addr
+ }
+ l.EnvvStart = s.Bottom
+
+ // Push our strings.
+ l.ArgvEnd = s.Bottom
+ argAddrs := make([]usermem.Addr, len(args))
+ for i := len(args) - 1; i >= 0; i-- {
+ addr, err := s.Push(args[i])
+ if err != nil {
+ return StackLayout{}, err
+ }
+ argAddrs[i] = addr
+ }
+ l.ArgvStart = s.Bottom
+
+ // We need to align the arguments appropriately.
+ //
+ // We must finish on a 16-byte alignment, but we'll play it
+ // conservatively and finish at 32-bytes. It would be nice to be able
+ // to call Align here, but unfortunately we need to align the stack
+ // with all the variable sized arrays pushed. So we just need to do
+ // some calculations.
+ argvSize := s.Arch.Width() * uint(len(args)+1)
+ envvSize := s.Arch.Width() * uint(len(env)+1)
+ auxvSize := s.Arch.Width() * 2 * uint(len(aux)+1)
+ total := usermem.Addr(argvSize) + usermem.Addr(envvSize) + usermem.Addr(auxvSize) + usermem.Addr(s.Arch.Width())
+ expectedBottom := s.Bottom - total
+ if expectedBottom%32 != 0 {
+ s.Bottom -= expectedBottom % 32
+ }
+
+ // Push our auxvec.
+ // NOTE: We need an extra zero here per spec.
+ // The Push function will automatically terminate
+ // strings and arrays with a single null value.
+ auxv := make([]usermem.Addr, 0, len(aux))
+ for _, a := range aux {
+ auxv = append(auxv, usermem.Addr(a.Key), a.Value)
+ }
+ auxv = append(auxv, usermem.Addr(0))
+ _, err := s.Push(auxv)
+ if err != nil {
+ return StackLayout{}, err
+ }
+
+ // Push environment.
+ _, err = s.Push(envAddrs)
+ if err != nil {
+ return StackLayout{}, err
+ }
+
+ // Push args.
+ _, err = s.Push(argAddrs)
+ if err != nil {
+ return StackLayout{}, err
+ }
+
+ // Push arg count.
+ _, err = s.Push(usermem.Addr(len(args)))
+ if err != nil {
+ return StackLayout{}, err
+ }
+
+ return l, nil
+}
diff --git a/pkg/sentry/arch/syscalls_amd64.go b/pkg/sentry/arch/syscalls_amd64.go
new file mode 100644
index 000000000..8b4f23007
--- /dev/null
+++ b/pkg/sentry/arch/syscalls_amd64.go
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package arch
+
+const restartSyscallNr = uintptr(219)
+
+// SyscallNo returns the syscall number according to the 64-bit convention.
+func (c *context64) SyscallNo() uintptr {
+ return uintptr(c.Regs.Orig_rax)
+}
+
+// SyscallArgs provides syscall arguments according to the 64-bit convention.
+//
+// Due to the way addresses are mapped for the sentry this binary *must* be
+// built in 64-bit mode. So we can just assume the syscall numbers that come
+// back match the expected host system call numbers.
+func (c *context64) SyscallArgs() SyscallArguments {
+ return SyscallArguments{
+ SyscallArgument{Value: uintptr(c.Regs.Rdi)},
+ SyscallArgument{Value: uintptr(c.Regs.Rsi)},
+ SyscallArgument{Value: uintptr(c.Regs.Rdx)},
+ SyscallArgument{Value: uintptr(c.Regs.R10)},
+ SyscallArgument{Value: uintptr(c.Regs.R8)},
+ SyscallArgument{Value: uintptr(c.Regs.R9)},
+ }
+}
+
+// RestartSyscall implements Context.RestartSyscall.
+func (c *context64) RestartSyscall() {
+ c.Regs.Rip -= SyscallWidth
+ c.Regs.Rax = c.Regs.Orig_rax
+}
+
+// RestartSyscallWithRestartBlock implements Context.RestartSyscallWithRestartBlock.
+func (c *context64) RestartSyscallWithRestartBlock() {
+ c.Regs.Rip -= SyscallWidth
+ c.Regs.Rax = uint64(restartSyscallNr)
+}
diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go
new file mode 100644
index 000000000..d70f3a5c3
--- /dev/null
+++ b/pkg/sentry/context/context.go
@@ -0,0 +1,126 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package context defines the sentry's Context type.
+package context
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/amutex"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+type contextID int
+
+// Globally accessible values from a context. These keys are defined in the
+// context package to resolve dependency cycles by not requiring the caller to
+// import packages usually required to get these information.
+const (
+ // CtxThreadGroupID is the current thread group ID when a context represents
+ // a task context. The value is represented as an int32.
+ CtxThreadGroupID contextID = iota
+)
+
+// ThreadGroupIDFromContext returns the current thread group ID when ctx
+// represents a task context.
+func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
+ if tgid := ctx.Value(CtxThreadGroupID); tgid != nil {
+ return tgid.(int32), true
+ }
+ return 0, false
+}
+
+// A Context represents a thread of execution (hereafter "goroutine" to reflect
+// Go idiosyncrasy). It carries state associated with the goroutine across API
+// boundaries.
+//
+// While Context exists for essentially the same reasons as Go's standard
+// context.Context, the standard type represents the state of an operation
+// rather than that of a goroutine. This is a critical distinction:
+//
+// - Unlike context.Context, which "may be passed to functions running in
+// different goroutines", it is *not safe* to use the same Context in multiple
+// concurrent goroutines.
+//
+// - It is *not safe* to retain a Context passed to a function beyond the scope
+// of that function call.
+//
+// In both cases, values extracted from the Context should be used instead.
+type Context interface {
+ log.Logger
+ amutex.Sleeper
+
+ // UninterruptibleSleepStart indicates the beginning of an uninterruptible
+ // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate
+ // is true and the Context represents a Task, the Task's AddressSpace is
+ // deactivated.
+ UninterruptibleSleepStart(deactivate bool)
+
+ // UninterruptibleSleepFinish indicates the end of an uninterruptible sleep
+ // state that was begun by a previous call to UninterruptibleSleepStart. If
+ // activate is true and the Context represents a Task, the Task's
+ // AddressSpace is activated. Normally activate is the same value as the
+ // deactivate parameter passed to UninterruptibleSleepStart.
+ UninterruptibleSleepFinish(activate bool)
+
+ // Value returns the value associated with this Context for key, or nil if
+ // no value is associated with key. Successive calls to Value with the same
+ // key returns the same result.
+ //
+ // A key identifies a specific value in a Context. Functions that wish to
+ // retrieve values from Context typically allocate a key in a global
+ // variable then use that key as the argument to Context.Value. A key can
+ // be any type that supports equality; packages should define keys as an
+ // unexported type to avoid collisions.
+ Value(key interface{}) interface{}
+}
+
+type logContext struct {
+ log.Logger
+ NoopSleeper
+}
+
+// Value implements Context.Value.
+func (logContext) Value(key interface{}) interface{} {
+ return nil
+}
+
+// NoopSleeper is a noop implementation of amutex.Sleeper and
+// Context.UninterruptibleSleep* methods for anonymous embedding in other types
+// that do not want to notify kernel.Task about sleeps.
+type NoopSleeper struct {
+ amutex.NoopSleeper
+}
+
+// UninterruptibleSleepStart does nothing.
+func (NoopSleeper) UninterruptibleSleepStart(bool) {}
+
+// UninterruptibleSleepFinish does nothing.
+func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
+
+// bgContext is the context returned by context.Background.
+var bgContext = &logContext{Logger: log.Log()}
+
+// Background returns an empty context using the default logger.
+//
+// Users should be wary of using a Background context. Please tag any use with
+// FIXME(b/38173783) and a note to remove this use.
+//
+// Generally, one should use the Task as their context when available, or avoid
+// having to use a context in places where a Task is unavailable.
+//
+// Using a Background context for tests is fine, as long as no values are
+// needed from the context in the tested code paths.
+func Background() Context {
+ return bgContext
+}
diff --git a/pkg/sentry/context/context_state_autogen.go b/pkg/sentry/context/context_state_autogen.go
new file mode 100755
index 000000000..7dd55bfea
--- /dev/null
+++ b/pkg/sentry/context/context_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package context
+
diff --git a/pkg/sentry/control/control.go b/pkg/sentry/control/control.go
new file mode 100644
index 000000000..6060b9b4f
--- /dev/null
+++ b/pkg/sentry/control/control.go
@@ -0,0 +1,17 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package control contains types that expose control server methods, and can
+// be used to configure and interact with a running sandbox process.
+package control
diff --git a/pkg/sentry/control/control_state_autogen.go b/pkg/sentry/control/control_state_autogen.go
new file mode 100755
index 000000000..a1de4bc6d
--- /dev/null
+++ b/pkg/sentry/control/control_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package control
+
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
new file mode 100644
index 000000000..d63916600
--- /dev/null
+++ b/pkg/sentry/control/pprof.go
@@ -0,0 +1,168 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "errors"
+ "runtime"
+ "runtime/pprof"
+ "runtime/trace"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/urpc"
+)
+
+var errNoOutput = errors.New("no output writer provided")
+
+// ProfileOpts contains options for the StartCPUProfile/Goroutine RPC call.
+type ProfileOpts struct {
+ // File is the filesystem path for the profile.
+ File string `json:"path"`
+
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
+}
+
+// Profile includes profile-related RPC stubs. It provides a way to
+// control the built-in pprof facility in sentry via sentryctl.
+//
+// The following options to sentryctl are added:
+//
+// - collect CPU profile on-demand.
+// sentryctl -pid <pid> pprof-cpu-start
+// sentryctl -pid <pid> pprof-cpu-stop
+//
+// - dump out the stack trace of current go routines.
+// sentryctl -pid <pid> pprof-goroutine
+type Profile struct {
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // cpuFile is the current CPU profile output file.
+ cpuFile *fd.FD
+
+ // traceFile is the current execution trace output file.
+ traceFile *fd.FD
+}
+
+// StartCPUProfile is an RPC stub which starts recording the CPU profile in a
+// file.
+func (p *Profile) StartCPUProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+
+ output, err := fd.NewFromFile(o.FilePayload.Files[0])
+ if err != nil {
+ return err
+ }
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Returns an error if profiling is already started.
+ if err := pprof.StartCPUProfile(output); err != nil {
+ output.Close()
+ return err
+ }
+
+ p.cpuFile = output
+ return nil
+}
+
+// StopCPUProfile is an RPC stub which stops the CPU profiling and flush out the
+// profile data. It takes no argument.
+func (p *Profile) StopCPUProfile(_, _ *struct{}) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.cpuFile == nil {
+ return errors.New("CPU profiling not started")
+ }
+
+ pprof.StopCPUProfile()
+ p.cpuFile.Close()
+ p.cpuFile = nil
+ return nil
+}
+
+// HeapProfile generates a heap profile for the sentry.
+func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ runtime.GC() // Get up-to-date statistics.
+ if err := pprof.WriteHeapProfile(output); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Goroutine is an RPC stub which dumps out the stack trace for all running
+// goroutines.
+func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("goroutine").WriteTo(output, 2); err != nil {
+ return err
+ }
+ return nil
+}
+
+// StartTrace is an RPC stub which starts collection of an execution trace.
+func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+
+ output, err := fd.NewFromFile(o.FilePayload.Files[0])
+ if err != nil {
+ return err
+ }
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Returns an error if profiling is already started.
+ if err := trace.Start(output); err != nil {
+ output.Close()
+ return err
+ }
+
+ p.traceFile = output
+ return nil
+}
+
+// StopTrace is an RPC stub which stops collection of an ongoing execution
+// trace and flushes the trace data. It takes no argument.
+func (p *Profile) StopTrace(_, _ *struct{}) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.traceFile == nil {
+ return errors.New("Execution tracing not start")
+ }
+
+ trace.Stop()
+ p.traceFile.Close()
+ p.traceFile = nil
+ return nil
+}
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
new file mode 100644
index 000000000..f7f02a3e1
--- /dev/null
+++ b/pkg/sentry/control/proc.go
@@ -0,0 +1,390 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "sort"
+ "strings"
+ "text/tabwriter"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/host"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/urpc"
+)
+
+// Proc includes task-related functions.
+//
+// At the moment, this is limited to exec support.
+type Proc struct {
+ Kernel *kernel.Kernel
+}
+
+// ExecArgs is the set of arguments to exec.
+type ExecArgs struct {
+ // Filename is the filename to load.
+ //
+ // If this is provided as "", then the file will be guessed via Argv[0].
+ Filename string `json:"filename"`
+
+ // Argv is a list of arguments.
+ Argv []string `json:"argv"`
+
+ // Envv is a list of environment variables.
+ Envv []string `json:"envv"`
+
+ // Root defines the root directory for the new process. A reference on
+ // Root must be held for the lifetime of the ExecArgs. If Root is nil,
+ // it will default to the VFS root.
+ Root *fs.Dirent
+
+ // WorkingDirectory defines the working directory for the new process.
+ WorkingDirectory string `json:"wd"`
+
+ // KUID is the UID to run with in the root user namespace. Defaults to
+ // root if not set explicitly.
+ KUID auth.KUID
+
+ // KGID is the GID to run with in the root user namespace. Defaults to
+ // the root group if not set explicitly.
+ KGID auth.KGID
+
+ // ExtraKGIDs is the list of additional groups to which the user
+ // belongs.
+ ExtraKGIDs []auth.KGID
+
+ // Capabilities is the list of capabilities to give to the process.
+ Capabilities *auth.TaskCapabilities
+
+ // StdioIsPty indicates that FDs 0, 1, and 2 are connected to a host
+ // pty FD.
+ StdioIsPty bool
+
+ // FilePayload determines the files to give to the new process.
+ urpc.FilePayload
+
+ // ContainerID is the container for the process being executed.
+ ContainerID string
+}
+
+// String prints the arguments as a string.
+func (args ExecArgs) String() string {
+ a := make([]string, len(args.Argv))
+ copy(a, args.Argv)
+ if args.Filename != "" {
+ a[0] = args.Filename
+ }
+ return strings.Join(a, " ")
+}
+
+// Exec runs a new task.
+func (proc *Proc) Exec(args *ExecArgs, waitStatus *uint32) error {
+ newTG, _, _, err := proc.execAsync(args)
+ if err != nil {
+ return err
+ }
+
+ // Wait for completion.
+ newTG.WaitExited()
+ *waitStatus = newTG.ExitStatus().Status()
+ return nil
+}
+
+// ExecAsync runs a new task, but doesn't wait for it to finish. It is defined
+// as a function rather than a method to avoid exposing execAsync as an RPC.
+func ExecAsync(proc *Proc, args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, error) {
+ return proc.execAsync(args)
+}
+
+// execAsync runs a new task, but doesn't wait for it to finish. It returns the
+// newly created thread group and its PID. If the stdio FDs are TTYs, then a
+// TTYFileOperations that wraps the TTY is also returned.
+func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, error) {
+ // Import file descriptors.
+ l := limits.NewLimitSet()
+ fdm := proc.Kernel.NewFDMap()
+ defer fdm.DecRef()
+
+ // No matter what happens, we should close all files in the FilePayload
+ // before returning. Any files that are imported will be duped.
+ defer func() {
+ for _, f := range args.FilePayload.Files {
+ f.Close()
+ }
+ }()
+
+ creds := auth.NewUserCredentials(
+ args.KUID,
+ args.KGID,
+ args.ExtraKGIDs,
+ args.Capabilities,
+ proc.Kernel.RootUserNamespace())
+
+ initArgs := kernel.CreateProcessArgs{
+ Filename: args.Filename,
+ Argv: args.Argv,
+ Envv: args.Envv,
+ WorkingDirectory: args.WorkingDirectory,
+ Root: args.Root,
+ Credentials: creds,
+ FDMap: fdm,
+ Umask: 0022,
+ Limits: l,
+ MaxSymlinkTraversals: linux.MaxSymlinkTraversals,
+ UTSNamespace: proc.Kernel.RootUTSNamespace(),
+ IPCNamespace: proc.Kernel.RootIPCNamespace(),
+ AbstractSocketNamespace: proc.Kernel.RootAbstractSocketNamespace(),
+ ContainerID: args.ContainerID,
+ }
+ if initArgs.Root != nil {
+ // initArgs must hold a reference on Root. This ref is dropped
+ // in CreateProcess.
+ initArgs.Root.IncRef()
+ }
+ ctx := initArgs.NewContext(proc.Kernel)
+
+ if initArgs.Filename == "" {
+ // Get the full path to the filename from the PATH env variable.
+ paths := fs.GetPath(initArgs.Envv)
+ f, err := proc.Kernel.RootMountNamespace().ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
+ }
+ initArgs.Filename = f
+ }
+
+ mounter := fs.FileOwnerFromContext(ctx)
+
+ var ttyFile *fs.File
+ for appFD, hostFile := range args.FilePayload.Files {
+ var appFile *fs.File
+
+ if args.StdioIsPty && appFD < 3 {
+ // Import the file as a host TTY file.
+ if ttyFile == nil {
+ var err error
+ appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), mounter, true /* isTTY */)
+ if err != nil {
+ return nil, 0, nil, err
+ }
+ defer appFile.DecRef()
+
+ // Remember this in the TTY file, as we will
+ // use it for the other stdio FDs.
+ ttyFile = appFile
+ } else {
+ // Re-use the existing TTY file, as all three
+ // stdio FDs must point to the same fs.File in
+ // order to share TTY state, specifically the
+ // foreground process group id.
+ appFile = ttyFile
+ }
+ } else {
+ // Import the file as a regular host file.
+ var err error
+ appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), mounter, false /* isTTY */)
+ if err != nil {
+ return nil, 0, nil, err
+ }
+ defer appFile.DecRef()
+ }
+
+ // Add the file to the FD map.
+ if err := fdm.NewFDAt(kdefs.FD(appFD), appFile, kernel.FDFlags{}, l); err != nil {
+ return nil, 0, nil, err
+ }
+ }
+
+ tg, tid, err := proc.Kernel.CreateProcess(initArgs)
+ if err != nil {
+ return nil, 0, nil, err
+ }
+
+ var ttyFileOps *host.TTYFileOperations
+ if ttyFile != nil {
+ // Set the foreground process group on the TTY before starting
+ // the process.
+ ttyFileOps = ttyFile.FileOperations.(*host.TTYFileOperations)
+ ttyFileOps.InitForegroundProcessGroup(tg.ProcessGroup())
+ }
+
+ // Start the newly created process.
+ proc.Kernel.StartProcess(tg)
+
+ return tg, tid, ttyFileOps, nil
+}
+
+// PsArgs is the set of arguments to ps.
+type PsArgs struct {
+ // JSON will force calls to Ps to return the result as a JSON payload.
+ JSON bool
+}
+
+// Ps provides a process listing for the running kernel.
+func (proc *Proc) Ps(args *PsArgs, out *string) error {
+ var p []*Process
+ if e := Processes(proc.Kernel, "", &p); e != nil {
+ return e
+ }
+ if !args.JSON {
+ *out = ProcessListToTable(p)
+ } else {
+ s, e := ProcessListToJSON(p)
+ if e != nil {
+ return e
+ }
+ *out = s
+ }
+ return nil
+}
+
+// Process contains information about a single process in a Sandbox.
+// TODO(b/117881927): Implement TTY field.
+type Process struct {
+ UID auth.KUID `json:"uid"`
+ PID kernel.ThreadID `json:"pid"`
+ // Parent PID
+ PPID kernel.ThreadID `json:"ppid"`
+ // Processor utilization
+ C int32 `json:"c"`
+ // Start time
+ STime string `json:"stime"`
+ // CPU time
+ Time string `json:"time"`
+ // Executable shortname (e.g. "sh" for /bin/sh)
+ Cmd string `json:"cmd"`
+}
+
+// ProcessListToTable prints a table with the following format:
+// UID PID PPID C STIME TIME CMD
+// 0 1 0 0 14:04 505262ns tail
+func ProcessListToTable(pl []*Process) string {
+ var buf bytes.Buffer
+ tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0)
+ fmt.Fprint(tw, "UID\tPID\tPPID\tC\tSTIME\tTIME\tCMD")
+ for _, d := range pl {
+ fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s",
+ d.UID,
+ d.PID,
+ d.PPID,
+ d.C,
+ d.STime,
+ d.Time,
+ d.Cmd)
+ }
+ tw.Flush()
+ return buf.String()
+}
+
+// ProcessListToJSON will return the JSON representation of ps.
+func ProcessListToJSON(pl []*Process) (string, error) {
+ b, err := json.Marshal(pl)
+ if err != nil {
+ return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err)
+ }
+ return string(b), nil
+}
+
+// PrintPIDsJSON prints a JSON object containing only the PIDs in pl. This
+// behavior is the same as runc's.
+func PrintPIDsJSON(pl []*Process) (string, error) {
+ pids := make([]kernel.ThreadID, 0, len(pl))
+ for _, d := range pl {
+ pids = append(pids, d.PID)
+ }
+ b, err := json.Marshal(pids)
+ if err != nil {
+ return "", fmt.Errorf("couldn't marshal PIDs %v: %v", pids, err)
+ }
+ return string(b), nil
+}
+
+// Processes retrieves information about processes running in the sandbox with
+// the given container id. All processes are returned if 'containerID' is empty.
+func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error {
+ ts := k.TaskSet()
+ now := k.RealtimeClock().Now()
+ for _, tg := range ts.Root.ThreadGroups() {
+ pid := ts.Root.IDOfThreadGroup(tg)
+ // If tg has already been reaped ignore it.
+ if pid == 0 {
+ continue
+ }
+ if containerID != "" && containerID != tg.Leader().ContainerID() {
+ continue
+ }
+
+ ppid := kernel.ThreadID(0)
+ if p := tg.Leader().Parent(); p != nil {
+ ppid = ts.Root.IDOfThreadGroup(p.ThreadGroup())
+ }
+ *out = append(*out, &Process{
+ UID: tg.Leader().Credentials().EffectiveKUID,
+ PID: pid,
+ PPID: ppid,
+ STime: formatStartTime(now, tg.Leader().StartTime()),
+ C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now),
+ Time: tg.CPUStats().SysTime.String(),
+ Cmd: tg.Leader().Name(),
+ })
+ }
+ sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID })
+ return nil
+}
+
+// formatStartTime formats startTime depending on the current time:
+// - If startTime was today, HH:MM is used.
+// - If startTime was not today but was this year, MonDD is used (e.g. Jan02)
+// - If startTime was not this year, the year is used.
+func formatStartTime(now, startTime ktime.Time) string {
+ nowS, nowNs := now.Unix()
+ n := time.Unix(nowS, nowNs)
+ startTimeS, startTimeNs := startTime.Unix()
+ st := time.Unix(startTimeS, startTimeNs)
+ format := "15:04"
+ if st.YearDay() != n.YearDay() {
+ format = "Jan02"
+ }
+ if st.Year() != n.Year() {
+ format = "2006"
+ }
+ return st.Format(format)
+}
+
+func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 {
+ // Note: In procps, there is an option to include child CPU stats. As
+ // it is disabled by default, we do not include them.
+ total := stats.UserTime + stats.SysTime
+ lifetime := now.Sub(startTime)
+ if lifetime <= 0 {
+ return 0
+ }
+ percentCPU := total * 100 / lifetime
+ // Cap at 99% since procps does the same.
+ if percentCPU > 99 {
+ percentCPU = 99
+ }
+ return int32(percentCPU)
+}
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
new file mode 100644
index 000000000..11efcaba1
--- /dev/null
+++ b/pkg/sentry/control/state.go
@@ -0,0 +1,73 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "errors"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/state"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/watchdog"
+ "gvisor.googlesource.com/gvisor/pkg/urpc"
+)
+
+// ErrInvalidFiles is returned when the urpc call to Save does not include an
+// appropriate file payload (e.g. there is no output file!).
+var ErrInvalidFiles = errors.New("exactly one file must be provided")
+
+// State includes state-related functions.
+type State struct {
+ Kernel *kernel.Kernel
+ Watchdog *watchdog.Watchdog
+}
+
+// SaveOpts contains options for the Save RPC call.
+type SaveOpts struct {
+ // Key is used for state integrity check.
+ Key []byte `json:"key"`
+
+ // Metadata is the set of metadata to prepend to the state file.
+ Metadata map[string]string `json:"metadata"`
+
+ // FilePayload contains the destination for the state.
+ urpc.FilePayload
+}
+
+// Save saves the running system.
+func (s *State) Save(o *SaveOpts, _ *struct{}) error {
+ // Create an output stream.
+ if len(o.FilePayload.Files) != 1 {
+ return ErrInvalidFiles
+ }
+ defer o.FilePayload.Files[0].Close()
+
+ // Save to the first provided stream.
+ saveOpts := state.SaveOpts{
+ Destination: o.FilePayload.Files[0],
+ Key: o.Key,
+ Metadata: o.Metadata,
+ Callback: func(err error) {
+ if err == nil {
+ log.Infof("Save succeeded: exiting...")
+ } else {
+ log.Warningf("Save failed: exiting...")
+ s.Kernel.SetSaveError(err)
+ }
+ s.Kernel.Kill(kernel.ExitStatus{})
+ },
+ }
+ return saveOpts.Save(s.Kernel, s.Watchdog)
+}
diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go
new file mode 100644
index 000000000..458d03b30
--- /dev/null
+++ b/pkg/sentry/device/device.go
@@ -0,0 +1,266 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package device defines reserved virtual kernel devices and structures
+// for managing them.
+package device
+
+import (
+ "bytes"
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// Registry tracks all simple devices and related state on the system for
+// save/restore.
+//
+// The set of devices across save/restore must remain consistent. That is, no
+// devices may be created or removed on restore relative to the saved
+// system. Practically, this means do not create new devices specifically as
+// part of restore.
+//
+// +stateify savable
+type Registry struct {
+ // lastAnonDeviceMinor is the last minor device number used for an anonymous
+ // device. Must be accessed atomically.
+ lastAnonDeviceMinor uint64
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ devices map[ID]*Device
+}
+
+// SimpleDevices is the system-wide simple device registry. This is
+// saved/restored by kernel.Kernel, but defined here to allow access without
+// depending on the kernel package. See kernel.Kernel.deviceRegistry.
+var SimpleDevices = newRegistry()
+
+func newRegistry() *Registry {
+ return &Registry{
+ devices: make(map[ID]*Device),
+ }
+}
+
+// newAnonID assigns a major and minor number to an anonymous device ID.
+func (r *Registry) newAnonID() ID {
+ return ID{
+ // Anon devices always have a major number of 0.
+ Major: 0,
+ // Use the next minor number.
+ Minor: atomic.AddUint64(&r.lastAnonDeviceMinor, 1),
+ }
+}
+
+// newAnonDevice allocates a new anonymous device with a unique minor device
+// number, and registers it with r.
+func (r *Registry) newAnonDevice() *Device {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ d := &Device{
+ ID: r.newAnonID(),
+ }
+ r.devices[d.ID] = d
+ return d
+}
+
+// LoadFrom initializes the internal state of all devices in r from other. The
+// set of devices in both registries must match. Devices may not be created or
+// destroyed across save/restore.
+func (r *Registry) LoadFrom(other *Registry) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ other.mu.Lock()
+ defer other.mu.Unlock()
+ if len(r.devices) != len(other.devices) {
+ panic(fmt.Sprintf("Devices were added or removed when restoring the registry:\nnew:\n%+v\nold:\n%+v", r.devices, other.devices))
+ }
+ for id, otherD := range other.devices {
+ ourD, ok := r.devices[id]
+ if !ok {
+ panic(fmt.Sprintf("Device %+v could not be restored as it wasn't defined in the new registry", otherD))
+ }
+ ourD.loadFrom(otherD)
+ }
+ atomic.StoreUint64(&r.lastAnonDeviceMinor, atomic.LoadUint64(&other.lastAnonDeviceMinor))
+}
+
+// ID identifies a device.
+//
+// +stateify savable
+type ID struct {
+ Major uint64
+ Minor uint64
+}
+
+// DeviceID formats a major and minor device number into a standard device number.
+func (i *ID) DeviceID() uint64 {
+ return uint64(linux.MakeDeviceID(uint16(i.Major), uint32(i.Minor)))
+}
+
+// NewAnonDevice creates a new anonymous device. Packages that require an anonymous
+// device should initialize the device in a global variable in a file called device.go:
+//
+// var myDevice = device.NewAnonDevice()
+func NewAnonDevice() *Device {
+ return SimpleDevices.newAnonDevice()
+}
+
+// NewAnonMultiDevice creates a new multi-keyed anonymous device. Packages that require
+// a multi-key anonymous device should initialize the device in a global variable in a
+// file called device.go:
+//
+// var myDevice = device.NewAnonMultiDevice()
+func NewAnonMultiDevice() *MultiDevice {
+ return &MultiDevice{
+ ID: SimpleDevices.newAnonID(),
+ }
+}
+
+// Device is a simple virtual kernel device.
+//
+// +stateify savable
+type Device struct {
+ ID
+
+ // last is the last generated inode.
+ last uint64
+}
+
+// loadFrom initializes d from other. The IDs of both devices must match.
+func (d *Device) loadFrom(other *Device) {
+ if d.ID != other.ID {
+ panic(fmt.Sprintf("Attempting to initialize a device %+v from %+v, but device IDs don't match", d, other))
+ }
+ atomic.StoreUint64(&d.last, atomic.LoadUint64(&other.last))
+}
+
+// NextIno generates a new inode number
+func (d *Device) NextIno() uint64 {
+ return atomic.AddUint64(&d.last, 1)
+}
+
+// MultiDeviceKey provides a hashable key for a MultiDevice. The key consists
+// of a raw device and inode for a resource, which must consistently identify
+// the unique resource. It may optionally include a secondary device if
+// appropriate.
+//
+// Note that using the path is not enough, because filesystems may rename a file
+// to a different backing resource, at which point the path points to a different
+// entity. Using only the inode is also not enough because the inode is assumed
+// to be unique only within the device on which the resource exists.
+type MultiDeviceKey struct {
+ Device uint64
+ SecondaryDevice string
+ Inode uint64
+}
+
+// String stringifies the key.
+func (m MultiDeviceKey) String() string {
+ return fmt.Sprintf("key{device: %d, sdevice: %s, inode: %d}", m.Device, m.SecondaryDevice, m.Inode)
+}
+
+// MultiDevice allows for remapping resources that come from a variety of raw
+// devices into a single device. The device ID should be one of the static
+// Device IDs above and cannot be reused.
+type MultiDevice struct {
+ ID
+
+ mu sync.Mutex
+ last uint64
+ cache map[MultiDeviceKey]uint64
+ rcache map[uint64]MultiDeviceKey
+}
+
+// String stringifies MultiDevice.
+func (m *MultiDevice) String() string {
+ buf := bytes.NewBuffer(nil)
+ buf.WriteString("cache{")
+ for k, v := range m.cache {
+ buf.WriteString(fmt.Sprintf("%s -> %d, ", k, v))
+ }
+ buf.WriteString("}")
+ return buf.String()
+}
+
+// Map maps a raw device and inode into the inode space of MultiDevice,
+// returning a virtualized inode. Raw devices and inodes can be reused;
+// in this case, the same virtual inode will be returned.
+func (m *MultiDevice) Map(key MultiDeviceKey) uint64 {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.cache == nil {
+ m.cache = make(map[MultiDeviceKey]uint64)
+ m.rcache = make(map[uint64]MultiDeviceKey)
+ }
+
+ id, ok := m.cache[key]
+ if ok {
+ return id
+ }
+ // Step over reserved entries that may have been loaded.
+ idx := m.last + 1
+ for {
+ if _, ok := m.rcache[idx]; !ok {
+ break
+ }
+ idx++
+ }
+ // We found a non-reserved entry, use it.
+ m.last = idx
+ m.cache[key] = m.last
+ m.rcache[m.last] = key
+ return m.last
+}
+
+// Load loads a raw device and inode into MultiDevice inode mappings
+// with value as the virtual inode.
+//
+// By design, inodes start from 1 and continue until max uint64. This means
+// that the zero value, which is often the uninitialized value, can be rejected
+// as invalid.
+func (m *MultiDevice) Load(key MultiDeviceKey, value uint64) bool {
+ // Reject the uninitialized value; see comment above.
+ if value == 0 {
+ return false
+ }
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.cache == nil {
+ m.cache = make(map[MultiDeviceKey]uint64)
+ m.rcache = make(map[uint64]MultiDeviceKey)
+ }
+
+ if val, exists := m.cache[key]; exists && val != value {
+ return false
+ }
+ if k, exists := m.rcache[value]; exists && k != key {
+ // Should never happen.
+ panic("MultiDevice's caches are inconsistent")
+ }
+
+ // Cache value at key.
+ m.cache[key] = value
+
+ // Prevent value from being used by new inode mappings.
+ m.rcache[value] = key
+
+ return true
+}
diff --git a/pkg/sentry/device/device_state_autogen.go b/pkg/sentry/device/device_state_autogen.go
new file mode 100755
index 000000000..33cc93f3f
--- /dev/null
+++ b/pkg/sentry/device/device_state_autogen.go
@@ -0,0 +1,52 @@
+// automatically generated by stateify.
+
+package device
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Registry) beforeSave() {}
+func (x *Registry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("lastAnonDeviceMinor", &x.lastAnonDeviceMinor)
+ m.Save("devices", &x.devices)
+}
+
+func (x *Registry) afterLoad() {}
+func (x *Registry) load(m state.Map) {
+ m.Load("lastAnonDeviceMinor", &x.lastAnonDeviceMinor)
+ m.Load("devices", &x.devices)
+}
+
+func (x *ID) beforeSave() {}
+func (x *ID) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Major", &x.Major)
+ m.Save("Minor", &x.Minor)
+}
+
+func (x *ID) afterLoad() {}
+func (x *ID) load(m state.Map) {
+ m.Load("Major", &x.Major)
+ m.Load("Minor", &x.Minor)
+}
+
+func (x *Device) beforeSave() {}
+func (x *Device) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ID", &x.ID)
+ m.Save("last", &x.last)
+}
+
+func (x *Device) afterLoad() {}
+func (x *Device) load(m state.Map) {
+ m.Load("ID", &x.ID)
+ m.Load("last", &x.last)
+}
+
+func init() {
+ state.Register("device.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load})
+ state.Register("device.ID", (*ID)(nil), state.Fns{Save: (*ID).save, Load: (*ID).load})
+ state.Register("device.Device", (*Device)(nil), state.Fns{Save: (*Device).save, Load: (*Device).load})
+}
diff --git a/pkg/sentry/fs/anon/anon.go b/pkg/sentry/fs/anon/anon.go
new file mode 100644
index 000000000..a6ea8b9e7
--- /dev/null
+++ b/pkg/sentry/fs/anon/anon.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package anon implements an anonymous inode, useful for implementing
+// inodes for pseudo filesystems.
+package anon
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// NewInode constructs an anonymous Inode that is not associated
+// with any real filesystem. Some types depend on completely pseudo
+// "anon" inodes (eventfds, epollfds, etc).
+func NewInode(ctx context.Context) *fs.Inode {
+ iops := &fsutil.SimpleFileInode{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ }, linux.ANON_INODE_FS_MAGIC),
+ }
+ return fs.NewInode(iops, fs.NewPseudoMountSource(), fs.StableAttr{
+ Type: fs.Anonymous,
+ DeviceID: PseudoDevice.DeviceID(),
+ InodeID: PseudoDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ })
+}
diff --git a/pkg/sentry/fs/anon/anon_state_autogen.go b/pkg/sentry/fs/anon/anon_state_autogen.go
new file mode 100755
index 000000000..fcb914212
--- /dev/null
+++ b/pkg/sentry/fs/anon/anon_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package anon
+
diff --git a/pkg/sentry/fs/anon/device.go b/pkg/sentry/fs/anon/device.go
new file mode 100644
index 000000000..5927bd11e
--- /dev/null
+++ b/pkg/sentry/fs/anon/device.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package anon
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+)
+
+// PseudoDevice is the device on which all anonymous inodes reside.
+var PseudoDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/ashmem/area.go b/pkg/sentry/fs/ashmem/area.go
new file mode 100644
index 000000000..b4b0cc08b
--- /dev/null
+++ b/pkg/sentry/fs/ashmem/area.go
@@ -0,0 +1,308 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ashmem
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // namePrefix is the name prefix assumed and forced by the Linux implementation.
+ namePrefix = "dev/ashmem"
+
+ // nameLen is the maximum name length.
+ nameLen = 256
+)
+
+// Area implements fs.FileOperations.
+//
+// +stateify savable
+type Area struct {
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ ad *Device
+
+ // mu protects fields below.
+ mu sync.Mutex `state:"nosave"`
+ tmpfsFile *fs.File
+ name string
+ size uint64
+ perms usermem.AccessType
+ pb *PinBoard
+}
+
+// Release implements fs.FileOperations.Release.
+func (a *Area) Release() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.tmpfsFile != nil {
+ a.tmpfsFile.DecRef()
+ a.tmpfsFile = nil
+ }
+}
+
+// Seek implements fs.FileOperations.Seek.
+func (a *Area) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.size == 0 {
+ return 0, syserror.EINVAL
+ }
+ if a.tmpfsFile == nil {
+ return 0, syserror.EBADF
+ }
+ return a.tmpfsFile.FileOperations.Seek(ctx, file, whence, offset)
+}
+
+// Read implements fs.FileOperations.Read.
+func (a *Area) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.size == 0 {
+ return 0, nil
+ }
+ if a.tmpfsFile == nil {
+ return 0, syserror.EBADF
+ }
+ return a.tmpfsFile.FileOperations.Read(ctx, file, dst, offset)
+}
+
+// Write implements fs.FileOperations.Write.
+func (a *Area) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (a *Area) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.size == 0 {
+ return syserror.EINVAL
+ }
+
+ if !a.perms.SupersetOf(opts.Perms) {
+ return syserror.EPERM
+ }
+ opts.MaxPerms = opts.MaxPerms.Intersect(a.perms)
+
+ if a.tmpfsFile == nil {
+ tmpfsInodeOps := tmpfs.NewInMemoryFile(ctx, usage.Tmpfs, fs.UnstableAttr{})
+ tmpfsInode := fs.NewInode(tmpfsInodeOps, fs.NewPseudoMountSource(), fs.StableAttr{})
+ dirent := fs.NewDirent(tmpfsInode, namePrefix+"/"+a.name)
+ tmpfsFile, err := tmpfsInode.GetFile(ctx, dirent, fs.FileFlags{Read: true, Write: true})
+ // Drop the extra reference on the Dirent.
+ dirent.DecRef()
+
+ if err != nil {
+ return err
+ }
+
+ // Truncate to the size set by ASHMEM_SET_SIZE ioctl.
+ err = tmpfsInodeOps.Truncate(ctx, tmpfsInode, int64(a.size))
+ if err != nil {
+ return err
+ }
+ a.tmpfsFile = tmpfsFile
+ a.pb = NewPinBoard()
+ }
+
+ return a.tmpfsFile.ConfigureMMap(ctx, opts)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (a *Area) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Switch on ioctl request.
+ switch args[1].Uint() {
+ case linux.AshmemSetNameIoctl:
+ name, err := usermem.CopyStringIn(ctx, io, args[2].Pointer(), nameLen-1, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ // Cannot set name for already mapped ashmem.
+ if a.tmpfsFile != nil {
+ return 0, syserror.EINVAL
+ }
+ a.name = name
+ return 0, nil
+
+ case linux.AshmemGetNameIoctl:
+ a.mu.Lock()
+ var local []byte
+ if a.name != "" {
+ nameLen := len([]byte(a.name))
+ local = make([]byte, nameLen, nameLen+1)
+ copy(local, []byte(a.name))
+ local = append(local, 0)
+ } else {
+ nameLen := len([]byte(namePrefix))
+ local = make([]byte, nameLen, nameLen+1)
+ copy(local, []byte(namePrefix))
+ local = append(local, 0)
+ }
+ a.mu.Unlock()
+
+ if _, err := io.CopyOut(ctx, args[2].Pointer(), local, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, syserror.EFAULT
+ }
+ return 0, nil
+
+ case linux.AshmemSetSizeIoctl:
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ // Cannot set size for already mapped ashmem.
+ if a.tmpfsFile != nil {
+ return 0, syserror.EINVAL
+ }
+ a.size = uint64(args[2].SizeT())
+ return 0, nil
+
+ case linux.AshmemGetSizeIoctl:
+ return uintptr(a.size), nil
+
+ case linux.AshmemPinIoctl, linux.AshmemUnpinIoctl, linux.AshmemGetPinStatusIoctl:
+ // Locking and unlocking is ok since once tmpfsFile is set, it won't be nil again
+ // even after unmapping! Unlocking is needed in order to avoid a deadlock on
+ // usermem.CopyObjectIn.
+
+ // Cannot execute pin-related ioctls before mapping.
+ a.mu.Lock()
+ if a.tmpfsFile == nil {
+ a.mu.Unlock()
+ return 0, syserror.EINVAL
+ }
+ a.mu.Unlock()
+
+ var pin linux.AshmemPin
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pin, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, syserror.EFAULT
+ }
+
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.pinOperation(pin, args[1].Uint())
+
+ case linux.AshmemPurgeAllCachesIoctl:
+ return 0, nil
+
+ case linux.AshmemSetProtMaskIoctl:
+ prot := uint64(args[2].ModeT())
+ perms := usermem.AccessType{
+ Read: prot&linux.PROT_READ != 0,
+ Write: prot&linux.PROT_WRITE != 0,
+ Execute: prot&linux.PROT_EXEC != 0,
+ }
+
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ // Can only narrow prot mask.
+ if !a.perms.SupersetOf(perms) {
+ return 0, syserror.EINVAL
+ }
+
+ // TODO(b/30946773,gvisor.dev/issue/153): If personality flag
+ // READ_IMPLIES_EXEC is set, set PROT_EXEC if PORT_READ is set.
+
+ a.perms = perms
+ return 0, nil
+
+ case linux.AshmemGetProtMaskIoctl:
+ return uintptr(a.perms.Prot()), nil
+ default:
+ // Ioctls irrelevant to Ashmem.
+ return 0, syserror.EINVAL
+ }
+}
+
+// pinOperation should only be called while holding a.mu.
+func (a *Area) pinOperation(pin linux.AshmemPin, op uint32) (uintptr, error) {
+ // Page-align a.size for checks.
+ pageAlignedSize, ok := usermem.Addr(a.size).RoundUp()
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+ // Len 0 means everything onward.
+ if pin.Len == 0 {
+ pin.Len = uint32(pageAlignedSize) - pin.Offset
+ }
+ // Both Offset and Len have to be page-aligned.
+ if pin.Offset%uint32(usermem.PageSize) != 0 {
+ return 0, syserror.EINVAL
+ }
+ if pin.Len%uint32(usermem.PageSize) != 0 {
+ return 0, syserror.EINVAL
+ }
+ // Adding Offset and Len must not cause an uint32 overflow.
+ if end := pin.Offset + pin.Len; end < pin.Offset {
+ return 0, syserror.EINVAL
+ }
+ // Pin range must not exceed a's size.
+ if uint32(pageAlignedSize) < pin.Offset+pin.Len {
+ return 0, syserror.EINVAL
+ }
+ // Handle each operation.
+ r := RangeFromAshmemPin(pin)
+ switch op {
+ case linux.AshmemPinIoctl:
+ if a.pb.PinRange(r) {
+ return linux.AshmemWasPurged, nil
+ }
+ return linux.AshmemNotPurged, nil
+
+ case linux.AshmemUnpinIoctl:
+ // TODO(b/30946773): Implement purge on unpin.
+ a.pb.UnpinRange(r)
+ return 0, nil
+
+ case linux.AshmemGetPinStatusIoctl:
+ if a.pb.RangePinnedStatus(r) {
+ return linux.AshmemIsPinned, nil
+ }
+ return linux.AshmemIsUnpinned, nil
+
+ default:
+ panic("unreachable")
+ }
+
+}
diff --git a/pkg/sentry/fs/ashmem/ashmem_state_autogen.go b/pkg/sentry/fs/ashmem/ashmem_state_autogen.go
new file mode 100755
index 000000000..c4469b13a
--- /dev/null
+++ b/pkg/sentry/fs/ashmem/ashmem_state_autogen.go
@@ -0,0 +1,123 @@
+// automatically generated by stateify.
+
+package ashmem
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Area) beforeSave() {}
+func (x *Area) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ad", &x.ad)
+ m.Save("tmpfsFile", &x.tmpfsFile)
+ m.Save("name", &x.name)
+ m.Save("size", &x.size)
+ m.Save("perms", &x.perms)
+ m.Save("pb", &x.pb)
+}
+
+func (x *Area) afterLoad() {}
+func (x *Area) load(m state.Map) {
+ m.Load("ad", &x.ad)
+ m.Load("tmpfsFile", &x.tmpfsFile)
+ m.Load("name", &x.name)
+ m.Load("size", &x.size)
+ m.Load("perms", &x.perms)
+ m.Load("pb", &x.pb)
+}
+
+func (x *Device) beforeSave() {}
+func (x *Device) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *Device) afterLoad() {}
+func (x *Device) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *PinBoard) beforeSave() {}
+func (x *PinBoard) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Set", &x.Set)
+}
+
+func (x *PinBoard) afterLoad() {}
+func (x *PinBoard) load(m state.Map) {
+ m.Load("Set", &x.Set)
+}
+
+func (x *Range) beforeSave() {}
+func (x *Range) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *Range) afterLoad() {}
+func (x *Range) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func (x *Set) beforeSave() {}
+func (x *Set) save(m state.Map) {
+ x.beforeSave()
+ var root *SegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *Set) afterLoad() {}
+func (x *Set) load(m state.Map) {
+ m.LoadValue("root", new(*SegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*SegmentDataSlices)) })
+}
+
+func (x *node) beforeSave() {}
+func (x *node) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *node) afterLoad() {}
+func (x *node) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *SegmentDataSlices) beforeSave() {}
+func (x *SegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *SegmentDataSlices) afterLoad() {}
+func (x *SegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func init() {
+ state.Register("ashmem.Area", (*Area)(nil), state.Fns{Save: (*Area).save, Load: (*Area).load})
+ state.Register("ashmem.Device", (*Device)(nil), state.Fns{Save: (*Device).save, Load: (*Device).load})
+ state.Register("ashmem.PinBoard", (*PinBoard)(nil), state.Fns{Save: (*PinBoard).save, Load: (*PinBoard).load})
+ state.Register("ashmem.Range", (*Range)(nil), state.Fns{Save: (*Range).save, Load: (*Range).load})
+ state.Register("ashmem.Set", (*Set)(nil), state.Fns{Save: (*Set).save, Load: (*Set).load})
+ state.Register("ashmem.node", (*node)(nil), state.Fns{Save: (*node).save, Load: (*node).load})
+ state.Register("ashmem.SegmentDataSlices", (*SegmentDataSlices)(nil), state.Fns{Save: (*SegmentDataSlices).save, Load: (*SegmentDataSlices).load})
+}
diff --git a/pkg/sentry/fs/ashmem/device.go b/pkg/sentry/fs/ashmem/device.go
new file mode 100644
index 000000000..22e1530e9
--- /dev/null
+++ b/pkg/sentry/fs/ashmem/device.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ashmem implements Android ashmem module (Anonymus Shared Memory).
+package ashmem
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Device implements fs.InodeOperations.
+//
+// +stateify savable
+type Device struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*Device)(nil)
+
+// NewDevice creates and intializes a Device structure.
+func NewDevice(ctx context.Context, owner fs.FileOwner, fp fs.FilePermissions) *Device {
+ return &Device{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fp, linux.ANON_INODE_FS_MAGIC),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (ad *Device) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &Area{
+ ad: ad,
+ tmpfsFile: nil,
+ perms: usermem.AnyAccess,
+ }), nil
+}
diff --git a/pkg/sentry/fs/ashmem/pin_board.go b/pkg/sentry/fs/ashmem/pin_board.go
new file mode 100644
index 000000000..bdf23b371
--- /dev/null
+++ b/pkg/sentry/fs/ashmem/pin_board.go
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ashmem
+
+import "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+
+const maxUint64 = ^uint64(0)
+
+// setFunctions implements segment.Functions generated from segment.Functions for
+// uint64 Key and noValue Value. For more information, see the build file and
+// segment set implementation at pkg/segment/set.go.
+type setFunctions struct{}
+
+// noValue is a type of range attached value, which is irrelevant here.
+type noValue struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (setFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (setFunctions) MaxKey() uint64 {
+ return maxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (setFunctions) ClearValue(*noValue) {
+ return
+}
+
+// Merge implements segment.Functions.Merge.
+func (setFunctions) Merge(Range, noValue, Range, noValue) (noValue, bool) {
+ return noValue{}, true
+}
+
+// Split implements segment.Functions.Split.
+func (setFunctions) Split(Range, noValue, uint64) (noValue, noValue) {
+ return noValue{}, noValue{}
+}
+
+// PinBoard represents a set of pinned ranges in ashmem.
+//
+// segment.Set is used for implementation where segments represent
+// ranges of pinned bytes, while gaps represent ranges of unpinned
+// bytes. All ranges are page-aligned.
+//
+// +stateify savable
+type PinBoard struct {
+ Set
+}
+
+// NewPinBoard creates a new pin board with all pages pinned.
+func NewPinBoard() *PinBoard {
+ var pb PinBoard
+ pb.PinRange(Range{0, maxUint64})
+ return &pb
+}
+
+// PinRange pins all pages in the specified range and returns true
+// if there are any newly pinned pages.
+func (pb *PinBoard) PinRange(r Range) bool {
+ pinnedPages := false
+ for gap := pb.LowerBoundGap(r.Start); gap.Ok() && gap.Start() < r.End; {
+ common := gap.Range().Intersect(r)
+ if common.Length() == 0 {
+ gap = gap.NextGap()
+ continue
+ }
+ pinnedPages = true
+ gap = pb.Insert(gap, common, noValue{}).NextGap()
+ }
+ return pinnedPages
+}
+
+// UnpinRange unpins all pages in the specified range.
+func (pb *PinBoard) UnpinRange(r Range) {
+ for seg := pb.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; {
+ common := seg.Range().Intersect(r)
+ if common.Length() == 0 {
+ seg = seg.NextSegment()
+ continue
+ }
+ seg = pb.RemoveRange(common).NextSegment()
+ }
+}
+
+// RangePinnedStatus returns false if there's at least one unpinned page in the
+// specified range.
+func (pb *PinBoard) RangePinnedStatus(r Range) bool {
+ for gap := pb.LowerBoundGap(r.Start); gap.Ok() && gap.Start() < r.End; {
+ common := gap.Range().Intersect(r)
+ if common.Length() == 0 {
+ gap = gap.NextGap()
+ continue
+ }
+ return false
+ }
+ return true
+}
+
+// RangeFromAshmemPin converts ashmem's original pin structure
+// to Range.
+func RangeFromAshmemPin(ap linux.AshmemPin) Range {
+ if ap.Len == 0 {
+ return Range{
+ uint64(ap.Offset),
+ maxUint64,
+ }
+ }
+ return Range{
+ uint64(ap.Offset),
+ uint64(ap.Offset) + uint64(ap.Len),
+ }
+}
diff --git a/pkg/sentry/fs/ashmem/uint64_range.go b/pkg/sentry/fs/ashmem/uint64_range.go
new file mode 100755
index 000000000..d71a10b16
--- /dev/null
+++ b/pkg/sentry/fs/ashmem/uint64_range.go
@@ -0,0 +1,62 @@
+package ashmem
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type Range struct {
+ // Start is the inclusive start of the range.
+ Start uint64
+
+ // End is the exclusive end of the range.
+ End uint64
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r Range) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r Range) Length() uint64 {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r Range) Contains(x uint64) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r Range) Overlaps(r2 Range) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r Range) IsSupersetOf(r2 Range) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r Range) Intersect(r2 Range) Range {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r Range) CanSplitAt(x uint64) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/sentry/fs/ashmem/uint64_set.go b/pkg/sentry/fs/ashmem/uint64_set.go
new file mode 100755
index 000000000..6e435325f
--- /dev/null
+++ b/pkg/sentry/fs/ashmem/uint64_set.go
@@ -0,0 +1,1270 @@
+package ashmem
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ minDegree = 3
+
+ maxDegree = 2 * minDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type Set struct {
+ root node `state:".(*SegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *Set) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *Set) IsEmptyRange(r Range) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *Set) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *Set) SpanRange(r Range) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *Set) FirstSegment() Iterator {
+ if s.root.nrSegments == 0 {
+ return Iterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *Set) LastSegment() Iterator {
+ if s.root.nrSegments == 0 {
+ return Iterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *Set) FirstGap() GapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return GapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *Set) LastGap() GapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return GapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *Set) Find(key uint64) (Iterator, GapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return Iterator{n, i}, GapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return Iterator{}, GapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *Set) FindSegment(key uint64) Iterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *Set) LowerBoundSegment(min uint64) Iterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *Set) UpperBoundSegment(max uint64) Iterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *Set) FindGap(key uint64) GapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *Set) LowerBoundGap(min uint64) GapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *Set) UpperBoundGap(max uint64) GapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *Set) Add(r Range, val noValue) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *Set) AddWithoutMerging(r Range, val noValue) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *Set) Insert(gap GapIterator, r Range, val noValue) Iterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (setFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (setFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (setFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val noValue) Iterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val noValue) Iterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return Iterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *Set) Remove(seg Iterator) GapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ setFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *Set) RemoveAll() {
+ s.root = node{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *Set) RemoveRange(r Range) GapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *Set) Merge(first, second Iterator) Iterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *Set) MergeUnchecked(first, second Iterator) Iterator {
+ if first.End() == second.Start() {
+ if mval, ok := (setFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return Iterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *Set) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *Set) MergeRange(r Range) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *Set) MergeAdjacent(r Range) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *Set) Split(seg Iterator, split uint64) (Iterator, Iterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *Set) SplitUnchecked(seg Iterator, split uint64) (Iterator, Iterator) {
+ val1, val2 := (setFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), Range{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *Set) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *Set) Isolate(seg Iterator, r Range) Iterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return GapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return GapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type node struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *node
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [maxDegree - 1]Range
+ values [maxDegree - 1]noValue
+ children [maxDegree]*node
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *node) firstSegment() Iterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return Iterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *node) lastSegment() Iterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return Iterator{n, n.nrSegments - 1}
+}
+
+func (n *node) prevSibling() *node {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *node) nextSibling() *node {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < maxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &node{
+ nrSegments: minDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &node{
+ nrSegments: minDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:minDegree-1], n.keys[:minDegree-1])
+ copy(left.values[:minDegree-1], n.values[:minDegree-1])
+ copy(right.keys[:minDegree-1], n.keys[minDegree:])
+ copy(right.values[:minDegree-1], n.values[minDegree:])
+ n.keys[0], n.values[0] = n.keys[minDegree-1], n.values[minDegree-1]
+ zeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:minDegree], n.children[:minDegree])
+ copy(right.children[:minDegree], n.children[minDegree:])
+ zeroNodeSlice(n.children[2:])
+ for i := 0; i < minDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < minDegree {
+ return GapIterator{left, gap.index}
+ }
+ return GapIterator{right, gap.index - minDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[minDegree-1], n.values[minDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &node{
+ nrSegments: minDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:minDegree-1], n.keys[minDegree:])
+ copy(sibling.values[:minDegree-1], n.values[minDegree:])
+ zeroValueSlice(n.values[minDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:minDegree], n.children[minDegree:])
+ zeroNodeSlice(n.children[minDegree:])
+ for i := 0; i < minDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = minDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < minDegree {
+ return gap
+ }
+ return GapIterator{sibling, gap.index - minDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
+ for {
+ if n.nrSegments >= minDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= minDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ setFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return GapIterator{n, 0}
+ }
+ if gap.node == n {
+ return GapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= minDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ setFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return GapIterator{n, n.nrSegments}
+ }
+ return GapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return GapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return GapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *node
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = GapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ setFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type Iterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *node
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg Iterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg Iterator) Range() Range {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg Iterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg Iterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg Iterator) SetRangeUnchecked(r Range) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg Iterator) SetRange(r Range) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg Iterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg Iterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg Iterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg Iterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg Iterator) Value() noValue {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg Iterator) ValuePtr() *noValue {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg Iterator) SetValue(val noValue) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg Iterator) PrevSegment() Iterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return Iterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return Iterator{}
+ }
+ return segmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg Iterator) NextSegment() Iterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return Iterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return Iterator{}
+ }
+ return segmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg Iterator) PrevGap() GapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return GapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg Iterator) NextGap() GapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return GapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg Iterator) PrevNonEmpty() (Iterator, GapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return Iterator{}, gap
+ }
+ return gap.PrevSegment(), GapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg Iterator) NextNonEmpty() (Iterator, GapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return Iterator{}, gap
+ }
+ return gap.NextSegment(), GapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type GapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *node
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap GapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap GapIterator) Range() Range {
+ return Range{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap GapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return setFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap GapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return setFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap GapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap GapIterator) PrevSegment() Iterator {
+ return segmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap GapIterator) NextSegment() Iterator {
+ return segmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap GapIterator) PrevGap() GapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return GapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap GapIterator) NextGap() GapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return GapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func segmentBeforePosition(n *node, i int) Iterator {
+ for i == 0 {
+ if n.parent == nil {
+ return Iterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return Iterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func segmentAfterPosition(n *node, i int) Iterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return Iterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return Iterator{n, i}
+}
+
+func zeroValueSlice(slice []noValue) {
+
+ for i := range slice {
+ setFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func zeroNodeSlice(slice []*node) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *Set) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *node) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type SegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []noValue
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *Set) ExportSortedSlices() *SegmentDataSlices {
+ var sds SegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := Range{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *Set) saveRoot() *SegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *Set) loadRoot(sds *SegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go
new file mode 100644
index 000000000..591e35e6a
--- /dev/null
+++ b/pkg/sentry/fs/attr.go
@@ -0,0 +1,422 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "os"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+)
+
+// InodeType enumerates types of Inodes.
+type InodeType int
+
+const (
+ // RegularFile is a regular file.
+ RegularFile InodeType = iota
+
+ // SpecialFile is a file that doesn't support SeekEnd. It is used for
+ // things like proc files.
+ SpecialFile
+
+ // Directory is a directory.
+ Directory
+
+ // SpecialDirectory is a directory that *does* support SeekEnd. It's
+ // the opposite of the SpecialFile scenario above. It similarly
+ // supports proc files.
+ SpecialDirectory
+
+ // Symlink is a symbolic link.
+ Symlink
+
+ // Pipe is a pipe (named or regular).
+ Pipe
+
+ // Socket is a socket.
+ Socket
+
+ // CharacterDevice is a character device.
+ CharacterDevice
+
+ // BlockDevice is a block device.
+ BlockDevice
+
+ // Anonymous is an anonymous type when none of the above apply.
+ // Epoll fds and event-driven fds fit this category.
+ Anonymous
+)
+
+// String returns a human-readable representation of the InodeType.
+func (n InodeType) String() string {
+ switch n {
+ case RegularFile, SpecialFile:
+ return "file"
+ case Directory, SpecialDirectory:
+ return "directory"
+ case Symlink:
+ return "symlink"
+ case Pipe:
+ return "pipe"
+ case Socket:
+ return "socket"
+ case CharacterDevice:
+ return "character-device"
+ case BlockDevice:
+ return "block-device"
+ case Anonymous:
+ return "anonymous"
+ default:
+ return "unknown"
+ }
+}
+
+// StableAttr contains Inode attributes that will be stable throughout the
+// lifetime of the Inode.
+//
+// +stateify savable
+type StableAttr struct {
+ // Type is the InodeType of a InodeOperations.
+ Type InodeType
+
+ // DeviceID is the device on which a InodeOperations resides.
+ DeviceID uint64
+
+ // InodeID uniquely identifies InodeOperations on its device.
+ InodeID uint64
+
+ // BlockSize is the block size of data backing this InodeOperations.
+ BlockSize int64
+
+ // DeviceFileMajor is the major device number of this Node, if it is a
+ // device file.
+ DeviceFileMajor uint16
+
+ // DeviceFileMinor is the minor device number of this Node, if it is a
+ // device file.
+ DeviceFileMinor uint32
+}
+
+// IsRegular returns true if StableAttr.Type matches a regular file.
+func IsRegular(s StableAttr) bool {
+ return s.Type == RegularFile
+}
+
+// IsFile returns true if StableAttr.Type matches any type of file.
+func IsFile(s StableAttr) bool {
+ return s.Type == RegularFile || s.Type == SpecialFile
+}
+
+// IsDir returns true if StableAttr.Type matches any type of directory.
+func IsDir(s StableAttr) bool {
+ return s.Type == Directory || s.Type == SpecialDirectory
+}
+
+// IsSymlink returns true if StableAttr.Type matches a symlink.
+func IsSymlink(s StableAttr) bool {
+ return s.Type == Symlink
+}
+
+// IsPipe returns true if StableAttr.Type matches any type of pipe.
+func IsPipe(s StableAttr) bool {
+ return s.Type == Pipe
+}
+
+// IsSocket returns true if StableAttr.Type matches any type of socket.
+func IsSocket(s StableAttr) bool {
+ return s.Type == Socket
+}
+
+// IsCharDevice returns true if StableAttr.Type matches a character device.
+func IsCharDevice(s StableAttr) bool {
+ return s.Type == CharacterDevice
+}
+
+// UnstableAttr contains Inode attributes that may change over the lifetime
+// of the Inode.
+//
+// +stateify savable
+type UnstableAttr struct {
+ // Size is the file size in bytes.
+ Size int64
+
+ // Usage is the actual data usage in bytes.
+ Usage int64
+
+ // Perms is the protection (read/write/execute for user/group/other).
+ Perms FilePermissions
+
+ // Owner describes the ownership of this file.
+ Owner FileOwner
+
+ // AccessTime is the time of last access
+ AccessTime ktime.Time
+
+ // ModificationTime is the time of last modification.
+ ModificationTime ktime.Time
+
+ // StatusChangeTime is the time of last attribute modification.
+ StatusChangeTime ktime.Time
+
+ // Links is the number of hard links.
+ Links uint64
+}
+
+// SetOwner sets the owner and group if they are valid.
+//
+// This method is NOT thread-safe. Callers must prevent concurrent calls.
+func (ua *UnstableAttr) SetOwner(ctx context.Context, owner FileOwner) {
+ if owner.UID.Ok() {
+ ua.Owner.UID = owner.UID
+ }
+ if owner.GID.Ok() {
+ ua.Owner.GID = owner.GID
+ }
+ ua.StatusChangeTime = ktime.NowFromContext(ctx)
+}
+
+// SetPermissions sets the permissions.
+//
+// This method is NOT thread-safe. Callers must prevent concurrent calls.
+func (ua *UnstableAttr) SetPermissions(ctx context.Context, p FilePermissions) {
+ ua.Perms = p
+ ua.StatusChangeTime = ktime.NowFromContext(ctx)
+}
+
+// SetTimestamps sets the timestamps according to the TimeSpec.
+//
+// This method is NOT thread-safe. Callers must prevent concurrent calls.
+func (ua *UnstableAttr) SetTimestamps(ctx context.Context, ts TimeSpec) {
+ if ts.ATimeOmit && ts.MTimeOmit {
+ return
+ }
+
+ now := ktime.NowFromContext(ctx)
+ if !ts.ATimeOmit {
+ if ts.ATimeSetSystemTime {
+ ua.AccessTime = now
+ } else {
+ ua.AccessTime = ts.ATime
+ }
+ }
+ if !ts.MTimeOmit {
+ if ts.MTimeSetSystemTime {
+ ua.ModificationTime = now
+ } else {
+ ua.ModificationTime = ts.MTime
+ }
+ }
+ ua.StatusChangeTime = now
+}
+
+// WithCurrentTime returns u with AccessTime == ModificationTime == current time.
+func WithCurrentTime(ctx context.Context, u UnstableAttr) UnstableAttr {
+ t := ktime.NowFromContext(ctx)
+ u.AccessTime = t
+ u.ModificationTime = t
+ u.StatusChangeTime = t
+ return u
+}
+
+// AttrMask contains fields to mask StableAttr and UnstableAttr.
+//
+// +stateify savable
+type AttrMask struct {
+ Type bool
+ DeviceID bool
+ InodeID bool
+ BlockSize bool
+ Size bool
+ Usage bool
+ Perms bool
+ UID bool
+ GID bool
+ AccessTime bool
+ ModificationTime bool
+ StatusChangeTime bool
+ Links bool
+}
+
+// Empty returns true if all fields in AttrMask are false.
+func (a AttrMask) Empty() bool {
+ return a == AttrMask{}
+}
+
+// PermMask are file access permissions.
+//
+// +stateify savable
+type PermMask struct {
+ // Read indicates reading is permitted.
+ Read bool
+
+ // Write indicates writing is permitted.
+ Write bool
+
+ // Execute indicates execution is permitted.
+ Execute bool
+}
+
+// OnlyRead returns true when only the read bit is set.
+func (p PermMask) OnlyRead() bool {
+ return p.Read && !p.Write && !p.Execute
+}
+
+// String implements the fmt.Stringer interface for PermMask.
+func (p PermMask) String() string {
+ return fmt.Sprintf("PermMask{Read: %v, Write: %v, Execute: %v}", p.Read, p.Write, p.Execute)
+}
+
+// Mode returns the system mode (syscall.S_IXOTH, etc.) for these permissions
+// in the "other" bits.
+func (p PermMask) Mode() (mode os.FileMode) {
+ if p.Read {
+ mode |= syscall.S_IROTH
+ }
+ if p.Write {
+ mode |= syscall.S_IWOTH
+ }
+ if p.Execute {
+ mode |= syscall.S_IXOTH
+ }
+ return
+}
+
+// SupersetOf returns true iff the permissions in p are a superset of the
+// permissions in other.
+func (p PermMask) SupersetOf(other PermMask) bool {
+ if !p.Read && other.Read {
+ return false
+ }
+ if !p.Write && other.Write {
+ return false
+ }
+ if !p.Execute && other.Execute {
+ return false
+ }
+ return true
+}
+
+// FilePermissions represents the permissions of a file, with
+// Read/Write/Execute bits for user, group, and other.
+//
+// +stateify savable
+type FilePermissions struct {
+ User PermMask
+ Group PermMask
+ Other PermMask
+
+ // Sticky, if set on directories, restricts renaming and deletion of
+ // files in those directories to the directory owner, file owner, or
+ // CAP_FOWNER. The sticky bit is ignored when set on other files.
+ Sticky bool
+
+ // SetUID executables can call UID-setting syscalls without CAP_SETUID.
+ SetUID bool
+
+ // SetGID executables can call GID-setting syscalls without CAP_SETGID.
+ SetGID bool
+}
+
+// PermsFromMode takes the Other permissions (last 3 bits) of a FileMode and
+// returns a set of PermMask.
+func PermsFromMode(mode linux.FileMode) (perms PermMask) {
+ perms.Read = mode&linux.ModeOtherRead != 0
+ perms.Write = mode&linux.ModeOtherWrite != 0
+ perms.Execute = mode&linux.ModeOtherExec != 0
+ return
+}
+
+// FilePermsFromP9 converts a p9.FileMode to a FilePermissions struct.
+func FilePermsFromP9(mode p9.FileMode) FilePermissions {
+ return FilePermsFromMode(linux.FileMode(mode))
+}
+
+// FilePermsFromMode converts a system file mode to a FilePermissions struct.
+func FilePermsFromMode(mode linux.FileMode) (fp FilePermissions) {
+ perm := mode.Permissions()
+ fp.Other = PermsFromMode(perm)
+ fp.Group = PermsFromMode(perm >> 3)
+ fp.User = PermsFromMode(perm >> 6)
+ fp.Sticky = mode&linux.ModeSticky == linux.ModeSticky
+ fp.SetUID = mode&linux.ModeSetUID == linux.ModeSetUID
+ fp.SetGID = mode&linux.ModeSetGID == linux.ModeSetGID
+ return
+}
+
+// LinuxMode returns the linux mode_t representation of these permissions.
+func (f FilePermissions) LinuxMode() linux.FileMode {
+ m := linux.FileMode(f.User.Mode()<<6 | f.Group.Mode()<<3 | f.Other.Mode())
+ if f.SetUID {
+ m |= linux.ModeSetUID
+ }
+ if f.SetGID {
+ m |= linux.ModeSetGID
+ }
+ if f.Sticky {
+ m |= linux.ModeSticky
+ }
+ return m
+}
+
+// OSMode returns the Go runtime's OS independent os.FileMode representation of
+// these permissions.
+func (f FilePermissions) OSMode() os.FileMode {
+ m := os.FileMode(f.User.Mode()<<6 | f.Group.Mode()<<3 | f.Other.Mode())
+ if f.SetUID {
+ m |= os.ModeSetuid
+ }
+ if f.SetGID {
+ m |= os.ModeSetgid
+ }
+ if f.Sticky {
+ m |= os.ModeSticky
+ }
+ return m
+}
+
+// AnyExecute returns true if any of U/G/O have the execute bit set.
+func (f FilePermissions) AnyExecute() bool {
+ return f.User.Execute || f.Group.Execute || f.Other.Execute
+}
+
+// AnyWrite returns true if any of U/G/O have the write bit set.
+func (f FilePermissions) AnyWrite() bool {
+ return f.User.Write || f.Group.Write || f.Other.Write
+}
+
+// AnyRead returns true if any of U/G/O have the read bit set.
+func (f FilePermissions) AnyRead() bool {
+ return f.User.Read || f.Group.Read || f.Other.Read
+}
+
+// FileOwner represents ownership of a file.
+//
+// +stateify savable
+type FileOwner struct {
+ UID auth.KUID
+ GID auth.KGID
+}
+
+// RootOwner corresponds to KUID/KGID 0/0.
+var RootOwner = FileOwner{
+ UID: auth.RootKUID,
+ GID: auth.RootKGID,
+}
diff --git a/pkg/sentry/fs/binder/binder.go b/pkg/sentry/fs/binder/binder.go
new file mode 100644
index 000000000..c78f1fc40
--- /dev/null
+++ b/pkg/sentry/fs/binder/binder.go
@@ -0,0 +1,260 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package binder implements Android Binder IPC module.
+package binder
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ currentProtocolVersion = 8
+
+ // mmapSizeLimit is the upper limit for mapped memory size in Binder.
+ mmapSizeLimit = 4 * 1024 * 1024 // 4MB
+)
+
+// Device implements fs.InodeOperations.
+//
+// +stateify savable
+type Device struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*Device)(nil)
+
+// NewDevice creates and intializes a Device structure.
+func NewDevice(ctx context.Context, owner fs.FileOwner, fp fs.FilePermissions) *Device {
+ return &Device{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fp, 0),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+//
+// TODO(b/30946773): Add functionality to GetFile: Additional fields will be
+// needed in the Device structure, initialize them here. Also, Device will need
+// to keep track of the created Procs in order to implement BINDER_READ_WRITE
+// ioctl.
+func (bd *Device) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &Proc{
+ bd: bd,
+ task: kernel.TaskFromContext(ctx),
+ mfp: pgalloc.MemoryFileProviderFromContext(ctx),
+ }), nil
+}
+
+// Proc implements fs.FileOperations and fs.IoctlGetter.
+//
+// +stateify savable
+type Proc struct {
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ bd *Device
+ task *kernel.Task
+ mfp pgalloc.MemoryFileProvider
+
+ // mu protects fr.
+ mu sync.Mutex `state:"nosave"`
+
+ // mapped is memory allocated from mfp.MemoryFile() by AddMapping.
+ mapped platform.FileRange
+}
+
+// Release implements fs.FileOperations.Release.
+func (bp *Proc) Release() {
+ bp.mu.Lock()
+ defer bp.mu.Unlock()
+ if bp.mapped.Length() != 0 {
+ bp.mfp.MemoryFile().DecRef(bp.mapped)
+ }
+}
+
+// Seek implements fs.FileOperations.Seek.
+//
+// Binder doesn't support seek operation (unless in debug mode).
+func (bp *Proc) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return offset, syserror.EOPNOTSUPP
+}
+
+// Read implements fs.FileOperations.Read.
+//
+// Binder doesn't support read operation (unless in debug mode).
+func (bp *Proc) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ return 0, syserror.EOPNOTSUPP
+}
+
+// Write implements fs.FileOperations.Write.
+//
+// Binder doesn't support write operation.
+func (bp *Proc) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ return 0, syserror.EOPNOTSUPP
+}
+
+// Flush implements fs.FileOperations.Flush.
+//
+// TODO(b/30946773): Implement.
+func (bp *Proc) Flush(ctx context.Context, file *fs.File) error {
+ return nil
+}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (bp *Proc) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ // Compare drivers/android/binder.c:binder_mmap().
+ if caller := kernel.TaskFromContext(ctx); caller != bp.task {
+ return syserror.EINVAL
+ }
+ if opts.Length > mmapSizeLimit {
+ opts.Length = mmapSizeLimit
+ }
+ opts.MaxPerms.Write = false
+
+ // TODO(b/30946773): Binder sets VM_DONTCOPY, preventing the created vma
+ // from being copied across fork(), but we don't support this yet. As
+ // a result, MMs containing a Binder mapping cannot be forked (MM.Fork will
+ // fail when AddMapping returns EBUSY).
+
+ return fsutil.GenericConfigureMMap(file, bp, opts)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+//
+// TODO(b/30946773): Implement.
+func (bp *Proc) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Switch on ioctl request.
+ switch uint32(args[1].Int()) {
+ case linux.BinderVersionIoctl:
+ ver := &linux.BinderVersion{
+ ProtocolVersion: currentProtocolVersion,
+ }
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ver, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.BinderWriteReadIoctl:
+ // TODO(b/30946773): Implement.
+ fallthrough
+ case linux.BinderSetIdleTimeoutIoctl:
+ // TODO(b/30946773): Implement.
+ fallthrough
+ case linux.BinderSetMaxThreadsIoctl:
+ // TODO(b/30946773): Implement.
+ fallthrough
+ case linux.BinderSetIdlePriorityIoctl:
+ // TODO(b/30946773): Implement.
+ fallthrough
+ case linux.BinderSetContextMgrIoctl:
+ // TODO(b/30946773): Implement.
+ fallthrough
+ case linux.BinderThreadExitIoctl:
+ // TODO(b/30946773): Implement.
+ return 0, syserror.ENOSYS
+ default:
+ // Ioctls irrelevant to Binder.
+ return 0, syserror.EINVAL
+ }
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (bp *Proc) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, _ bool) error {
+ bp.mu.Lock()
+ defer bp.mu.Unlock()
+ if bp.mapped.Length() != 0 {
+ // mmap has been called before, which binder_mmap() doesn't like.
+ return syserror.EBUSY
+ }
+ // Binder only allocates and maps a single page up-front
+ // (drivers/android/binder.c:binder_mmap() => binder_update_page_range()).
+ fr, err := bp.mfp.MemoryFile().Allocate(usermem.PageSize, usage.Anonymous)
+ if err != nil {
+ return err
+ }
+ bp.mapped = fr
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (*Proc) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) {
+ // Nothing to do. Notably, we don't free bp.mapped to allow another mmap.
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (bp *Proc) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, _ bool) error {
+ // Nothing to do. Notably, this is one case where CopyMapping isn't
+ // equivalent to AddMapping, as AddMapping would return EBUSY.
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (bp *Proc) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ // TODO(b/30946773): In addition to the page initially allocated and mapped
+ // in AddMapping (Linux: binder_mmap), Binder allocates and maps pages for
+ // each transaction (Linux: binder_ioctl => binder_ioctl_write_read =>
+ // binder_thread_write => binder_transaction => binder_alloc_buf =>
+ // binder_update_page_range). Since we don't actually implement
+ // BinderWriteReadIoctl (Linux: BINDER_WRITE_READ), we only ever have the
+ // first page.
+ var err error
+ if required.End > usermem.PageSize {
+ err = &memmap.BusError{syserror.EFAULT}
+ }
+ if required.Start == 0 {
+ return []memmap.Translation{
+ {
+ Source: memmap.MappableRange{0, usermem.PageSize},
+ File: bp.mfp.MemoryFile(),
+ Offset: bp.mapped.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, err
+ }
+ return nil, err
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (bp *Proc) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
diff --git a/pkg/sentry/fs/binder/binder_state_autogen.go b/pkg/sentry/fs/binder/binder_state_autogen.go
new file mode 100755
index 000000000..195d9e00b
--- /dev/null
+++ b/pkg/sentry/fs/binder/binder_state_autogen.go
@@ -0,0 +1,40 @@
+// automatically generated by stateify.
+
+package binder
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Device) beforeSave() {}
+func (x *Device) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *Device) afterLoad() {}
+func (x *Device) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *Proc) beforeSave() {}
+func (x *Proc) save(m state.Map) {
+ x.beforeSave()
+ m.Save("bd", &x.bd)
+ m.Save("task", &x.task)
+ m.Save("mfp", &x.mfp)
+ m.Save("mapped", &x.mapped)
+}
+
+func (x *Proc) afterLoad() {}
+func (x *Proc) load(m state.Map) {
+ m.Load("bd", &x.bd)
+ m.Load("task", &x.task)
+ m.Load("mfp", &x.mfp)
+ m.Load("mapped", &x.mapped)
+}
+
+func init() {
+ state.Register("binder.Device", (*Device)(nil), state.Fns{Save: (*Device).save, Load: (*Device).load})
+ state.Register("binder.Proc", (*Proc)(nil), state.Fns{Save: (*Proc).save, Load: (*Proc).load})
+}
diff --git a/pkg/sentry/fs/context.go b/pkg/sentry/fs/context.go
new file mode 100644
index 000000000..c80ea0175
--- /dev/null
+++ b/pkg/sentry/fs/context.go
@@ -0,0 +1,114 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+)
+
+// contextID is the fs package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxRoot is a Context.Value key for a Dirent.
+ CtxRoot contextID = iota
+
+ // CtxDirentCacheLimiter is a Context.Value key for DirentCacheLimiter.
+ CtxDirentCacheLimiter
+)
+
+// ContextCanAccessFile determines whether `file` can be accessed in the requested way
+// (for reading, writing, or execution) using the caller's credentials and user
+// namespace, as does Linux's fs/namei.c:generic_permission.
+func ContextCanAccessFile(ctx context.Context, inode *Inode, reqPerms PermMask) bool {
+ creds := auth.CredentialsFromContext(ctx)
+ uattr, err := inode.UnstableAttr(ctx)
+ if err != nil {
+ return false
+ }
+
+ p := uattr.Perms.Other
+ // Are we owner or in group?
+ if uattr.Owner.UID == creds.EffectiveKUID {
+ p = uattr.Perms.User
+ } else if creds.InGroup(uattr.Owner.GID) {
+ p = uattr.Perms.Group
+ }
+
+ // Do not allow programs to be executed if MS_NOEXEC is set.
+ if IsFile(inode.StableAttr) && reqPerms.Execute && inode.MountSource.Flags.NoExec {
+ return false
+ }
+
+ // Are permissions satisfied without capability checks?
+ if p.SupersetOf(reqPerms) {
+ return true
+ }
+
+ if IsDir(inode.StableAttr) {
+ // CAP_DAC_OVERRIDE can override any perms on directories.
+ if inode.CheckCapability(ctx, linux.CAP_DAC_OVERRIDE) {
+ return true
+ }
+
+ // CAP_DAC_READ_SEARCH can normally only override Read perms,
+ // but for directories it can also override execution.
+ if !reqPerms.Write && inode.CheckCapability(ctx, linux.CAP_DAC_READ_SEARCH) {
+ return true
+ }
+ }
+
+ // CAP_DAC_OVERRIDE can always override Read/Write.
+ // Can override executable only when at least one execute bit is set.
+ if !reqPerms.Execute || uattr.Perms.AnyExecute() {
+ if inode.CheckCapability(ctx, linux.CAP_DAC_OVERRIDE) {
+ return true
+ }
+ }
+
+ // Read perms can be overridden by CAP_DAC_READ_SEARCH.
+ if reqPerms.OnlyRead() && inode.CheckCapability(ctx, linux.CAP_DAC_READ_SEARCH) {
+ return true
+ }
+ return false
+}
+
+// FileOwnerFromContext returns a FileOwner using the effective user and group
+// IDs used by ctx.
+func FileOwnerFromContext(ctx context.Context) FileOwner {
+ creds := auth.CredentialsFromContext(ctx)
+ return FileOwner{creds.EffectiveKUID, creds.EffectiveKGID}
+}
+
+// RootFromContext returns the root of the virtual filesystem observed by ctx,
+// or nil if ctx is not associated with a virtual filesystem. If
+// RootFromContext returns a non-nil fs.Dirent, a reference is taken on it.
+func RootFromContext(ctx context.Context) *Dirent {
+ if v := ctx.Value(CtxRoot); v != nil {
+ return v.(*Dirent)
+ }
+ return nil
+}
+
+// DirentCacheLimiterFromContext returns the DirentCacheLimiter used by ctx, or
+// nil if ctx does not have a dirent cache limiter.
+func DirentCacheLimiterFromContext(ctx context.Context) *DirentCacheLimiter {
+ if v := ctx.Value(CtxDirentCacheLimiter); v != nil {
+ return v.(*DirentCacheLimiter)
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
new file mode 100644
index 000000000..41265704c
--- /dev/null
+++ b/pkg/sentry/fs/copy_up.go
@@ -0,0 +1,433 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "io"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// copyUp copies a file in an overlay from a lower filesystem to an
+// upper filesytem so that the file can be modified in the upper
+// filesystem. Copying a file involves several steps:
+//
+// - All parent directories of the file are created in the upper
+// filesystem if they don't exist there. For instance:
+//
+// upper /dir0
+// lower /dir0/dir1/file
+//
+// copyUp of /dir0/dir1/file creates /dir0/dir1 in order to create
+// /dir0/dir1/file.
+//
+// - The file content is copied from the lower file to the upper
+// file. For symlinks this is the symlink target. For directories,
+// upper directory entries are merged with lower directory entries
+// so there is no need to copy any entries.
+//
+// - A subset of file attributes of the lower file are set on the
+// upper file. These are the file owner, the file timestamps,
+// and all non-overlay extended attributes. copyUp will fail if
+// the upper filesystem does not support the setting of these
+// attributes.
+//
+// The file's permissions are set when the file is created and its
+// size will be brought up to date when its contents are copied.
+// Notably no attempt is made to bring link count up to date because
+// hard links are currently not preserved across overlay filesystems.
+//
+// - Memory mappings of the lower file are invalidated and memory
+// references are transferred to the upper file. From this point on,
+// memory mappings of the file will be backed by content in the upper
+// filesystem.
+//
+// Synchronization:
+//
+// copyUp synchronizes with rename(2) using renameMu to ensure that
+// parentage does not change while a file is being copied. In the context
+// of rename(2), copyUpLockedForRename should be used to avoid deadlock on
+// renameMu.
+//
+// The following operations synchronize with copyUp using copyMu:
+//
+// - InodeOperations, i.e. to ensure that looking up a directory takes
+// into account new upper filesystem directories created by copy up,
+// which subsequently can be modified.
+//
+// - FileOperations, i.e. to ensure that reading from a file does not
+// continue using a stale, lower filesystem handle when the file is
+// written to.
+//
+// Lock ordering: Dirent.mu -> Inode.overlay.copyMu -> Inode.mu.
+//
+// Caveats:
+//
+// If any step in copying up a file fails, copyUp cleans the upper
+// filesystem of any partially up-to-date file. If this cleanup fails,
+// the overlay may be in an unacceptable, inconsistent state, so copyUp
+// panics. If copyUp fails because any step (above) fails, a generic
+// error is returned.
+//
+// copyUp currently makes no attempt to optimize copying up file content.
+// For large files, this means that copyUp blocks until the entire file
+// is copied synchronously.
+func copyUp(ctx context.Context, d *Dirent) error {
+ renameMu.RLock()
+ defer renameMu.RUnlock()
+ return copyUpLockedForRename(ctx, d)
+}
+
+// copyUpLockedForRename is the same as copyUp except that it does not lock
+// renameMu.
+//
+// It copies each component of d that does not yet exist in the upper
+// filesystem. If d already exists in the upper filesystem, it is a no-op.
+//
+// Any error returned indicates a failure to copy all of d. This may
+// leave the upper filesystem filled with any number of parent directories
+// but the upper filesystem will never be in an inconsistent state.
+//
+// Preconditions:
+// - d.Inode.overlay is non-nil.
+func copyUpLockedForRename(ctx context.Context, d *Dirent) error {
+ for {
+ // Did we race with another copy up or does there
+ // already exist something in the upper filesystem
+ // for d?
+ d.Inode.overlay.copyMu.RLock()
+ if d.Inode.overlay.upper != nil {
+ d.Inode.overlay.copyMu.RUnlock()
+ // Done, d is in the upper filesystem.
+ return nil
+ }
+ d.Inode.overlay.copyMu.RUnlock()
+
+ // Find the next component to copy up. We will work our way
+ // down to the last component of d and finally copy it.
+ next := findNextCopyUp(ctx, d)
+
+ // Attempt to copy.
+ if err := doCopyUp(ctx, next); err != nil {
+ return err
+ }
+ }
+}
+
+// findNextCopyUp finds the next component of d from root that does not
+// yet exist in the upper filesystem. The parent of this component is
+// also returned, which is the root of the overlay in the worst case.
+func findNextCopyUp(ctx context.Context, d *Dirent) *Dirent {
+ next := d
+ for parent := next.parent; ; /* checked in-loop */ /* updated in-loop */ {
+ // Does this parent have a non-nil upper Inode?
+ parent.Inode.overlay.copyMu.RLock()
+ if parent.Inode.overlay.upper != nil {
+ parent.Inode.overlay.copyMu.RUnlock()
+ // Note that since we found an upper, it is stable.
+ return next
+ }
+ parent.Inode.overlay.copyMu.RUnlock()
+
+ // Continue searching for a parent with a non-nil
+ // upper Inode.
+ next = parent
+ parent = next.parent
+ }
+}
+
+func doCopyUp(ctx context.Context, d *Dirent) error {
+ // Fail fast on Inode types we won't be able to copy up anyways. These
+ // Inodes may block in GetFile while holding copyMu for reading. If we
+ // then try to take copyMu for writing here, we'd deadlock.
+ t := d.Inode.overlay.lower.StableAttr.Type
+ if t != RegularFile && t != Directory && t != Symlink {
+ return syserror.EINVAL
+ }
+
+ // Wait to get exclusive access to the upper Inode.
+ d.Inode.overlay.copyMu.Lock()
+ defer d.Inode.overlay.copyMu.Unlock()
+ if d.Inode.overlay.upper != nil {
+ // We raced with another doCopyUp, no problem.
+ return nil
+ }
+
+ // Perform the copy.
+ return copyUpLocked(ctx, d.parent, d)
+}
+
+// copyUpLocked creates a copy of next in the upper filesystem of parent.
+//
+// copyUpLocked must be called with d.Inode.overlay.copyMu locked.
+//
+// Returns a generic error on failure.
+//
+// Preconditions:
+// - parent.Inode.overlay.upper must be non-nil.
+// - next.Inode.overlay.copyMu must be locked writable.
+// - next.Inode.overlay.lower must be non-nil.
+// - next.Inode.overlay.lower.StableAttr.Type must be RegularFile, Directory,
+// or Symlink.
+// - upper filesystem must support setting file ownership and timestamps.
+func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
+ // Extract the attributes of the file we wish to copy.
+ attrs, err := next.Inode.overlay.lower.UnstableAttr(ctx)
+ if err != nil {
+ log.Warningf("copy up failed to get lower attributes: %v", err)
+ return syserror.EIO
+ }
+
+ var childUpperInode *Inode
+ parentUpper := parent.Inode.overlay.upper
+ root := RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+
+ // Create the file in the upper filesystem and get an Inode for it.
+ switch next.Inode.StableAttr.Type {
+ case RegularFile:
+ childFile, err := parentUpper.Create(ctx, root, next.name, FileFlags{Read: true, Write: true}, attrs.Perms)
+ if err != nil {
+ log.Warningf("copy up failed to create file: %v", err)
+ return syserror.EIO
+ }
+ defer childFile.DecRef()
+ childUpperInode = childFile.Dirent.Inode
+
+ case Directory:
+ if err := parentUpper.CreateDirectory(ctx, root, next.name, attrs.Perms); err != nil {
+ log.Warningf("copy up failed to create directory: %v", err)
+ return syserror.EIO
+ }
+ childUpper, err := parentUpper.Lookup(ctx, next.name)
+ if err != nil {
+ log.Warningf("copy up failed to lookup directory: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name)
+ return syserror.EIO
+ }
+ defer childUpper.DecRef()
+ childUpperInode = childUpper.Inode
+
+ case Symlink:
+ childLower := next.Inode.overlay.lower
+ link, err := childLower.Readlink(ctx)
+ if err != nil {
+ log.Warningf("copy up failed to read symlink value: %v", err)
+ return syserror.EIO
+ }
+ if err := parentUpper.CreateLink(ctx, root, link, next.name); err != nil {
+ log.Warningf("copy up failed to create symlink: %v", err)
+ return syserror.EIO
+ }
+ childUpper, err := parentUpper.Lookup(ctx, next.name)
+ if err != nil {
+ log.Warningf("copy up failed to lookup symlink: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name)
+ return syserror.EIO
+ }
+ defer childUpper.DecRef()
+ childUpperInode = childUpper.Inode
+
+ default:
+ panic(fmt.Sprintf("copy up of invalid type %v on %+v", next.Inode.StableAttr.Type, next))
+ }
+
+ // Bring file attributes up to date. This does not include size, which will be
+ // brought up to date with copyContentsLocked.
+ if err := copyAttributesLocked(ctx, childUpperInode, next.Inode.overlay.lower); err != nil {
+ log.Warningf("copy up failed to copy up attributes: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name)
+ return syserror.EIO
+ }
+
+ // Copy the entire file.
+ if err := copyContentsLocked(ctx, childUpperInode, next.Inode.overlay.lower, attrs.Size); err != nil {
+ log.Warningf("copy up failed to copy up contents: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name)
+ return syserror.EIO
+ }
+
+ lowerMappable := next.Inode.overlay.lower.Mappable()
+ upperMappable := childUpperInode.Mappable()
+ if lowerMappable != nil && upperMappable == nil {
+ log.Warningf("copy up failed: cannot ensure memory mapping coherence")
+ cleanupUpper(ctx, parentUpper, next.name)
+ return syserror.EIO
+ }
+
+ // Propagate memory mappings to the upper Inode.
+ next.Inode.overlay.mapsMu.Lock()
+ defer next.Inode.overlay.mapsMu.Unlock()
+ if upperMappable != nil {
+ // Remember which mappings we added so we can remove them on failure.
+ allAdded := make(map[memmap.MappableRange]memmap.MappingsOfRange)
+ for seg := next.Inode.overlay.mappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ added := make(memmap.MappingsOfRange)
+ for m := range seg.Value() {
+ if err := upperMappable.AddMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable); err != nil {
+ for m := range added {
+ upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
+ }
+ for mr, mappings := range allAdded {
+ for m := range mappings {
+ upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, mr.Start, m.Writable)
+ }
+ }
+ return err
+ }
+ added[m] = struct{}{}
+ }
+ allAdded[seg.Range()] = added
+ }
+ }
+
+ // Take a reference on the upper Inode (transferred to
+ // next.Inode.overlay.upper) and make new translations use it.
+ next.Inode.overlay.dataMu.Lock()
+ childUpperInode.IncRef()
+ next.Inode.overlay.upper = childUpperInode
+ next.Inode.overlay.dataMu.Unlock()
+
+ // Invalidate existing translations through the lower Inode.
+ next.Inode.overlay.mappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Remove existing memory mappings from the lower Inode.
+ if lowerMappable != nil {
+ for seg := next.Inode.overlay.mappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ for m := range seg.Value() {
+ lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
+ }
+ }
+ }
+
+ return nil
+}
+
+// cleanupUpper removes name from parent, and panics if it is unsuccessful.
+func cleanupUpper(ctx context.Context, parent *Inode, name string) {
+ if err := parent.InodeOperations.Remove(ctx, parent, name); err != nil {
+ // Unfortunately we don't have much choice. We shouldn't
+ // willingly give the caller access to a nonsense filesystem.
+ panic(fmt.Sprintf("overlay filesystem is in an inconsistent state: failed to remove %q from upper filesystem: %v", name, err))
+ }
+}
+
+// copyUpBuffers is a buffer pool for copying file content. The buffer
+// size is the same used by io.Copy.
+var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }}
+
+// copyContentsLocked copies the contents of lower to upper. It panics if
+// less than size bytes can be copied.
+func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size int64) error {
+ // We don't support copying up for anything other than regular files.
+ if lower.StableAttr.Type != RegularFile {
+ return nil
+ }
+
+ // Get a handle to the upper filesystem, which we will write to.
+ upperFile, err := overlayFile(ctx, upper, FileFlags{Write: true})
+ if err != nil {
+ return err
+ }
+ defer upperFile.DecRef()
+
+ // Get a handle to the lower filesystem, which we will read from.
+ lowerFile, err := overlayFile(ctx, lower, FileFlags{Read: true})
+ if err != nil {
+ return err
+ }
+ defer lowerFile.DecRef()
+
+ // Use a buffer pool to minimize allocations.
+ buf := copyUpBuffers.Get().([]byte)
+ defer copyUpBuffers.Put(buf)
+
+ // Transfer the contents.
+ //
+ // One might be able to optimize this by doing parallel reads, parallel writes and reads, larger
+ // buffers, etc. But we really don't know anything about the underlying implementation, so these
+ // optimizations could be self-defeating. So we leave this as simple as possible.
+ var offset int64
+ for {
+ nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset)
+ if err != nil && err != io.EOF {
+ return err
+ }
+ if nr == 0 {
+ if offset != size {
+ // Same as in cleanupUpper, we cannot live
+ // with ourselves if we do anything less.
+ panic(fmt.Sprintf("filesystem is in an inconsistent state: wrote only %d bytes of %d sized file", offset, size))
+ }
+ return nil
+ }
+ nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset)
+ if err != nil {
+ return err
+ }
+ offset += nw
+ }
+}
+
+// copyAttributesLocked copies a subset of lower's attributes to upper,
+// specifically owner, timestamps (except of status change time), and
+// extended attributes. Notably no attempt is made to copy link count.
+// Size and permissions are set on upper when the file content is copied
+// and when the file is created respectively.
+func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error {
+ // Extract attributes fro the lower filesystem.
+ lowerAttr, err := lower.UnstableAttr(ctx)
+ if err != nil {
+ return err
+ }
+ lowerXattr, err := lower.Listxattr()
+ if err != nil && err != syserror.EOPNOTSUPP {
+ return err
+ }
+
+ // Set the attributes on the upper filesystem.
+ if err := upper.InodeOperations.SetOwner(ctx, upper, lowerAttr.Owner); err != nil {
+ return err
+ }
+ if err := upper.InodeOperations.SetTimestamps(ctx, upper, TimeSpec{
+ ATime: lowerAttr.AccessTime,
+ MTime: lowerAttr.ModificationTime,
+ }); err != nil {
+ return err
+ }
+ for name := range lowerXattr {
+ // Don't copy-up attributes that configure an overlay in the
+ // lower.
+ if isXattrOverlay(name) {
+ continue
+ }
+ value, err := lower.Getxattr(name)
+ if err != nil {
+ return err
+ }
+ if err := upper.InodeOperations.Setxattr(upper, name, value); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/dentry.go b/pkg/sentry/fs/dentry.go
new file mode 100644
index 000000000..7a2d4b180
--- /dev/null
+++ b/pkg/sentry/fs/dentry.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "sort"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+)
+
+// DentAttr is the metadata of a directory entry. It is a subset of StableAttr.
+//
+// +stateify savable
+type DentAttr struct {
+ // Type is the InodeType of an Inode.
+ Type InodeType
+
+ // InodeID uniquely identifies an Inode on a device.
+ InodeID uint64
+}
+
+// GenericDentAttr returns a generic DentAttr where:
+//
+// Type == nt
+// InodeID == the inode id of a new inode on device.
+func GenericDentAttr(nt InodeType, device *device.Device) DentAttr {
+ return DentAttr{
+ Type: nt,
+ InodeID: device.NextIno(),
+ }
+}
+
+// DentrySerializer serializes a directory entry.
+type DentrySerializer interface {
+ // CopyOut serializes a directory entry based on its name and attributes.
+ CopyOut(name string, attributes DentAttr) error
+
+ // Written returns the number of bytes written.
+ Written() int
+}
+
+// CollectEntriesSerializer copies DentAttrs to Entries. The order in
+// which entries are encountered is preserved in Order.
+type CollectEntriesSerializer struct {
+ Entries map[string]DentAttr
+ Order []string
+}
+
+// CopyOut implements DentrySerializer.CopyOut.
+func (c *CollectEntriesSerializer) CopyOut(name string, attr DentAttr) error {
+ if c.Entries == nil {
+ c.Entries = make(map[string]DentAttr)
+ }
+ c.Entries[name] = attr
+ c.Order = append(c.Order, name)
+ return nil
+}
+
+// Written implements DentrySerializer.Written.
+func (c *CollectEntriesSerializer) Written() int {
+ return len(c.Entries)
+}
+
+// DirCtx is used in FileOperations.IterateDir to emit directory entries. It is
+// not thread-safe.
+type DirCtx struct {
+ // Serializer is used to serialize the node attributes.
+ Serializer DentrySerializer
+
+ // attrs are DentAttrs
+ attrs map[string]DentAttr
+
+ // DirCursor is the directory cursor.
+ DirCursor *string
+}
+
+// DirEmit is called for each directory entry.
+func (c *DirCtx) DirEmit(name string, attr DentAttr) error {
+ if c.Serializer != nil {
+ if err := c.Serializer.CopyOut(name, attr); err != nil {
+ return err
+ }
+ }
+ if c.attrs == nil {
+ c.attrs = make(map[string]DentAttr)
+ }
+ c.attrs[name] = attr
+ return nil
+}
+
+// DentAttrs returns a map of DentAttrs corresponding to the emitted directory
+// entries.
+func (c *DirCtx) DentAttrs() map[string]DentAttr {
+ if c.attrs == nil {
+ c.attrs = make(map[string]DentAttr)
+ }
+ return c.attrs
+}
+
+// GenericReaddir serializes DentAttrs based on a SortedDentryMap that must
+// contain _all_ up-to-date DentAttrs under a directory. If ctx.DirCursor is
+// not nil, it is updated to the name of the last DentAttr that was
+// successfully serialized.
+//
+// Returns the number of entries serialized.
+func GenericReaddir(ctx *DirCtx, s *SortedDentryMap) (int, error) {
+ // Retrieve the next directory entries.
+ var names []string
+ var entries map[string]DentAttr
+ if ctx.DirCursor != nil {
+ names, entries = s.GetNext(*ctx.DirCursor)
+ } else {
+ names, entries = s.GetAll()
+ }
+
+ // Try to serialize each entry.
+ var serialized int
+ for _, name := range names {
+ // Skip "" per POSIX. Skip "." and ".." which will be added by Dirent.Readdir.
+ if name == "" || name == "." || name == ".." {
+ continue
+ }
+
+ // Emit the directory entry.
+ if err := ctx.DirEmit(name, entries[name]); err != nil {
+ // Return potentially a partial serialized count.
+ return serialized, err
+ }
+
+ // We successfully serialized this entry.
+ serialized++
+
+ // Update the cursor with the name of the entry last serialized.
+ if ctx.DirCursor != nil {
+ *ctx.DirCursor = name
+ }
+ }
+
+ // Everything was serialized.
+ return serialized, nil
+}
+
+// SortedDentryMap is a sorted map of names and fs.DentAttr entries.
+//
+// +stateify savable
+type SortedDentryMap struct {
+ // names is always kept in sorted-order.
+ names []string
+
+ // entries maps names to fs.DentAttrs.
+ entries map[string]DentAttr
+}
+
+// NewSortedDentryMap maintains entries in name sorted order.
+func NewSortedDentryMap(entries map[string]DentAttr) *SortedDentryMap {
+ s := &SortedDentryMap{
+ names: make([]string, 0, len(entries)),
+ entries: entries,
+ }
+ // Don't allow s.entries to be nil, because nil maps arn't Saveable.
+ if s.entries == nil {
+ s.entries = make(map[string]DentAttr)
+ }
+
+ // Collect names from entries and sort them.
+ for name := range s.entries {
+ s.names = append(s.names, name)
+ }
+ sort.Strings(s.names)
+ return s
+}
+
+// GetAll returns all names and entries in s. Callers should not modify the
+// returned values.
+func (s *SortedDentryMap) GetAll() ([]string, map[string]DentAttr) {
+ return s.names, s.entries
+}
+
+// GetNext returns names after cursor in s and all entries.
+func (s *SortedDentryMap) GetNext(cursor string) ([]string, map[string]DentAttr) {
+ i := sort.SearchStrings(s.names, cursor)
+ if i == len(s.names) {
+ return nil, s.entries
+ }
+
+ // Return everything strictly after the cursor.
+ if s.names[i] == cursor {
+ i++
+ }
+ return s.names[i:], s.entries
+}
+
+// Add adds an entry with the given name to the map, preserving sort order. If
+// name already exists in the map, its entry will be overwritten.
+func (s *SortedDentryMap) Add(name string, entry DentAttr) {
+ if _, ok := s.entries[name]; !ok {
+ // Map does not yet contain an entry with this name. We must
+ // insert it in s.names at the appropriate spot.
+ i := sort.SearchStrings(s.names, name)
+ s.names = append(s.names, "")
+ copy(s.names[i+1:], s.names[i:])
+ s.names[i] = name
+ }
+ s.entries[name] = entry
+}
+
+// Remove removes an entry with the given name from the map, preserving sort order.
+func (s *SortedDentryMap) Remove(name string) {
+ if _, ok := s.entries[name]; !ok {
+ return
+ }
+ i := sort.SearchStrings(s.names, name)
+ copy(s.names[i:], s.names[i+1:])
+ s.names = s.names[:len(s.names)-1]
+ delete(s.entries, name)
+}
+
+// Contains reports whether the map contains an entry with the given name.
+func (s *SortedDentryMap) Contains(name string) bool {
+ _, ok := s.entries[name]
+ return ok
+}
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
new file mode 100644
index 000000000..34ac01173
--- /dev/null
+++ b/pkg/sentry/fs/dev/dev.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package dev provides a filesystem with simple devices.
+package dev
+
+import (
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ashmem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/binder"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Memory device numbers are from Linux's drivers/char/mem.c
+const (
+ // Mem device major.
+ memDevMajor uint16 = 1
+
+ // Mem device minors.
+ nullDevMinor uint32 = 3
+ zeroDevMinor uint32 = 5
+ fullDevMinor uint32 = 7
+ randomDevMinor uint32 = 8
+ urandomDevMinor uint32 = 9
+)
+
+func newCharacterDevice(iops fs.InodeOperations, msrc *fs.MountSource) *fs.Inode {
+ return fs.NewInode(iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.CharacterDevice,
+ })
+}
+
+func newMemDevice(iops fs.InodeOperations, msrc *fs.MountSource, minor uint32) *fs.Inode {
+ return fs.NewInode(iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.CharacterDevice,
+ DeviceFileMajor: memDevMajor,
+ DeviceFileMinor: minor,
+ })
+}
+
+func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return fs.NewInode(iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
+
+func newSymlink(ctx context.Context, target string, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewSymlink(ctx, fs.RootOwner, target)
+ return fs.NewInode(iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Symlink,
+ })
+}
+
+// New returns the root node of a device filesystem.
+func New(ctx context.Context, msrc *fs.MountSource, binderEnabled bool, ashmemEnabled bool) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "fd": newSymlink(ctx, "/proc/self/fd", msrc),
+ "stdin": newSymlink(ctx, "/proc/self/fd/0", msrc),
+ "stdout": newSymlink(ctx, "/proc/self/fd/1", msrc),
+ "stderr": newSymlink(ctx, "/proc/self/fd/2", msrc),
+
+ "null": newMemDevice(newNullDevice(ctx, fs.RootOwner, 0666), msrc, nullDevMinor),
+ "zero": newMemDevice(newZeroDevice(ctx, fs.RootOwner, 0666), msrc, zeroDevMinor),
+ "full": newMemDevice(newFullDevice(ctx, fs.RootOwner, 0666), msrc, fullDevMinor),
+
+ // This is not as good as /dev/random in linux because go
+ // runtime uses sys_random and /dev/urandom internally.
+ // According to 'man 4 random', this will be sufficient unless
+ // application uses this to generate long-lived GPG/SSL/SSH
+ // keys.
+ "random": newMemDevice(newRandomDevice(ctx, fs.RootOwner, 0444), msrc, randomDevMinor),
+ "urandom": newMemDevice(newRandomDevice(ctx, fs.RootOwner, 0444), msrc, urandomDevMinor),
+
+ "shm": tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc),
+
+ // A devpts is typically mounted at /dev/pts to provide
+ // pseudoterminal support. Place an empty directory there for
+ // the devpts to be mounted over.
+ "pts": newDirectory(ctx, msrc),
+ // Similarly, applications expect a ptmx device at /dev/ptmx
+ // connected to the terminals provided by /dev/pts/. Rather
+ // than creating a device directly (which requires a hairy
+ // lookup on open to determine if a devpts exists), just create
+ // a symlink to the ptmx provided by devpts. (The Linux devpts
+ // documentation recommends this).
+ //
+ // If no devpts is mounted, this will simply be a dangling
+ // symlink, which is fine.
+ "ptmx": newSymlink(ctx, "pts/ptmx", msrc),
+ }
+
+ if binderEnabled {
+ binder := binder.NewDevice(ctx, fs.RootOwner, fs.FilePermsFromMode(0666))
+ contents["binder"] = newCharacterDevice(binder, msrc)
+ }
+
+ if ashmemEnabled {
+ ashmem := ashmem.NewDevice(ctx, fs.RootOwner, fs.FilePermsFromMode(0666))
+ contents["ashmem"] = newCharacterDevice(ashmem, msrc)
+ }
+
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return fs.NewInode(iops, msrc, fs.StableAttr{
+ DeviceID: devDevice.DeviceID(),
+ InodeID: devDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
+
+// readZeros implements fs.FileOperations.Read with infinite null bytes.
+type readZeros struct{}
+
+// Read implements fs.FileOperations.Read.
+func (*readZeros) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ return dst.ZeroOut(ctx, math.MaxInt64)
+}
diff --git a/pkg/sentry/fs/dev/dev_state_autogen.go b/pkg/sentry/fs/dev/dev_state_autogen.go
new file mode 100755
index 000000000..afe94ff86
--- /dev/null
+++ b/pkg/sentry/fs/dev/dev_state_autogen.go
@@ -0,0 +1,108 @@
+// automatically generated by stateify.
+
+package dev
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *filesystem) beforeSave() {}
+func (x *filesystem) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *filesystem) afterLoad() {}
+func (x *filesystem) load(m state.Map) {
+}
+
+func (x *fullDevice) beforeSave() {}
+func (x *fullDevice) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *fullDevice) afterLoad() {}
+func (x *fullDevice) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *fullFileOperations) beforeSave() {}
+func (x *fullFileOperations) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *fullFileOperations) afterLoad() {}
+func (x *fullFileOperations) load(m state.Map) {
+}
+
+func (x *nullDevice) beforeSave() {}
+func (x *nullDevice) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *nullDevice) afterLoad() {}
+func (x *nullDevice) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *nullFileOperations) beforeSave() {}
+func (x *nullFileOperations) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *nullFileOperations) afterLoad() {}
+func (x *nullFileOperations) load(m state.Map) {
+}
+
+func (x *zeroDevice) beforeSave() {}
+func (x *zeroDevice) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nullDevice", &x.nullDevice)
+}
+
+func (x *zeroDevice) afterLoad() {}
+func (x *zeroDevice) load(m state.Map) {
+ m.Load("nullDevice", &x.nullDevice)
+}
+
+func (x *zeroFileOperations) beforeSave() {}
+func (x *zeroFileOperations) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *zeroFileOperations) afterLoad() {}
+func (x *zeroFileOperations) load(m state.Map) {
+}
+
+func (x *randomDevice) beforeSave() {}
+func (x *randomDevice) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *randomDevice) afterLoad() {}
+func (x *randomDevice) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *randomFileOperations) beforeSave() {}
+func (x *randomFileOperations) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *randomFileOperations) afterLoad() {}
+func (x *randomFileOperations) load(m state.Map) {
+}
+
+func init() {
+ state.Register("dev.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load})
+ state.Register("dev.fullDevice", (*fullDevice)(nil), state.Fns{Save: (*fullDevice).save, Load: (*fullDevice).load})
+ state.Register("dev.fullFileOperations", (*fullFileOperations)(nil), state.Fns{Save: (*fullFileOperations).save, Load: (*fullFileOperations).load})
+ state.Register("dev.nullDevice", (*nullDevice)(nil), state.Fns{Save: (*nullDevice).save, Load: (*nullDevice).load})
+ state.Register("dev.nullFileOperations", (*nullFileOperations)(nil), state.Fns{Save: (*nullFileOperations).save, Load: (*nullFileOperations).load})
+ state.Register("dev.zeroDevice", (*zeroDevice)(nil), state.Fns{Save: (*zeroDevice).save, Load: (*zeroDevice).load})
+ state.Register("dev.zeroFileOperations", (*zeroFileOperations)(nil), state.Fns{Save: (*zeroFileOperations).save, Load: (*zeroFileOperations).load})
+ state.Register("dev.randomDevice", (*randomDevice)(nil), state.Fns{Save: (*randomDevice).save, Load: (*randomDevice).load})
+ state.Register("dev.randomFileOperations", (*randomFileOperations)(nil), state.Fns{Save: (*randomFileOperations).save, Load: (*randomFileOperations).load})
+}
diff --git a/pkg/sentry/fs/dev/device.go b/pkg/sentry/fs/dev/device.go
new file mode 100644
index 000000000..9f4e41fc9
--- /dev/null
+++ b/pkg/sentry/fs/dev/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// devDevice is the pseudo-filesystem device.
+var devDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/dev/fs.go b/pkg/sentry/fs/dev/fs.go
new file mode 100644
index 000000000..6096a40f8
--- /dev/null
+++ b/pkg/sentry/fs/dev/fs.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Optional key containing boolean flag which specifies if Android Binder IPC should be enabled.
+const binderEnabledKey = "binder_enabled"
+
+// Optional key containing boolean flag which specifies if Android ashmem should be enabled.
+const ashmemEnabledKey = "ashmem_enabled"
+
+// filesystem is a devtmpfs.
+//
+// +stateify savable
+type filesystem struct{}
+
+var _ fs.Filesystem = (*filesystem)(nil)
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name underwhich the filesystem is registered.
+// Name matches drivers/base/devtmpfs.c:dev_fs_type.name.
+const FilesystemName = "devtmpfs"
+
+// Name is the name of the file system.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*filesystem) AllowUserMount() bool {
+ return true
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// In Linux, devtmpfs does the same thing.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns a devtmpfs root that can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+ // devtmpfs backed by ramfs ignores bad options. See fs/ramfs/inode.c:ramfs_parse_options.
+ // -> we should consider parsing the mode and backing devtmpfs by this.
+
+ // Parse generic comma-separated key=value options.
+ options := fs.GenericMountSourceOptions(data)
+
+ // binerEnabledKey is optional and binder is disabled by default.
+ binderEnabled := false
+ if beStr, exists := options[binderEnabledKey]; exists {
+ var err error
+ binderEnabled, err = strconv.ParseBool(beStr)
+ if err != nil {
+ return nil, syserror.EINVAL
+ }
+ }
+
+ // ashmemEnabledKey is optional and ashmem is disabled by default.
+ ashmemEnabled := false
+ if aeStr, exists := options[ashmemEnabledKey]; exists {
+ var err error
+ ashmemEnabled, err = strconv.ParseBool(aeStr)
+ if err != nil {
+ return nil, syserror.EINVAL
+ }
+ }
+
+ // Construct the devtmpfs root.
+ return New(ctx, fs.NewNonCachingMountSource(f, flags), binderEnabled, ashmemEnabled), nil
+}
diff --git a/pkg/sentry/fs/dev/full.go b/pkg/sentry/fs/dev/full.go
new file mode 100644
index 000000000..8f6c6da2d
--- /dev/null
+++ b/pkg/sentry/fs/dev/full.go
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// fullDevice is used to implement /dev/full.
+//
+// +stateify savable
+type fullDevice struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*fullDevice)(nil)
+
+func newFullDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *fullDevice {
+ f := &fullDevice{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+ return f
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (f *fullDevice) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &fullFileOperations{}), nil
+}
+
+// +stateify savable
+type fullFileOperations struct {
+ waiter.AlwaysReady `state:"nosave"`
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ readZeros `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*fullFileOperations)(nil)
+
+// Write implements FileOperations.Write.
+func (*fullFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.ENOSPC
+}
diff --git a/pkg/sentry/fs/dev/null.go b/pkg/sentry/fs/dev/null.go
new file mode 100644
index 000000000..3f1accef8
--- /dev/null
+++ b/pkg/sentry/fs/dev/null.go
@@ -0,0 +1,130 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type nullDevice struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*nullDevice)(nil)
+
+func newNullDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *nullDevice {
+ n := &nullDevice{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+ return n
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (n *nullDevice) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+
+ return fs.NewFile(ctx, dirent, flags, &nullFileOperations{}), nil
+}
+
+// +stateify savable
+type nullFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRead `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*nullFileOperations)(nil)
+
+// +stateify savable
+type zeroDevice struct {
+ nullDevice
+}
+
+var _ fs.InodeOperations = (*zeroDevice)(nil)
+
+func newZeroDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *zeroDevice {
+ zd := &zeroDevice{
+ nullDevice: nullDevice{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ },
+ }
+ return zd
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (zd *zeroDevice) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+
+ return fs.NewFile(ctx, dirent, flags, &zeroFileOperations{}), nil
+}
+
+// +stateify savable
+type zeroFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+ readZeros `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*zeroFileOperations)(nil)
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (*zeroFileOperations) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
+ if err != nil {
+ return err
+ }
+ opts.MappingIdentity = m
+ opts.Mappable = m
+ return nil
+}
diff --git a/pkg/sentry/fs/dev/random.go b/pkg/sentry/fs/dev/random.go
new file mode 100644
index 000000000..e5a01a906
--- /dev/null
+++ b/pkg/sentry/fs/dev/random.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type randomDevice struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*randomDevice)(nil)
+
+func newRandomDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *randomDevice {
+ r := &randomDevice{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+ return r
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (*randomDevice) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &randomFileOperations{}), nil
+}
+
+// +stateify savable
+type randomFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*randomFileOperations)(nil)
+
+// Read implements fs.FileOperations.Read.
+func (*randomFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ return dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader})
+}
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
new file mode 100644
index 000000000..c0bc261a2
--- /dev/null
+++ b/pkg/sentry/fs/dirent.go
@@ -0,0 +1,1675 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "path"
+ "sort"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+type globalDirentMap struct {
+ mu sync.Mutex
+ dirents map[*Dirent]struct{}
+}
+
+func (g *globalDirentMap) add(d *Dirent) {
+ g.mu.Lock()
+ g.dirents[d] = struct{}{}
+ g.mu.Unlock()
+}
+
+func (g *globalDirentMap) remove(d *Dirent) {
+ g.mu.Lock()
+ delete(g.dirents, d)
+ g.mu.Unlock()
+}
+
+// allDirents keeps track of all Dirents that need to be considered in
+// Save/Restore for inode mappings.
+//
+// Because inodes do not hold paths, but inodes for external file systems map
+// to an external path, every user-visible Dirent is stored in this map and
+// iterated through upon save to keep inode ID -> restore path mappings.
+var allDirents = globalDirentMap{
+ dirents: map[*Dirent]struct{}{},
+}
+
+// renameMu protects the parent of *all* Dirents. (See explanation in
+// lockForRename.)
+//
+// See fs.go for lock ordering.
+var renameMu sync.RWMutex
+
+// Dirent holds an Inode in memory.
+//
+// A Dirent may be negative or positive:
+//
+// A negative Dirent contains a nil Inode and indicates that a path does not exist. This
+// is a convention taken from the Linux dcache, see fs/dcache.c. A negative Dirent remains
+// cached until a create operation replaces it with a positive Dirent. A negative Dirent
+// always has one reference owned by its parent and takes _no_ reference on its parent. This
+// ensures that its parent can be unhashed regardless of negative children.
+//
+// A positive Dirent contains a non-nil Inode. It remains cached for as long as there remain
+// references to it. A positive Dirent always takes a reference on its parent.
+//
+// A Dirent may be a root Dirent (parent is nil) or be parented (non-nil parent).
+//
+// Dirents currently do not attempt to free entries that lack application references under
+// memory pressure.
+//
+// +stateify savable
+type Dirent struct {
+ // AtomicRefCount is our reference count.
+ refs.AtomicRefCount
+
+ // userVisible indicates whether the Dirent is visible to the user or
+ // not. Only user-visible Dirents should save inode mappings in
+ // save/restore, as only they hold the real path to the underlying
+ // inode.
+ //
+ // See newDirent and Dirent.afterLoad.
+ userVisible bool
+
+ // Inode is the underlying file object.
+ //
+ // Inode is exported currently to assist in implementing overlay Inodes (where a
+ // Inode.InodeOperations.Lookup may need to merge the Inode contained in a positive Dirent with
+ // another Inode). This is normally done before the Dirent is parented (there are
+ // no external references to it).
+ //
+ // Other objects in the VFS may take a reference to this Inode but only while holding
+ // a reference to this Dirent.
+ Inode *Inode
+
+ // name is the name (i.e. basename) of this entry.
+ //
+ // N.B. name is protected by parent.mu, not this node's mu!
+ name string
+
+ // parent is the parent directory.
+ //
+ // We hold a hard reference to the parent.
+ //
+ // parent is protected by renameMu.
+ parent *Dirent
+
+ // deleted may be set atomically when removed.
+ deleted int32
+
+ // frozen indicates this entry can't walk to unknown nodes.
+ frozen bool
+
+ // mounted is true if Dirent is a mount point, similar to include/linux/dcache.h:DCACHE_MOUNTED.
+ mounted bool
+
+ // direntEntry identifies this Dirent as an element in a DirentCache. DirentCaches
+ // and their contents are not saved.
+ direntEntry `state:"nosave"`
+
+ // dirMu is a read-write mutex that protects caching decisions made by directory operations.
+ // Lock ordering: dirMu must be taken before mu (see below). Details:
+ //
+ // dirMu does not participate in Rename; instead mu and renameMu are used, see lockForRename.
+ //
+ // Creation and Removal operations must be synchronized with Walk to prevent stale negative
+ // caching. Note that this requirement is not specific to a _Dirent_ doing negative caching.
+ // The following race exists at any level of the VFS:
+ //
+ // For an object D that represents a directory, containing a cache of non-existent paths,
+ // protected by D.cacheMu:
+ //
+ // T1: T2:
+ // D.lookup(name)
+ // --> ENOENT
+ // D.create(name)
+ // --> success
+ // D.cacheMu.Lock
+ // delete(D.cache, name)
+ // D.cacheMu.Unlock
+ // D.cacheMu.Lock
+ // D.cache[name] = true
+ // D.cacheMu.Unlock
+ //
+ // D.lookup(name)
+ // D.cacheMu.Lock
+ // if D.cache[name] {
+ // --> ENOENT (wrong)
+ // }
+ // D.cacheMu.Lock
+ //
+ // Correct:
+ //
+ // T1: T2:
+ // D.cacheMu.Lock
+ // D.lookup(name)
+ // --> ENOENT
+ // D.cache[name] = true
+ // D.cacheMu.Unlock
+ // D.cacheMu.Lock
+ // D.create(name)
+ // --> success
+ // delete(D.cache, name)
+ // D.cacheMu.Unlock
+ //
+ // D.cacheMu.Lock
+ // D.lookup(name)
+ // --> EXISTS (right)
+ // D.cacheMu.Unlock
+ //
+ // Note that the above "correct" solution causes too much lock contention: all lookups are
+ // synchronized with each other. This is a problem because lookups are involved in any VFS
+ // path operation.
+ //
+ // A Dirent diverges from the single D.cacheMu and instead uses two locks: dirMu to protect
+ // concurrent creation/removal/lookup caching, and mu to protect the Dirent's children map
+ // in general.
+ //
+ // This allows for concurrent Walks to be executed in order to pipeline lookups. For instance
+ // for a hot directory /a/b, threads T1, T2, T3 will only block on each other update the
+ // children map of /a/b when their individual lookups complete.
+ //
+ // T1: T2: T3:
+ // stat(/a/b/c) stat(/a/b/d) stat(/a/b/e)
+ dirMu sync.RWMutex `state:"nosave"`
+
+ // mu protects the below fields. Lock ordering: mu must be taken after dirMu.
+ mu sync.Mutex `state:"nosave"`
+
+ // children are cached via weak references.
+ children map[string]*refs.WeakRef `state:".(map[string]*Dirent)"`
+}
+
+// NewDirent returns a new root Dirent, taking the caller's reference on inode. The caller
+// holds the only reference to the Dirent. Parents may call hashChild to parent this Dirent.
+func NewDirent(inode *Inode, name string) *Dirent {
+ d := newDirent(inode, name)
+ allDirents.add(d)
+ d.userVisible = true
+ return d
+}
+
+// NewTransientDirent creates a transient Dirent that shouldn't actually be
+// visible to users.
+//
+// An Inode is required.
+func NewTransientDirent(inode *Inode) *Dirent {
+ if inode == nil {
+ panic("an inode is required")
+ }
+ return newDirent(inode, "transient")
+}
+
+func newDirent(inode *Inode, name string) *Dirent {
+ // The Dirent needs to maintain one reference to MountSource.
+ if inode != nil {
+ inode.MountSource.IncDirentRefs()
+ }
+ return &Dirent{
+ Inode: inode,
+ name: name,
+ children: make(map[string]*refs.WeakRef),
+ }
+}
+
+// NewNegativeDirent returns a new root negative Dirent. Otherwise same as NewDirent.
+func NewNegativeDirent(name string) *Dirent {
+ return newDirent(nil, name)
+}
+
+// IsRoot returns true if d is a root Dirent.
+func (d *Dirent) IsRoot() bool {
+ return d.parent == nil
+}
+
+// IsNegative returns true if d represents a path that does not exist.
+func (d *Dirent) IsNegative() bool {
+ return d.Inode == nil
+}
+
+// hashChild will hash child into the children list of its new parent d, carrying over
+// any "frozen" state from d.
+//
+// Returns (*WeakRef, true) if hashing child caused a Dirent to be unhashed. The caller must
+// validate the returned unhashed weak reference. Common cases:
+//
+// * Remove: hashing a negative Dirent unhashes a positive Dirent (unimplemented).
+// * Create: hashing a positive Dirent unhashes a negative Dirent.
+// * Lookup: hashing any Dirent should not unhash any other Dirent.
+//
+// Preconditions:
+// * d.mu must be held.
+// * child must be a root Dirent.
+func (d *Dirent) hashChild(child *Dirent) (*refs.WeakRef, bool) {
+ if !child.IsRoot() {
+ panic("hashChild must be a root Dirent")
+ }
+
+ // Assign parentage.
+ child.parent = d
+
+ // Avoid letting negative Dirents take a reference on their parent; these Dirents
+ // don't have a role outside of the Dirent cache and should not keep their parent
+ // indefinitely pinned.
+ if !child.IsNegative() {
+ // Positive dirents must take a reference on their parent.
+ d.IncRef()
+ }
+
+ // Carry over parent's frozen state.
+ child.frozen = d.frozen
+
+ return d.hashChildParentSet(child)
+}
+
+// hashChildParentSet will rehash child into the children list of its parent d.
+//
+// Assumes that child.parent = d already.
+func (d *Dirent) hashChildParentSet(child *Dirent) (*refs.WeakRef, bool) {
+ if child.parent != d {
+ panic("hashChildParentSet assumes the child already belongs to the parent")
+ }
+
+ // Save any replaced child so our caller can validate it.
+ old, ok := d.children[child.name]
+
+ // Hash the child.
+ d.children[child.name] = refs.NewWeakRef(child, nil)
+
+ // Return any replaced child.
+ return old, ok
+}
+
+// SyncAll iterates through mount points under d and writes back their buffered
+// modifications to filesystems.
+func (d *Dirent) SyncAll(ctx context.Context) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // For negative Dirents there is nothing to sync. By definition these are
+ // leaves (there is nothing left to traverse).
+ if d.IsNegative() {
+ return
+ }
+
+ // There is nothing to sync for a read-only filesystem.
+ if !d.Inode.MountSource.Flags.ReadOnly {
+ // FIXME(b/34856369): This should be a mount traversal, not a
+ // Dirent traversal, because some Inodes that need to be synced
+ // may no longer be reachable by name (after sys_unlink).
+ //
+ // Write out metadata, dirty page cached pages, and sync disk/remote
+ // caches.
+ d.Inode.WriteOut(ctx)
+ }
+
+ // Continue iterating through other mounted filesystems.
+ for _, w := range d.children {
+ if child := w.Get(); child != nil {
+ child.(*Dirent).SyncAll(ctx)
+ child.DecRef()
+ }
+ }
+}
+
+// BaseName returns the base name of the dirent.
+func (d *Dirent) BaseName() string {
+ p := d.parent
+ if p == nil {
+ return d.name
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return d.name
+}
+
+// FullName returns the fully-qualified name and a boolean value representing
+// whether this Dirent was a descendant of root.
+// If the root argument is nil it is assumed to be the root of the Dirent tree.
+func (d *Dirent) FullName(root *Dirent) (string, bool) {
+ renameMu.RLock()
+ defer renameMu.RUnlock()
+ return d.fullName(root)
+}
+
+// fullName returns the fully-qualified name and a boolean value representing
+// if the root node was reachable from this Dirent.
+func (d *Dirent) fullName(root *Dirent) (string, bool) {
+ if d == root {
+ return "/", true
+ }
+
+ if d.IsRoot() {
+ if root != nil {
+ // We reached the top of the Dirent tree but did not encounter
+ // the given root. Return false for reachable so the caller
+ // can handle this situation accordingly.
+ return d.name, false
+ }
+ return d.name, true
+ }
+
+ // Traverse up to parent.
+ d.parent.mu.Lock()
+ name := d.name
+ d.parent.mu.Unlock()
+ parentName, reachable := d.parent.fullName(root)
+ s := path.Join(parentName, name)
+ if atomic.LoadInt32(&d.deleted) != 0 {
+ return s + " (deleted)", reachable
+ }
+ return s, reachable
+}
+
+// MountRoot finds and returns the mount-root for a given dirent.
+func (d *Dirent) MountRoot() *Dirent {
+ renameMu.RLock()
+ defer renameMu.RUnlock()
+
+ mountRoot := d
+ for !mountRoot.mounted && mountRoot.parent != nil {
+ mountRoot = mountRoot.parent
+ }
+ mountRoot.IncRef()
+ return mountRoot
+}
+
+// Freeze prevents this dirent from walking to more nodes. Freeze is applied
+// recursively to all children.
+//
+// If this particular Dirent represents a Virtual node, then Walks and Creates
+// may proceed as before.
+//
+// Freeze can only be called before the application starts running, otherwise
+// the root it might be out of sync with the application root if modified by
+// sys_chroot.
+func (d *Dirent) Freeze() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ if d.frozen {
+ // Already frozen.
+ return
+ }
+ d.frozen = true
+
+ // Take a reference when freezing.
+ for _, w := range d.children {
+ if child := w.Get(); child != nil {
+ // NOTE: We would normally drop the reference here. But
+ // instead we're hanging on to it.
+ ch := child.(*Dirent)
+ ch.Freeze()
+ }
+ }
+
+ // Drop all expired weak references.
+ d.flush()
+}
+
+// descendantOf returns true if the receiver dirent is equal to, or a
+// descendant of, the argument dirent.
+//
+// d.mu must be held.
+func (d *Dirent) descendantOf(p *Dirent) bool {
+ if d == p {
+ return true
+ }
+ if d.IsRoot() {
+ return false
+ }
+ return d.parent.descendantOf(p)
+}
+
+// walk walks to path name starting at the dirent, and will not traverse above
+// root Dirent.
+//
+// If walkMayUnlock is true then walk can unlock d.mu to execute a slow
+// Inode.Lookup, otherwise walk will keep d.mu locked.
+//
+// Preconditions:
+// - renameMu must be held for reading.
+// - d.mu must be held.
+// - name must must not contain "/"s.
+func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnlock bool) (*Dirent, error) {
+ if !IsDir(d.Inode.StableAttr) {
+ return nil, syscall.ENOTDIR
+ }
+
+ if name == "" || name == "." {
+ d.IncRef()
+ return d, nil
+ } else if name == ".." {
+ // Respect the chroot. Note that in Linux there is no check to enforce
+ // that d is a descendant of root.
+ if d == root {
+ d.IncRef()
+ return d, nil
+ }
+ // Are we already at the root? Then ".." is ".".
+ if d.IsRoot() {
+ d.IncRef()
+ return d, nil
+ }
+ d.parent.IncRef()
+ return d.parent, nil
+ }
+
+ if w, ok := d.children[name]; ok {
+ // Try to resolve the weak reference to a hard reference.
+ if child := w.Get(); child != nil {
+ cd := child.(*Dirent)
+
+ // Is this a negative Dirent?
+ if cd.IsNegative() {
+ // Don't leak a reference; this doesn't matter as much for negative Dirents,
+ // which don't hold a hard reference on their parent (their parent holds a
+ // hard reference on them, and they contain virtually no state). But this is
+ // good house-keeping.
+ child.DecRef()
+ return nil, syscall.ENOENT
+ }
+
+ // Do we need to revalidate this child?
+ //
+ // We never allow the file system to revalidate mounts, that could cause them
+ // to unexpectedly drop out before umount.
+ if cd.mounted || !cd.Inode.MountSource.Revalidate(ctx, name, d.Inode, cd.Inode) {
+ // Good to go. This is the fast-path.
+ return cd, nil
+ }
+
+ // If we're revalidating a child, we must ensure all inotify watches release
+ // their pins on the child. Inotify doesn't properly support filesystems that
+ // revalidate dirents (since watches are lost on revalidation), but if we fail
+ // to unpin the watches child will never be GCed.
+ cd.Inode.Watches.Unpin(cd)
+
+ // This child needs to be revalidated, fallthrough to unhash it. Make sure
+ // to not leak a reference from Get().
+ //
+ // Note that previous lookups may still have a reference to this stale child;
+ // this can't be helped, but we can ensure that *new* lookups are up-to-date.
+ child.DecRef()
+ }
+
+ // Either our weak reference expired or we need to revalidate it. Unhash child first, we're
+ // about to replace it.
+ delete(d.children, name)
+ w.Drop()
+ }
+
+ // Are we allowed to do the lookup?
+ if d.frozen && !d.Inode.IsVirtual() {
+ return nil, syscall.ENOENT
+ }
+
+ // Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be
+ // expensive, if possible release the lock and re-acquire it.
+ if walkMayUnlock {
+ d.mu.Unlock()
+ }
+ c, err := d.Inode.Lookup(ctx, name)
+ if walkMayUnlock {
+ d.mu.Lock()
+ }
+ // No dice.
+ if err != nil {
+ return nil, err
+ }
+
+ // Sanity check c, its name must be consistent.
+ if c.name != name {
+ panic(fmt.Sprintf("lookup from %q to %q returned unexpected name %q", d.name, name, c.name))
+ }
+
+ // Now that we have the lock again, check if we raced.
+ if w, ok := d.children[name]; ok {
+ // Someone else looked up or created a child at name before us.
+ if child := w.Get(); child != nil {
+ cd := child.(*Dirent)
+
+ // There are active references to the existing child, prefer it to the one we
+ // retrieved from Lookup. Likely the Lookup happened very close to the insertion
+ // of child, so considering one stale over the other is fairly arbitrary.
+ c.DecRef()
+
+ // The child that was installed could be negative.
+ if cd.IsNegative() {
+ // If so, don't leak a reference and short circuit.
+ child.DecRef()
+ return nil, syscall.ENOENT
+ }
+
+ // We make the judgement call that if c raced with cd they are close enough to have
+ // the same staleness, so we don't attempt to revalidate cd. In Linux revalidations
+ // can continue indefinitely (see fs/namei.c, retry_estale); we try to avoid this.
+ return cd, nil
+ }
+
+ // Weak reference expired. We went through a full cycle of create/destroy in the time
+ // we did the Inode.Lookup. Fully drop the weak reference and fallback to using the child
+ // we looked up.
+ delete(d.children, name)
+ w.Drop()
+ }
+
+ // Give the looked up child a parent. We cannot kick out entries, since we just checked above
+ // that there is nothing at name in d's children list.
+ if _, kicked := d.hashChild(c); kicked {
+ // Yell loudly.
+ panic(fmt.Sprintf("hashed child %q over existing child", c.name))
+ }
+
+ // Is this a negative Dirent?
+ if c.IsNegative() {
+ // Don't drop a reference on the negative Dirent, it was just installed and this is the
+ // only reference we'll ever get. d owns the reference.
+ return nil, syscall.ENOENT
+ }
+
+ // Return the positive Dirent.
+ return c, nil
+}
+
+// Walk walks to a new dirent, and will not walk higher than the given root
+// Dirent, which must not be nil.
+func (d *Dirent) Walk(ctx context.Context, root *Dirent, name string) (*Dirent, error) {
+ if root == nil {
+ panic("Dirent.Walk: root must not be nil")
+ }
+
+ // We could use lockDirectory here, but this is a hot path and we want
+ // to avoid defer.
+ renameMu.RLock()
+ d.dirMu.RLock()
+ d.mu.Lock()
+
+ child, err := d.walk(ctx, root, name, true /* may unlock */)
+
+ d.mu.Unlock()
+ d.dirMu.RUnlock()
+ renameMu.RUnlock()
+
+ return child, err
+}
+
+// exists returns true if name exists in relation to d.
+//
+// Preconditions:
+// - renameMu must be held for reading.
+// - d.mu must be held.
+// - name must must not contain "/"s.
+func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool {
+ child, err := d.walk(ctx, root, name, false /* may unlock */)
+ if err != nil {
+ // Child may not exist.
+ return false
+ }
+ // Child exists.
+ child.DecRef()
+ return true
+}
+
+// lockDirectory should be called for any operation that changes this `d`s
+// children (creating or removing them).
+func (d *Dirent) lockDirectory() func() {
+ renameMu.RLock()
+ d.dirMu.Lock()
+ d.mu.Lock()
+ return func() {
+ d.mu.Unlock()
+ d.dirMu.Unlock()
+ renameMu.RUnlock()
+ }
+}
+
+// Create creates a new regular file in this directory.
+func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags FileFlags, perms FilePermissions) (*File, error) {
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ // Does something already exist?
+ if d.exists(ctx, root, name) {
+ return nil, syscall.EEXIST
+ }
+
+ // Are we frozen?
+ if d.frozen && !d.Inode.IsVirtual() {
+ return nil, syscall.ENOENT
+ }
+
+ // Try the create. We need to trust the file system to return EEXIST (or something
+ // that will translate to EEXIST) if name already exists.
+ file, err := d.Inode.Create(ctx, d, name, flags, perms)
+ if err != nil {
+ return nil, err
+ }
+ child := file.Dirent
+
+ d.finishCreate(child, name)
+
+ // Return the reference and the new file. When the last reference to
+ // the file is dropped, file.Dirent may no longer be cached.
+ return file, nil
+}
+
+// finishCreate validates the created file, adds it as a child of this dirent,
+// and notifies any watchers.
+func (d *Dirent) finishCreate(child *Dirent, name string) {
+ // Sanity check c, its name must be consistent.
+ if child.name != name {
+ panic(fmt.Sprintf("create from %q to %q returned unexpected name %q", d.name, name, child.name))
+ }
+
+ // File systems cannot return a negative Dirent on Create, that makes no sense.
+ if child.IsNegative() {
+ panic(fmt.Sprintf("create from %q to %q returned negative Dirent", d.name, name))
+ }
+
+ // Hash the child into its parent. We can only kick out a Dirent if it is negative
+ // (we are replacing something that does not exist with something that now does).
+ if w, kicked := d.hashChild(child); kicked {
+ if old := w.Get(); old != nil {
+ if !old.(*Dirent).IsNegative() {
+ panic(fmt.Sprintf("hashed child %q over a positive child", child.name))
+ }
+ // Don't leak a reference.
+ old.DecRef()
+
+ // Drop d's reference.
+ old.DecRef()
+ }
+
+ // Finally drop the useless weak reference on the floor.
+ w.Drop()
+ }
+
+ d.Inode.Watches.Notify(name, linux.IN_CREATE, 0)
+
+ // Allow the file system to take extra references on c.
+ child.maybeExtendReference()
+}
+
+// genericCreate executes create if name does not exist. Removes a negative Dirent at name if
+// create succeeds.
+func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, create func() error) error {
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ // Does something already exist?
+ if d.exists(ctx, root, name) {
+ return syscall.EEXIST
+ }
+
+ // Are we frozen?
+ if d.frozen && !d.Inode.IsVirtual() {
+ return syscall.ENOENT
+ }
+
+ // Remove any negative Dirent. We've already asserted above with d.exists
+ // that the only thing remaining here can be a negative Dirent.
+ if w, ok := d.children[name]; ok {
+ // Same as Create.
+ if old := w.Get(); old != nil {
+ if !old.(*Dirent).IsNegative() {
+ panic(fmt.Sprintf("hashed over a positive child %q", old.(*Dirent).name))
+ }
+ // Don't leak a reference.
+ old.DecRef()
+
+ // Drop d's reference.
+ old.DecRef()
+ }
+
+ // Unhash the negative Dirent, name needs to exist now.
+ delete(d.children, name)
+
+ // Finally drop the useless weak reference on the floor.
+ w.Drop()
+ }
+
+ // Execute the create operation.
+ return create()
+}
+
+// CreateLink creates a new link in this directory.
+func (d *Dirent) CreateLink(ctx context.Context, root *Dirent, oldname, newname string) error {
+ return d.genericCreate(ctx, root, newname, func() error {
+ if err := d.Inode.CreateLink(ctx, d, oldname, newname); err != nil {
+ return err
+ }
+ d.Inode.Watches.Notify(newname, linux.IN_CREATE, 0)
+ return nil
+ })
+}
+
+// CreateHardLink creates a new hard link in this directory.
+func (d *Dirent) CreateHardLink(ctx context.Context, root *Dirent, target *Dirent, name string) error {
+ // Make sure that target does not span filesystems.
+ if d.Inode.MountSource != target.Inode.MountSource {
+ return syscall.EXDEV
+ }
+
+ // Directories are never linkable. See fs/namei.c:vfs_link.
+ if IsDir(target.Inode.StableAttr) {
+ return syscall.EPERM
+ }
+
+ return d.genericCreate(ctx, root, name, func() error {
+ if err := d.Inode.CreateHardLink(ctx, d, target, name); err != nil {
+ return err
+ }
+ target.Inode.Watches.Notify("", linux.IN_ATTRIB, 0) // Link count change.
+ d.Inode.Watches.Notify(name, linux.IN_CREATE, 0)
+ return nil
+ })
+}
+
+// CreateDirectory creates a new directory under this dirent.
+func (d *Dirent) CreateDirectory(ctx context.Context, root *Dirent, name string, perms FilePermissions) error {
+ return d.genericCreate(ctx, root, name, func() error {
+ if err := d.Inode.CreateDirectory(ctx, d, name, perms); err != nil {
+ return err
+ }
+ d.Inode.Watches.Notify(name, linux.IN_ISDIR|linux.IN_CREATE, 0)
+ return nil
+ })
+}
+
+// Bind satisfies the InodeOperations interface; otherwise same as GetFile.
+func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data transport.BoundEndpoint, perms FilePermissions) (*Dirent, error) {
+ var childDir *Dirent
+ err := d.genericCreate(ctx, root, name, func() error {
+ var e error
+ childDir, e = d.Inode.Bind(ctx, name, data, perms)
+ if e != nil {
+ return e
+ }
+ d.finishCreate(childDir, name)
+ return nil
+ })
+ if err == syscall.EEXIST {
+ return nil, syscall.EADDRINUSE
+ }
+ if err != nil {
+ return nil, err
+ }
+ return childDir, err
+}
+
+// CreateFifo creates a new named pipe under this dirent.
+func (d *Dirent) CreateFifo(ctx context.Context, root *Dirent, name string, perms FilePermissions) error {
+ return d.genericCreate(ctx, root, name, func() error {
+ if err := d.Inode.CreateFifo(ctx, d, name, perms); err != nil {
+ return err
+ }
+ d.Inode.Watches.Notify(name, linux.IN_CREATE, 0)
+ return nil
+ })
+}
+
+// GetDotAttrs returns the DentAttrs corresponding to "." and ".." directories.
+func (d *Dirent) GetDotAttrs(root *Dirent) (DentAttr, DentAttr) {
+ // Get '.'.
+ sattr := d.Inode.StableAttr
+ dot := DentAttr{
+ Type: sattr.Type,
+ InodeID: sattr.InodeID,
+ }
+
+ // Hold d.mu while we call d.descendantOf.
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Get '..'.
+ if !d.IsRoot() && d.descendantOf(root) {
+ // Dirent is a descendant of the root. Get its parent's attrs.
+ psattr := d.parent.Inode.StableAttr
+ dotdot := DentAttr{
+ Type: psattr.Type,
+ InodeID: psattr.InodeID,
+ }
+ return dot, dotdot
+ }
+ // Dirent is either root or not a descendant of the root. ".." is the
+ // same as ".".
+ return dot, dot
+}
+
+// readdirFrozen returns readdir results based solely on the frozen children.
+func (d *Dirent) readdirFrozen(root *Dirent, offset int64, dirCtx *DirCtx) (int64, error) {
+ // Collect attrs for "." and "..".
+ attrs := make(map[string]DentAttr)
+ names := []string{".", ".."}
+ attrs["."], attrs[".."] = d.GetDotAttrs(root)
+
+ // Get info from all children.
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ for name, w := range d.children {
+ if child := w.Get(); child != nil {
+ defer child.DecRef()
+
+ // Skip negative children.
+ if child.(*Dirent).IsNegative() {
+ continue
+ }
+
+ sattr := child.(*Dirent).Inode.StableAttr
+ attrs[name] = DentAttr{
+ Type: sattr.Type,
+ InodeID: sattr.InodeID,
+ }
+ names = append(names, name)
+ }
+ }
+
+ sort.Strings(names)
+
+ if int(offset) >= len(names) {
+ return offset, nil
+ }
+ names = names[int(offset):]
+ for _, name := range names {
+ if err := dirCtx.DirEmit(name, attrs[name]); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+// DirIterator is an open directory containing directory entries that can be read.
+type DirIterator interface {
+ // IterateDir emits directory entries by calling dirCtx.EmitDir, beginning
+ // with the entry at offset and returning the next directory offset.
+ //
+ // Entries for "." and ".." must *not* be included.
+ //
+ // If the offset returned is the same as the argument offset, then
+ // nothing has been serialized. This is equivalent to reaching EOF.
+ // In this case serializer.Written() should return 0.
+ //
+ // The order of entries to emit must be consistent between Readdir
+ // calls, and must start with the given offset.
+ //
+ // The caller must ensure that this operation is permitted.
+ IterateDir(ctx context.Context, dirCtx *DirCtx, offset int) (int, error)
+}
+
+// DirentReaddir serializes the directory entries of d including "." and "..".
+//
+// Arguments:
+//
+// * d: the Dirent of the directory being read; required to provide "." and "..".
+// * it: the directory iterator; which represents an open directory handle.
+// * root: fs root; if d is equal to the root, then '..' will refer to d.
+// * ctx: context provided to file systems in order to select and serialize entries.
+// * offset: the current directory offset.
+//
+// Returns the offset of the *next* element which was not serialized.
+func DirentReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent, dirCtx *DirCtx, offset int64) (int64, error) {
+ offset, err := direntReaddir(ctx, d, it, root, dirCtx, offset)
+ // Serializing any directory entries at all means success.
+ if dirCtx.Serializer.Written() > 0 {
+ return offset, nil
+ }
+ return offset, err
+}
+
+func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent, dirCtx *DirCtx, offset int64) (int64, error) {
+ if root == nil {
+ panic("Dirent.Readdir: root must not be nil")
+ }
+ if dirCtx.Serializer == nil {
+ panic("Dirent.Readdir: serializer must not be nil")
+ }
+ if d.frozen {
+ return d.readdirFrozen(root, offset, dirCtx)
+ }
+
+ // Check that this is actually a directory before emitting anything.
+ // Once we have written entries for "." and "..", future errors from
+ // IterateDir will be hidden.
+ if !IsDir(d.Inode.StableAttr) {
+ return 0, syserror.ENOTDIR
+ }
+
+ // Collect attrs for "." and "..".
+ dot, dotdot := d.GetDotAttrs(root)
+
+ // Emit "." and ".." if the offset is low enough.
+ if offset == 0 {
+ // Serialize ".".
+ if err := dirCtx.DirEmit(".", dot); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ if offset == 1 {
+ // Serialize "..".
+ if err := dirCtx.DirEmit("..", dotdot); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+
+ // it.IterateDir should be passed an offset that does not include the
+ // initial dot elements. We will add them back later.
+ offset -= 2
+ newOffset, err := it.IterateDir(ctx, dirCtx, int(offset))
+ if int64(newOffset) < offset {
+ panic(fmt.Sprintf("node.Readdir returned offset %v less than input offset %v", newOffset, offset))
+ }
+ // Add the initial nodes back to the offset count.
+ newOffset += 2
+ return int64(newOffset), err
+}
+
+// flush flushes all weak references recursively, and removes any cached
+// references to children.
+//
+// Preconditions: d.mu must be held.
+func (d *Dirent) flush() {
+ expired := make(map[string]*refs.WeakRef)
+ for n, w := range d.children {
+ // Call flush recursively on each child before removing our
+ // reference on it, and removing the cache's reference.
+ if child := w.Get(); child != nil {
+ cd := child.(*Dirent)
+
+ if !cd.IsNegative() {
+ // Flush the child.
+ cd.mu.Lock()
+ cd.flush()
+ cd.mu.Unlock()
+
+ // Allow the file system to drop extra references on child.
+ cd.dropExtendedReference()
+ }
+
+ // Don't leak a reference.
+ child.DecRef()
+ }
+ // Check if the child dirent is closed, and mark it as expired if it is.
+ // We must call w.Get() again here, since the child could have been closed
+ // by the calls to flush() and cache.Remove() in the above if-block.
+ if child := w.Get(); child != nil {
+ child.DecRef()
+ } else {
+ expired[n] = w
+ }
+ }
+
+ // Remove expired entries.
+ for n, w := range expired {
+ delete(d.children, n)
+ w.Drop()
+ }
+}
+
+// isMountPoint returns true if the dirent is a mount point or the root.
+func (d *Dirent) isMountPoint() bool {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.isMountPointLocked()
+}
+
+func (d *Dirent) isMountPointLocked() bool {
+ return d.mounted || d.parent == nil
+}
+
+// mount mounts a new dirent with the given inode over d.
+//
+// Precondition: must be called with mm.withMountLocked held on `d`.
+func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err error) {
+ // Did we race with deletion?
+ if atomic.LoadInt32(&d.deleted) != 0 {
+ return nil, syserror.ENOENT
+ }
+
+ // Refuse to mount a symlink.
+ //
+ // See Linux equivalent in fs/namespace.c:do_add_mount.
+ if IsSymlink(inode.StableAttr) {
+ return nil, syserror.EINVAL
+ }
+
+ // Are we frozen?
+ if d.parent.frozen && !d.parent.Inode.IsVirtual() {
+ return nil, syserror.ENOENT
+ }
+
+ // Dirent that'll replace d.
+ //
+ // Note that NewDirent returns with one reference taken; the reference
+ // is donated to the caller as the mount reference.
+ replacement := NewDirent(inode, d.name)
+ replacement.mounted = true
+
+ weakRef, ok := d.parent.hashChild(replacement)
+ if !ok {
+ panic("mount must mount over an existing dirent")
+ }
+ weakRef.Drop()
+
+ // Note that even though `d` is now hidden, it still holds a reference
+ // to its parent.
+ return replacement, nil
+}
+
+// unmount unmounts `d` and replaces it with the last Dirent that was in its
+// place, supplied by the MountNamespace as `replacement`.
+//
+// Precondition: must be called with mm.withMountLocked held on `d`.
+func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
+ // Did we race with deletion?
+ if atomic.LoadInt32(&d.deleted) != 0 {
+ return syserror.ENOENT
+ }
+
+ // Are we frozen?
+ if d.parent.frozen && !d.parent.Inode.IsVirtual() {
+ return syserror.ENOENT
+ }
+
+ // Remount our former child in its place.
+ //
+ // As replacement used to be our child, it must already have the right
+ // parent.
+ weakRef, ok := d.parent.hashChildParentSet(replacement)
+ if !ok {
+ panic("mount must mount over an existing dirent")
+ }
+ weakRef.Drop()
+
+ // d is not reachable anymore, and hence not mounted anymore.
+ d.mounted = false
+
+ // Drop mount reference.
+ d.DecRef()
+ return nil
+}
+
+// Remove removes the given file or symlink. The root dirent is used to
+// resolve name, and must not be nil.
+func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
+ // Check the root.
+ if root == nil {
+ panic("Dirent.Remove: root must not be nil")
+ }
+
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ // Are we frozen?
+ if d.frozen && !d.Inode.IsVirtual() {
+ return syscall.ENOENT
+ }
+
+ // Try to walk to the node.
+ child, err := d.walk(ctx, root, name, false /* may unlock */)
+ if err != nil {
+ // Child does not exist.
+ return err
+ }
+ defer child.DecRef()
+
+ // Remove cannot remove directories.
+ if IsDir(child.Inode.StableAttr) {
+ return syscall.EISDIR
+ }
+
+ // Remove cannot remove a mount point.
+ if child.isMountPoint() {
+ return syscall.EBUSY
+ }
+
+ // Try to remove name on the file system.
+ if err := d.Inode.Remove(ctx, d, child); err != nil {
+ return err
+ }
+
+ // Link count changed, this only applies to non-directory nodes.
+ child.Inode.Watches.Notify("", linux.IN_ATTRIB, 0)
+
+ // Mark name as deleted and remove from children.
+ atomic.StoreInt32(&child.deleted, 1)
+ if w, ok := d.children[name]; ok {
+ delete(d.children, name)
+ w.Drop()
+ }
+
+ // Allow the file system to drop extra references on child.
+ child.dropExtendedReference()
+
+ // Finally, let inotify know the child is being unlinked. Drop any extra
+ // refs from inotify to this child dirent. This doesn't necessarily mean the
+ // watches on the underlying inode will be destroyed, since the underlying
+ // inode may have other links. If this was the last link, the events for the
+ // watch removal will be queued by the inode destructor.
+ child.Inode.Watches.MarkUnlinked()
+ child.Inode.Watches.Unpin(child)
+ d.Inode.Watches.Notify(name, linux.IN_DELETE, 0)
+
+ return nil
+}
+
+// RemoveDirectory removes the given directory. The root dirent is used to
+// resolve name, and must not be nil.
+func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string) error {
+ // Check the root.
+ if root == nil {
+ panic("Dirent.Remove: root must not be nil")
+ }
+
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ // Are we frozen?
+ if d.frozen && !d.Inode.IsVirtual() {
+ return syscall.ENOENT
+ }
+
+ // Check for dots.
+ if name == "." {
+ // Rejected as the last component by rmdir(2).
+ return syscall.EINVAL
+ }
+ if name == ".." {
+ // If d was found, then its parent is not empty.
+ return syscall.ENOTEMPTY
+ }
+
+ // Try to walk to the node.
+ child, err := d.walk(ctx, root, name, false /* may unlock */)
+ if err != nil {
+ // Child does not exist.
+ return err
+ }
+ defer child.DecRef()
+
+ // RemoveDirectory can only remove directories.
+ if !IsDir(child.Inode.StableAttr) {
+ return syscall.ENOTDIR
+ }
+
+ // Remove cannot remove a mount point.
+ if child.isMountPoint() {
+ return syscall.EBUSY
+ }
+
+ // Try to remove name on the file system.
+ if err := d.Inode.Remove(ctx, d, child); err != nil {
+ return err
+ }
+
+ // Mark name as deleted and remove from children.
+ atomic.StoreInt32(&child.deleted, 1)
+ if w, ok := d.children[name]; ok {
+ delete(d.children, name)
+ w.Drop()
+ }
+
+ // Allow the file system to drop extra references on child.
+ child.dropExtendedReference()
+
+ // Finally, let inotify know the child is being unlinked. Drop any extra
+ // refs from inotify to this child dirent.
+ child.Inode.Watches.MarkUnlinked()
+ child.Inode.Watches.Unpin(child)
+ d.Inode.Watches.Notify(name, linux.IN_ISDIR|linux.IN_DELETE, 0)
+
+ return nil
+}
+
+// destroy closes this node and all children.
+func (d *Dirent) destroy() {
+ if d.IsNegative() {
+ // Nothing to tear-down and no parent references to drop, since a negative
+ // Dirent does not take a references on its parent, has no Inode and no children.
+ return
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Drop all weak references.
+ for _, w := range d.children {
+ if c := w.Get(); c != nil {
+ if c.(*Dirent).IsNegative() {
+ // The parent holds both weak and strong refs in the case of
+ // negative dirents.
+ c.DecRef()
+ }
+ // Drop the reference we just acquired in WeakRef.Get.
+ c.DecRef()
+ }
+ w.Drop()
+ }
+ d.children = nil
+
+ allDirents.remove(d)
+
+ // Drop our reference to the Inode.
+ d.Inode.DecRef()
+
+ // Allow the Dirent to be GC'ed after this point, since the Inode may still
+ // be referenced after the Dirent is destroyed (for instance by filesystem
+ // internal caches or hard links).
+ d.Inode = nil
+
+ // Drop the reference we have on our parent if we took one. renameMu doesn't need to be
+ // held because d can't be reparented without any references to it left.
+ if d.parent != nil {
+ d.parent.DecRef()
+ }
+}
+
+// IncRef increases the Dirent's refcount as well as its mount's refcount.
+//
+// IncRef implements RefCounter.IncRef.
+func (d *Dirent) IncRef() {
+ if d.Inode != nil {
+ d.Inode.MountSource.IncDirentRefs()
+ }
+ d.AtomicRefCount.IncRef()
+}
+
+// TryIncRef implements RefCounter.TryIncRef.
+func (d *Dirent) TryIncRef() bool {
+ ok := d.AtomicRefCount.TryIncRef()
+ if ok && d.Inode != nil {
+ d.Inode.MountSource.IncDirentRefs()
+ }
+ return ok
+}
+
+// DecRef decreases the Dirent's refcount and drops its reference on its mount.
+//
+// DecRef implements RefCounter.DecRef with destructor d.destroy.
+func (d *Dirent) DecRef() {
+ if d.Inode != nil {
+ // Keep mount around, since DecRef may destroy d.Inode.
+ msrc := d.Inode.MountSource
+ d.DecRefWithDestructor(d.destroy)
+ msrc.DecDirentRefs()
+ } else {
+ d.DecRefWithDestructor(d.destroy)
+ }
+}
+
+// InotifyEvent notifies all watches on the inode for this dirent and its parent
+// of potential events. The events may not actually propagate up to the user,
+// depending on the event masks. InotifyEvent automatically provides the name of
+// the current dirent as the subject of the event as required, and adds the
+// IN_ISDIR flag for dirents that refer to directories.
+func (d *Dirent) InotifyEvent(events, cookie uint32) {
+ // N.B. We don't defer the unlocks because InotifyEvent is in the hot
+ // path of all IO operations, and the defers cost too much for small IO
+ // operations.
+ renameMu.RLock()
+
+ if IsDir(d.Inode.StableAttr) {
+ events |= linux.IN_ISDIR
+ }
+
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ // name is immediately stale w.r.t. renames (renameMu doesn't
+ // protect against renames in the same directory). Holding
+ // d.parent.mu around Notify() wouldn't matter since Notify
+ // doesn't provide a synchronous mechanism for reading the name
+ // anyway.
+ d.parent.mu.Lock()
+ name := d.name
+ d.parent.mu.Unlock()
+ d.parent.Inode.Watches.Notify(name, events, cookie)
+ }
+ d.Inode.Watches.Notify("", events, cookie)
+
+ renameMu.RUnlock()
+}
+
+// maybeExtendReference caches a reference on this Dirent if
+// MountSourceOperations.Keep returns true.
+func (d *Dirent) maybeExtendReference() {
+ if msrc := d.Inode.MountSource; msrc.Keep(d) {
+ msrc.fscache.Add(d)
+ }
+}
+
+// dropExtendedReference drops any cached reference held by the
+// MountSource on the dirent.
+func (d *Dirent) dropExtendedReference() {
+ d.Inode.MountSource.fscache.Remove(d)
+}
+
+// lockForRename takes locks on oldParent and newParent as required by Rename
+// and returns a function that will unlock the locks taken. The returned
+// function must be called even if a non-nil error is returned.
+func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName string) (func(), error) {
+ renameMu.Lock()
+ if oldParent == newParent {
+ oldParent.mu.Lock()
+ return func() {
+ oldParent.mu.Unlock()
+ renameMu.Unlock()
+ }, nil
+ }
+
+ // Renaming between directories is a bit subtle:
+ //
+ // - A concurrent cross-directory Rename may try to lock in the opposite
+ // order; take renameMu to prevent this from happening.
+ //
+ // - If either directory is an ancestor of the other, then a concurrent
+ // Remove may lock the descendant (in DecRef -> closeAll) while holding a
+ // lock on the ancestor; to avoid this, ensure we take locks in the same
+ // ancestor-to-descendant order. (Holding renameMu prevents this
+ // relationship from changing.)
+
+ // First check if newParent is a descendant of oldParent.
+ child := newParent
+ for p := newParent.parent; p != nil; p = p.parent {
+ if p == oldParent {
+ oldParent.mu.Lock()
+ newParent.mu.Lock()
+ var err error
+ if child.name == oldName {
+ // newParent is not just a descendant of oldParent, but
+ // more specifically of oldParent/oldName. That is, we're
+ // trying to rename something into a subdirectory of
+ // itself.
+ err = syscall.EINVAL
+ }
+ return func() {
+ newParent.mu.Unlock()
+ oldParent.mu.Unlock()
+ renameMu.Unlock()
+ }, err
+ }
+ child = p
+ }
+
+ // Otherwise, either oldParent is a descendant of newParent or the two
+ // have no relationship; in either case we can do this:
+ newParent.mu.Lock()
+ oldParent.mu.Lock()
+ return func() {
+ oldParent.mu.Unlock()
+ newParent.mu.Unlock()
+ renameMu.Unlock()
+ }, nil
+}
+
+func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error {
+ uattr, err := dir.Inode.UnstableAttr(ctx)
+ if err != nil {
+ return syserror.EPERM
+ }
+ if !uattr.Perms.Sticky {
+ return nil
+ }
+
+ creds := auth.CredentialsFromContext(ctx)
+ if uattr.Owner.UID == creds.EffectiveKUID {
+ return nil
+ }
+
+ vuattr, err := victim.Inode.UnstableAttr(ctx)
+ if err != nil {
+ return syserror.EPERM
+ }
+ if vuattr.Owner.UID == creds.EffectiveKUID {
+ return nil
+ }
+ if victim.Inode.CheckCapability(ctx, linux.CAP_FOWNER) {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// MayDelete determines whether `name`, a child of `dir`, can be deleted or
+// renamed by `ctx`.
+//
+// Compare Linux kernel fs/namei.c:may_delete.
+func MayDelete(ctx context.Context, root, dir *Dirent, name string) error {
+ if err := dir.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ victim, err := dir.Walk(ctx, root, name)
+ if err != nil {
+ return err
+ }
+ defer victim.DecRef()
+
+ return mayDelete(ctx, dir, victim)
+}
+
+// mayDelete determines whether `victim`, a child of `dir`, can be deleted or
+// renamed by `ctx`.
+//
+// Preconditions: `dir` is writable and executable by `ctx`.
+func mayDelete(ctx context.Context, dir, victim *Dirent) error {
+ if err := checkSticky(ctx, dir, victim); err != nil {
+ return err
+ }
+
+ if victim.IsRoot() {
+ return syserror.EBUSY
+ }
+
+ return nil
+}
+
+// Rename atomically converts the child of oldParent named oldName to a
+// child of newParent named newName.
+func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string, newParent *Dirent, newName string) error {
+ if root == nil {
+ panic("Rename: root must not be nil")
+ }
+ if oldParent == newParent && oldName == newName {
+ return nil
+ }
+
+ // Acquire global renameMu lock, and mu locks on oldParent/newParent.
+ unlock, err := lockForRename(oldParent, oldName, newParent, newName)
+ defer unlock()
+ if err != nil {
+ return err
+ }
+
+ // Are we frozen?
+ // TODO(jamieliu): Is this the right errno?
+ if oldParent.frozen && !oldParent.Inode.IsVirtual() {
+ return syscall.ENOENT
+ }
+ if newParent.frozen && !newParent.Inode.IsVirtual() {
+ return syscall.ENOENT
+ }
+
+ // Do we have general permission to remove from oldParent and
+ // create/replace in newParent?
+ if err := oldParent.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+ if err := newParent.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ // renamed is the dirent that will be renamed to something else.
+ renamed, err := oldParent.walk(ctx, root, oldName, false /* may unlock */)
+ if err != nil {
+ return err
+ }
+ defer renamed.DecRef()
+
+ // Check that the renamed dirent is deletable.
+ if err := mayDelete(ctx, oldParent, renamed); err != nil {
+ return err
+ }
+
+ // Check that the renamed dirent is not a mount point.
+ if renamed.isMountPointLocked() {
+ return syscall.EBUSY
+ }
+
+ // Source should not be an ancestor of the target.
+ if newParent.descendantOf(renamed) {
+ return syscall.EINVAL
+ }
+
+ // Per rename(2): "... EACCES: ... or oldpath is a directory and does not
+ // allow write permission (needed to update the .. entry)."
+ if IsDir(renamed.Inode.StableAttr) {
+ if err := renamed.Inode.CheckPermission(ctx, PermMask{Write: true}); err != nil {
+ return err
+ }
+ }
+
+ // replaced is the dirent that is being overwritten by rename.
+ replaced, err := newParent.walk(ctx, root, newName, false /* may unlock */)
+ if err != nil {
+ if err != syserror.ENOENT {
+ return err
+ }
+
+ // newName doesn't exist; simply create it below.
+ replaced = nil
+ } else {
+ // Check constraints on the dirent being replaced.
+
+ // NOTE(b/111808347): We don't want to keep replaced alive
+ // across the Rename, so must call DecRef manually (no defer).
+
+ // Check that we can delete replaced.
+ if err := mayDelete(ctx, newParent, replaced); err != nil {
+ replaced.DecRef()
+ return err
+ }
+
+ // Target should not be an ancestor of source.
+ if oldParent.descendantOf(replaced) {
+ replaced.DecRef()
+
+ // Note that Linux returns EINVAL if the source is an
+ // ancestor of target, but ENOTEMPTY if the target is
+ // an ancestor of source (unless RENAME_EXCHANGE flag
+ // is present). See fs/namei.c:renameat2.
+ return syscall.ENOTEMPTY
+ }
+
+ // Check that replaced is not a mount point.
+ if replaced.isMountPointLocked() {
+ replaced.DecRef()
+ return syscall.EBUSY
+ }
+
+ // Require that a directory is replaced by a directory.
+ oldIsDir := IsDir(renamed.Inode.StableAttr)
+ newIsDir := IsDir(replaced.Inode.StableAttr)
+ if !newIsDir && oldIsDir {
+ replaced.DecRef()
+ return syscall.ENOTDIR
+ }
+ if !oldIsDir && newIsDir {
+ replaced.DecRef()
+ return syscall.EISDIR
+ }
+
+ // Allow the file system to drop extra references on replaced.
+ replaced.dropExtendedReference()
+
+ // NOTE(b/31798319,b/31867149,b/31867671): Keeping a dirent
+ // open across renames is currently broken for multiple
+ // reasons, so we flush all references on the replaced node and
+ // its children.
+ replaced.Inode.Watches.Unpin(replaced)
+ replaced.mu.Lock()
+ replaced.flush()
+ replaced.mu.Unlock()
+
+ // Done with replaced.
+ replaced.DecRef()
+ }
+
+ if err := renamed.Inode.Rename(ctx, oldParent, renamed, newParent, newName, replaced != nil); err != nil {
+ return err
+ }
+
+ renamed.name = newName
+ renamed.parent = newParent
+ if oldParent != newParent {
+ // Reparent the reference held by renamed.parent. oldParent.DecRef
+ // can't destroy oldParent (and try to retake its lock) because
+ // Rename's caller must be holding a reference.
+ newParent.IncRef()
+ oldParent.DecRef()
+ }
+ if w, ok := newParent.children[newName]; ok {
+ w.Drop()
+ delete(newParent.children, newName)
+ }
+ if w, ok := oldParent.children[oldName]; ok {
+ w.Drop()
+ delete(oldParent.children, oldName)
+ }
+
+ // Add a weak reference from the new parent. This ensures that the child
+ // can still be found from the new parent if a prior hard reference is
+ // held on renamed.
+ //
+ // This is required for file lock correctness because file locks are per-Dirent
+ // and without maintaining the a cached child (via a weak reference) for renamed,
+ // multiple Dirents can correspond to the same resource (by virtue of the renamed
+ // Dirent being unreachable by its parent and it being looked up).
+ newParent.children[newName] = refs.NewWeakRef(renamed, nil)
+
+ // Queue inotify events for the rename.
+ var ev uint32
+ if IsDir(renamed.Inode.StableAttr) {
+ ev |= linux.IN_ISDIR
+ }
+
+ cookie := uniqueid.InotifyCookie(ctx)
+ oldParent.Inode.Watches.Notify(oldName, ev|linux.IN_MOVED_FROM, cookie)
+ newParent.Inode.Watches.Notify(newName, ev|linux.IN_MOVED_TO, cookie)
+ // Somewhat surprisingly, self move events do not have a cookie.
+ renamed.Inode.Watches.Notify("", linux.IN_MOVE_SELF, 0)
+
+ // Allow the file system to drop extra references on renamed.
+ renamed.dropExtendedReference()
+
+ // Same as replaced.flush above.
+ renamed.mu.Lock()
+ renamed.flush()
+ renamed.mu.Unlock()
+
+ return nil
+}
diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go
new file mode 100644
index 000000000..71f2d11de
--- /dev/null
+++ b/pkg/sentry/fs/dirent_cache.go
@@ -0,0 +1,175 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "sync"
+)
+
+// DirentCache is an LRU cache of Dirents. The Dirent's refCount is
+// incremented when it is added to the cache, and decremented when it is
+// removed.
+//
+// A nil DirentCache corresponds to a cache with size 0. All methods can be
+// called, but nothing is actually cached.
+//
+// +stateify savable
+type DirentCache struct {
+ // Maximum size of the cache. This must be saved manually, to handle the case
+ // when cache is nil.
+ maxSize uint64
+
+ // limit restricts the number of entries in the cache amoung multiple caches.
+ // It may be nil if there are no global limit for this cache.
+ limit *DirentCacheLimiter
+
+ // mu protects currentSize and direntList.
+ mu sync.Mutex `state:"nosave"`
+
+ // currentSize is the number of elements in the cache. It must be zero (i.e.
+ // the cache must be empty) on Save.
+ currentSize uint64 `state:"zerovalue"`
+
+ // list is a direntList, an ilist of Dirents. New Dirents are added
+ // to the front of the list. Old Dirents are removed from the back of
+ // the list. It must be zerovalue (i.e. the cache must be empty) on Save.
+ list direntList `state:"zerovalue"`
+}
+
+// NewDirentCache returns a new DirentCache with the given maxSize.
+func NewDirentCache(maxSize uint64) *DirentCache {
+ return &DirentCache{
+ maxSize: maxSize,
+ }
+}
+
+// Add adds the element to the cache and increments the refCount. If the
+// argument is already in the cache, it is moved to the front. An element is
+// removed from the back if the cache is over capacity.
+func (c *DirentCache) Add(d *Dirent) {
+ if c == nil || c.maxSize == 0 {
+ return
+ }
+
+ c.mu.Lock()
+ if c.contains(d) {
+ // d is already in cache. Bump it to the front.
+ // currentSize and refCount are unaffected.
+ c.list.Remove(d)
+ c.list.PushFront(d)
+ c.mu.Unlock()
+ return
+ }
+
+ // First check against the global limit.
+ for c.limit != nil && !c.limit.tryInc() {
+ if c.currentSize == 0 {
+ // If the global limit is reached, but there is nothing more to drop from
+ // this cache, there is not much else to do.
+ c.mu.Unlock()
+ return
+ }
+ c.remove(c.list.Back())
+ }
+
+ // d is not in cache. Add it and take a reference.
+ c.list.PushFront(d)
+ d.IncRef()
+ c.currentSize++
+
+ c.maybeShrink()
+
+ c.mu.Unlock()
+}
+
+func (c *DirentCache) remove(d *Dirent) {
+ if !c.contains(d) {
+ panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d))
+ }
+ c.list.Remove(d)
+ d.SetPrev(nil)
+ d.SetNext(nil)
+ d.DecRef()
+ c.currentSize--
+ if c.limit != nil {
+ c.limit.dec()
+ }
+}
+
+// Remove removes the element from the cache and decrements its refCount. It
+// also sets the previous and next elements to nil, which allows us to
+// determine if a given element is in the cache.
+func (c *DirentCache) Remove(d *Dirent) {
+ if c == nil || c.maxSize == 0 {
+ return
+ }
+ c.mu.Lock()
+ if !c.contains(d) {
+ c.mu.Unlock()
+ return
+ }
+ c.remove(d)
+ c.mu.Unlock()
+}
+
+// Size returns the number of elements in the cache.
+func (c *DirentCache) Size() uint64 {
+ if c == nil {
+ return 0
+ }
+ c.mu.Lock()
+ size := c.currentSize
+ c.mu.Unlock()
+ return size
+}
+
+func (c *DirentCache) contains(d *Dirent) bool {
+ // If d has a Prev or Next element, then it is in the cache.
+ if d.Prev() != nil || d.Next() != nil {
+ return true
+ }
+ // Otherwise, d is in the cache if it is the only element (and thus the
+ // first element).
+ return c.list.Front() == d
+}
+
+// Invalidate removes all Dirents from the cache, caling DecRef on each.
+func (c *DirentCache) Invalidate() {
+ if c == nil {
+ return
+ }
+ c.mu.Lock()
+ for c.list.Front() != nil {
+ c.remove(c.list.Front())
+ }
+ c.mu.Unlock()
+}
+
+// setMaxSize sets cache max size. If current size is larger than max size, the
+// cache shrinks to acommodate the new max.
+func (c *DirentCache) setMaxSize(max uint64) {
+ c.mu.Lock()
+ c.maxSize = max
+ c.maybeShrink()
+ c.mu.Unlock()
+}
+
+// shrink removes the oldest element until the list is under the size limit.
+func (c *DirentCache) maybeShrink() {
+ for c.maxSize > 0 && c.currentSize > c.maxSize {
+ c.remove(c.list.Back())
+ }
+}
diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go
new file mode 100644
index 000000000..ebb80bd50
--- /dev/null
+++ b/pkg/sentry/fs/dirent_cache_limiter.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "sync"
+)
+
+// DirentCacheLimiter acts as a global limit for all dirent caches in the
+// process.
+//
+// +stateify savable
+type DirentCacheLimiter struct {
+ mu sync.Mutex `state:"nosave"`
+ max uint64
+ count uint64 `state:"zerovalue"`
+}
+
+// NewDirentCacheLimiter creates a new DirentCacheLimiter.
+func NewDirentCacheLimiter(max uint64) *DirentCacheLimiter {
+ return &DirentCacheLimiter{max: max}
+}
+
+func (d *DirentCacheLimiter) tryInc() bool {
+ d.mu.Lock()
+ if d.count >= d.max {
+ d.mu.Unlock()
+ return false
+ }
+ d.count++
+ d.mu.Unlock()
+ return true
+}
+
+func (d *DirentCacheLimiter) dec() {
+ d.mu.Lock()
+ if d.count == 0 {
+ panic(fmt.Sprintf("underflowing DirentCacheLimiter count: %+v", d))
+ }
+ d.count--
+ d.mu.Unlock()
+}
diff --git a/pkg/sentry/fs/dirent_list.go b/pkg/sentry/fs/dirent_list.go
new file mode 100755
index 000000000..750961a48
--- /dev/null
+++ b/pkg/sentry/fs/dirent_list.go
@@ -0,0 +1,173 @@
+package fs
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type direntElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (direntElementMapper) linkerFor(elem *Dirent) *Dirent { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type direntList struct {
+ head *Dirent
+ tail *Dirent
+}
+
+// Reset resets list l to the empty state.
+func (l *direntList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *direntList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *direntList) Front() *Dirent {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *direntList) Back() *Dirent {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *direntList) PushFront(e *Dirent) {
+ direntElementMapper{}.linkerFor(e).SetNext(l.head)
+ direntElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ direntElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *direntList) PushBack(e *Dirent) {
+ direntElementMapper{}.linkerFor(e).SetNext(nil)
+ direntElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ direntElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *direntList) PushBackList(m *direntList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ direntElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ direntElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *direntList) InsertAfter(b, e *Dirent) {
+ a := direntElementMapper{}.linkerFor(b).Next()
+ direntElementMapper{}.linkerFor(e).SetNext(a)
+ direntElementMapper{}.linkerFor(e).SetPrev(b)
+ direntElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ direntElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *direntList) InsertBefore(a, e *Dirent) {
+ b := direntElementMapper{}.linkerFor(a).Prev()
+ direntElementMapper{}.linkerFor(e).SetNext(a)
+ direntElementMapper{}.linkerFor(e).SetPrev(b)
+ direntElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ direntElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *direntList) Remove(e *Dirent) {
+ prev := direntElementMapper{}.linkerFor(e).Prev()
+ next := direntElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ direntElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ direntElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type direntEntry struct {
+ next *Dirent
+ prev *Dirent
+}
+
+// Next returns the entry that follows e in the list.
+func (e *direntEntry) Next() *Dirent {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *direntEntry) Prev() *Dirent {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *direntEntry) SetNext(elem *Dirent) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *direntEntry) SetPrev(elem *Dirent) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/fs/dirent_state.go b/pkg/sentry/fs/dirent_state.go
new file mode 100644
index 000000000..18652b809
--- /dev/null
+++ b/pkg/sentry/fs/dirent_state.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+)
+
+// beforeSave is invoked by stateify.
+func (d *Dirent) beforeSave() {
+ // Refuse to save if the file is on a non-virtual file system and has
+ // already been deleted (but still has open fds, which is why the Dirent
+ // is still accessible). We know the the restore re-opening of the file
+ // will always fail. This condition will last until all the open fds and
+ // this Dirent are closed and released.
+ //
+ // Such "dangling" open files on virtual file systems (e.g., tmpfs) is
+ // OK to save as their restore does not require re-opening the files.
+ //
+ // Note that this is rejection rather than failure---it would be
+ // perfectly OK to save---we are simply disallowing it here to prevent
+ // generating non-restorable state dumps. As the program continues its
+ // execution, it may become allowed to save again.
+ if !d.Inode.IsVirtual() && atomic.LoadInt32(&d.deleted) != 0 {
+ n, _ := d.FullName(nil /* root */)
+ panic(ErrSaveRejection{fmt.Errorf("deleted file %q still has open fds", n)})
+ }
+}
+
+// saveChildren is invoked by stateify.
+func (d *Dirent) saveChildren() map[string]*Dirent {
+ c := make(map[string]*Dirent)
+ for name, w := range d.children {
+ if rc := w.Get(); rc != nil {
+ // Drop the reference count obtain in w.Get()
+ rc.DecRef()
+
+ cd := rc.(*Dirent)
+ if cd.IsNegative() {
+ // Don't bother saving negative Dirents.
+ continue
+ }
+ c[name] = cd
+ }
+ }
+ return c
+}
+
+// loadChildren is invoked by stateify.
+func (d *Dirent) loadChildren(children map[string]*Dirent) {
+ d.children = make(map[string]*refs.WeakRef)
+ for name, c := range children {
+ d.children[name] = refs.NewWeakRef(c, nil)
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *Dirent) afterLoad() {
+ if d.userVisible {
+ allDirents.add(d)
+ }
+}
diff --git a/pkg/sentry/fs/event_list.go b/pkg/sentry/fs/event_list.go
new file mode 100755
index 000000000..c94cb03e1
--- /dev/null
+++ b/pkg/sentry/fs/event_list.go
@@ -0,0 +1,173 @@
+package fs
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type eventElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (eventElementMapper) linkerFor(elem *Event) *Event { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type eventList struct {
+ head *Event
+ tail *Event
+}
+
+// Reset resets list l to the empty state.
+func (l *eventList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *eventList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *eventList) Front() *Event {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *eventList) Back() *Event {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *eventList) PushFront(e *Event) {
+ eventElementMapper{}.linkerFor(e).SetNext(l.head)
+ eventElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ eventElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *eventList) PushBack(e *Event) {
+ eventElementMapper{}.linkerFor(e).SetNext(nil)
+ eventElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ eventElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *eventList) PushBackList(m *eventList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ eventElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ eventElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *eventList) InsertAfter(b, e *Event) {
+ a := eventElementMapper{}.linkerFor(b).Next()
+ eventElementMapper{}.linkerFor(e).SetNext(a)
+ eventElementMapper{}.linkerFor(e).SetPrev(b)
+ eventElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ eventElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *eventList) InsertBefore(a, e *Event) {
+ b := eventElementMapper{}.linkerFor(a).Prev()
+ eventElementMapper{}.linkerFor(e).SetNext(a)
+ eventElementMapper{}.linkerFor(e).SetPrev(b)
+ eventElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ eventElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *eventList) Remove(e *Event) {
+ prev := eventElementMapper{}.linkerFor(e).Prev()
+ next := eventElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ eventElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ eventElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type eventEntry struct {
+ next *Event
+ prev *Event
+}
+
+// Next returns the entry that follows e in the list.
+func (e *eventEntry) Next() *Event {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *eventEntry) Prev() *Event {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *eventEntry) SetNext(elem *Event) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *eventEntry) SetPrev(elem *Event) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go b/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go
new file mode 100755
index 000000000..46192664c
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go
@@ -0,0 +1,27 @@
+// automatically generated by stateify.
+
+package fdpipe
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+func (x *pipeOperations) save(m state.Map) {
+ x.beforeSave()
+ var flags fs.FileFlags = x.saveFlags()
+ m.SaveValue("flags", flags)
+ m.Save("opener", &x.opener)
+ m.Save("readAheadBuffer", &x.readAheadBuffer)
+}
+
+func (x *pipeOperations) load(m state.Map) {
+ m.LoadWait("opener", &x.opener)
+ m.Load("readAheadBuffer", &x.readAheadBuffer)
+ m.LoadValue("flags", new(fs.FileFlags), func(y interface{}) { x.loadFlags(y.(fs.FileFlags)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func init() {
+ state.Register("fdpipe.pipeOperations", (*pipeOperations)(nil), state.Fns{Save: (*pipeOperations).save, Load: (*pipeOperations).load})
+}
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
new file mode 100644
index 000000000..4ef7ea08a
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -0,0 +1,168 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fdpipe implements common namedpipe opening and accessing logic.
+package fdpipe
+
+import (
+ "os"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/fdnotifier"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/secio"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// pipeOperations are the fs.FileOperations of a host pipe.
+//
+// +stateify savable
+type pipeOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.Queue `state:"nosave"`
+
+ // flags are the flags used to open the pipe.
+ flags fs.FileFlags `state:".(fs.FileFlags)"`
+
+ // opener is how the pipe was opened.
+ opener NonBlockingOpener `state:"wait"`
+
+ // file represents the host pipe.
+ file *fd.FD `state:"nosave"`
+
+ // mu protects readAheadBuffer access below.
+ mu sync.Mutex `state:"nosave"`
+
+ // readAheadBuffer contains read bytes that have not yet been read
+ // by the application but need to be buffered for save-restore for correct
+ // opening semantics. The readAheadBuffer will only be non-empty when the
+ // is first opened and will be drained by subsequent reads on the pipe.
+ readAheadBuffer []byte
+}
+
+// newPipeOperations returns an implementation of fs.FileOperations for a pipe.
+func newPipeOperations(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags, file *fd.FD, readAheadBuffer []byte) (*pipeOperations, error) {
+ pipeOps := &pipeOperations{
+ flags: flags,
+ opener: opener,
+ file: file,
+ readAheadBuffer: readAheadBuffer,
+ }
+ if err := pipeOps.init(); err != nil {
+ return nil, err
+ }
+ return pipeOps, nil
+}
+
+// init initializes p.file.
+func (p *pipeOperations) init() error {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(p.file.FD(), &s); err != nil {
+ log.Warningf("pipe: cannot stat fd %d: %v", p.file.FD(), err)
+ return syscall.EINVAL
+ }
+ if s.Mode&syscall.S_IFIFO != syscall.S_IFIFO {
+ log.Warningf("pipe: cannot load fd %d as pipe, file type: %o", p.file.FD(), s.Mode)
+ return syscall.EINVAL
+ }
+ if err := syscall.SetNonblock(p.file.FD(), true); err != nil {
+ return err
+ }
+ return fdnotifier.AddFD(int32(p.file.FD()), &p.Queue)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (p *pipeOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ p.Queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(p.file.FD()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (p *pipeOperations) EventUnregister(e *waiter.Entry) {
+ p.Queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(p.file.FD()))
+}
+
+// Readiness returns a mask of ready events for stream.
+func (p *pipeOperations) Readiness(mask waiter.EventMask) (eventMask waiter.EventMask) {
+ return fdnotifier.NonBlockingPoll(int32(p.file.FD()), mask)
+}
+
+// Release implements fs.FileOperations.Release.
+func (p *pipeOperations) Release() {
+ fdnotifier.RemoveFD(int32(p.file.FD()))
+ p.file.Close()
+ p.file = nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (p *pipeOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ // Drain the read ahead buffer, if it contains anything first.
+ var bufN int
+ var bufErr error
+ p.mu.Lock()
+ if len(p.readAheadBuffer) > 0 {
+ bufN, bufErr = dst.CopyOut(ctx, p.readAheadBuffer)
+ p.readAheadBuffer = p.readAheadBuffer[bufN:]
+ dst = dst.DropFirst(bufN)
+ }
+ p.mu.Unlock()
+ if dst.NumBytes() == 0 || bufErr != nil {
+ return int64(bufN), bufErr
+ }
+
+ // Pipes expect full reads.
+ n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{secio.FullReader{p.file}})
+ total := int64(bufN) + n
+ if err != nil && isBlockError(err) {
+ return total, syserror.ErrWouldBlock
+ }
+ return total, err
+}
+
+// Write implements fs.FileOperations.Write.
+func (p *pipeOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ n, err := src.CopyInTo(ctx, safemem.FromIOWriter{p.file})
+ if err != nil && isBlockError(err) {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
+// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
+func isBlockError(err error) bool {
+ if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK {
+ return true
+ }
+ if pe, ok := err.(*os.PathError); ok {
+ return isBlockError(pe.Err)
+ }
+ return false
+}
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener.go b/pkg/sentry/fs/fdpipe/pipe_opener.go
new file mode 100644
index 000000000..0cabe2e18
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/pipe_opener.go
@@ -0,0 +1,193 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdpipe
+
+import (
+ "io"
+ "os"
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// NonBlockingOpener is a generic host file opener used to retry opening host
+// pipes if necessary.
+type NonBlockingOpener interface {
+ // NonBlockingOpen tries to open a host pipe in a non-blocking way,
+ // and otherwise returns an error. Implementations should be idempotent.
+ NonBlockingOpen(context.Context, fs.PermMask) (*fd.FD, error)
+}
+
+// Open blocks until a host pipe can be opened or the action was cancelled.
+// On success, returns fs.FileOperations wrapping the opened host pipe.
+func Open(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags) (fs.FileOperations, error) {
+ p := &pipeOpenState{}
+ canceled := false
+ for {
+ if file, err := p.TryOpen(ctx, opener, flags); err != syserror.ErrWouldBlock {
+ return file, err
+ }
+
+ // Honor the cancellation request if open still blocks.
+ if canceled {
+ // If we were canceled but we have a handle to a host
+ // file, we need to close it.
+ if p.hostFile != nil {
+ p.hostFile.Close()
+ }
+ return nil, syserror.ErrInterrupted
+ }
+
+ cancel := ctx.SleepStart()
+ select {
+ case <-cancel:
+ // The cancellation request received here really says
+ // "cancel from now on (or ASAP)". Any environmental
+ // changes happened before receiving it, that might have
+ // caused open to not block anymore, should still be
+ // respected. So we cannot just return here. We have to
+ // give open another try below first.
+ canceled = true
+ ctx.SleepFinish(false)
+ case <-time.After(100 * time.Millisecond):
+ // If we would block, then delay retrying for a bit, since there
+ // is no way to know when the pipe would be ready to be
+ // re-opened. This is identical to sending an event notification
+ // to stop blocking in Task.Block, given that this routine will
+ // stop retrying if a cancelation is received.
+ ctx.SleepFinish(true)
+ }
+ }
+}
+
+// pipeOpenState holds state needed to open a blocking named pipe read only, for instance the
+// file that has been opened but doesn't yet have a corresponding writer.
+type pipeOpenState struct {
+ // hostFile is the read only named pipe which lacks a corresponding writer.
+ hostFile *fd.FD
+}
+
+// unwrapError is needed to match against ENXIO primarily.
+func unwrapError(err error) error {
+ if pe, ok := err.(*os.PathError); ok {
+ return pe.Err
+ }
+ return err
+}
+
+// TryOpen uses a NonBlockingOpener to try to open a host pipe, respecting the fs.FileFlags.
+func (p *pipeOpenState) TryOpen(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags) (*pipeOperations, error) {
+ switch {
+ // Reject invalid configurations so they don't accidentally succeed below.
+ case !flags.Read && !flags.Write:
+ return nil, syscall.EINVAL
+
+ // Handle opening RDWR or with O_NONBLOCK: will never block, so try only once.
+ case (flags.Read && flags.Write) || flags.NonBlocking:
+ f, err := opener.NonBlockingOpen(ctx, fs.PermMask{Read: flags.Read, Write: flags.Write})
+ if err != nil {
+ return nil, err
+ }
+ return newPipeOperations(ctx, opener, flags, f, nil)
+
+ // Handle opening O_WRONLY blocking: convert ENXIO to syserror.ErrWouldBlock.
+ // See TryOpenWriteOnly for more details.
+ case flags.Write:
+ return p.TryOpenWriteOnly(ctx, opener)
+
+ default:
+ // Handle opening O_RDONLY blocking: convert EOF from read to syserror.ErrWouldBlock.
+ // See TryOpenReadOnly for more details.
+ return p.TryOpenReadOnly(ctx, opener)
+ }
+}
+
+// TryOpenReadOnly tries to open a host pipe read only but only returns a fs.File when
+// there is a coordinating writer. Call TryOpenReadOnly repeatedly on the same pipeOpenState
+// until syserror.ErrWouldBlock is no longer returned.
+//
+// How it works:
+//
+// Opening a pipe read only will return no error, but each non zero Read will return EOF
+// until a writer becomes available, then EWOULDBLOCK. This is the only state change
+// available to us. We keep a read ahead buffer in case we read bytes instead of getting
+// EWOULDBLOCK, to be read from on the first read request to this fs.File.
+func (p *pipeOpenState) TryOpenReadOnly(ctx context.Context, opener NonBlockingOpener) (*pipeOperations, error) {
+ // Waiting for a blocking read only open involves reading from the host pipe until
+ // bytes or other writers are available, so instead of retrying opening the pipe,
+ // it's necessary to retry reading from the pipe. To do this we need to keep around
+ // the read only pipe we opened, until success or an irrecoverable read error (at
+ // which point it must be closed).
+ if p.hostFile == nil {
+ var err error
+ p.hostFile, err = opener.NonBlockingOpen(ctx, fs.PermMask{Read: true})
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Try to read from the pipe to see if writers are around.
+ tryReadBuffer := make([]byte, 1)
+ n, rerr := p.hostFile.Read(tryReadBuffer)
+
+ // No bytes were read.
+ if n == 0 {
+ // EOF means that we're not ready yet.
+ if rerr == nil || rerr == io.EOF {
+ return nil, syserror.ErrWouldBlock
+ }
+ // Any error that is not EWOULDBLOCK also means we're not
+ // ready yet, and probably never will be ready. In this
+ // case we need to close the host pipe we opened.
+ if unwrapError(rerr) != syscall.EWOULDBLOCK {
+ p.hostFile.Close()
+ return nil, rerr
+ }
+ }
+
+ // If any bytes were read, no matter the corresponding error, we need
+ // to keep them around so they can be read by the application.
+ var readAheadBuffer []byte
+ if n > 0 {
+ readAheadBuffer = tryReadBuffer
+ }
+
+ // Successfully opened read only blocking pipe with either bytes available
+ // to read and/or a writer available.
+ return newPipeOperations(ctx, opener, fs.FileFlags{Read: true}, p.hostFile, readAheadBuffer)
+}
+
+// TryOpenWriteOnly tries to open a host pipe write only but only returns a fs.File when
+// there is a coordinating reader. Call TryOpenWriteOnly repeatedly on the same pipeOpenState
+// until syserror.ErrWouldBlock is no longer returned.
+//
+// How it works:
+//
+// Opening a pipe write only will return ENXIO until readers are available. Converts the ENXIO
+// to an syserror.ErrWouldBlock, to tell callers to retry.
+func (*pipeOpenState) TryOpenWriteOnly(ctx context.Context, opener NonBlockingOpener) (*pipeOperations, error) {
+ hostFile, err := opener.NonBlockingOpen(ctx, fs.PermMask{Write: true})
+ if unwrapError(err) == syscall.ENXIO {
+ return nil, syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return nil, err
+ }
+ return newPipeOperations(ctx, opener, fs.FileFlags{Write: true}, hostFile, nil)
+}
diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go
new file mode 100644
index 000000000..8b347aa11
--- /dev/null
+++ b/pkg/sentry/fs/fdpipe/pipe_state.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdpipe
+
+import (
+ "fmt"
+ "io/ioutil"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// beforeSave is invoked by stateify.
+func (p *pipeOperations) beforeSave() {
+ if p.flags.Read {
+ data, err := ioutil.ReadAll(p.file)
+ if err != nil && !isBlockError(err) {
+ panic(fmt.Sprintf("failed to read from pipe: %v", err))
+ }
+ p.readAheadBuffer = append(p.readAheadBuffer, data...)
+ } else if p.flags.Write {
+ file, err := p.opener.NonBlockingOpen(context.Background(), fs.PermMask{Write: true})
+ if err != nil {
+ panic(fs.ErrSaveRejection{fmt.Errorf("write-only pipe end cannot be re-opened as %v: %v", p, err)})
+ }
+ file.Close()
+ }
+}
+
+// saveFlags is invoked by stateify.
+func (p *pipeOperations) saveFlags() fs.FileFlags {
+ return p.flags
+}
+
+// readPipeOperationsLoading is used to ensure that write-only pipe fds are
+// opened after read/write and read-only pipe fds, to avoid ENXIO when
+// multiple pipe fds refer to different ends of the same pipe.
+var readPipeOperationsLoading sync.WaitGroup
+
+// loadFlags is invoked by stateify.
+func (p *pipeOperations) loadFlags(flags fs.FileFlags) {
+ // This is a hack to ensure that readPipeOperationsLoading includes all
+ // readable pipe fds before any asynchronous calls to
+ // readPipeOperationsLoading.Wait().
+ if flags.Read {
+ readPipeOperationsLoading.Add(1)
+ }
+ p.flags = flags
+}
+
+// afterLoad is invoked by stateify.
+func (p *pipeOperations) afterLoad() {
+ load := func() error {
+ if !p.flags.Read {
+ readPipeOperationsLoading.Wait()
+ } else {
+ defer readPipeOperationsLoading.Done()
+ }
+ var err error
+ p.file, err = p.opener.NonBlockingOpen(context.Background(), fs.PermMask{
+ Read: p.flags.Read,
+ Write: p.flags.Write,
+ })
+ if err != nil {
+ return fmt.Errorf("unable to open pipe %v: %v", p, err)
+ }
+ if err := p.init(); err != nil {
+ return fmt.Errorf("unable to initialize pipe %v: %v", p, err)
+ }
+ return nil
+ }
+
+ // Do background opening of pipe ends. Note for write-only pipe ends we
+ // have to do it asynchronously to avoid blocking the restore.
+ fs.Async(fs.CatchError(load))
+}
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
new file mode 100644
index 000000000..8c1307235
--- /dev/null
+++ b/pkg/sentry/fs/file.go
@@ -0,0 +1,556 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/amutex"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/lock"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+var (
+ // RecordWaitTime controls writing metrics for filesystem reads.
+ // Enabling this comes at a small CPU cost due to performing two
+ // monotonic clock reads per read call.
+ //
+ // Note that this is only performed in the direct read path, and may
+ // not be consistently applied for other forms of reads, such as
+ // splice.
+ RecordWaitTime = false
+
+ reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.")
+ readWait = metric.MustCreateNewUint64Metric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.")
+)
+
+// IncrementWait increments the given wait time metric, if enabled.
+func IncrementWait(m *metric.Uint64Metric, start time.Time) {
+ if !RecordWaitTime {
+ return
+ }
+ m.IncrementBy(uint64(time.Since(start)))
+}
+
+// FileMaxOffset is the maximum possible file offset.
+const FileMaxOffset = math.MaxInt64
+
+// File is an open file handle. It is thread-safe.
+//
+// File provides stronger synchronization guarantees than Linux. Linux
+// synchronizes lseek(2), read(2), and write(2) with respect to the file
+// offset for regular files and only for those interfaces. See
+// fs/read_write.c:fdget_pos, fs.read_write.c:fdput_pos and FMODE_ATOMIC_POS.
+//
+// In contrast, File synchronizes any operation that could take a long time
+// under a single abortable mutex which also synchronizes lseek(2), read(2),
+// and write(2).
+//
+// FIXME(b/38451980): Split synchronization from cancellation.
+//
+// +stateify savable
+type File struct {
+ refs.AtomicRefCount
+
+ // UniqueID is the globally unique identifier of the File.
+ UniqueID uint64
+
+ // Dirent is the Dirent backing this File. This encodes the name
+ // of the File via Dirent.FullName() as well as its identity via the
+ // Dirent's Inode. The Dirent is non-nil.
+ //
+ // A File holds a reference to this Dirent. Using the returned Dirent is
+ // only safe as long as a reference on the File is held. The association
+ // between a File and a Dirent is immutable.
+ //
+ // Files that are not parented in a filesystem return a root Dirent
+ // that holds a reference to their Inode.
+ //
+ // The name of the Dirent may reflect parentage if the Dirent is not a
+ // root Dirent or the identity of the File on a pseudo filesystem (pipefs,
+ // sockfs, etc).
+ //
+ // Multiple Files may hold a reference to the same Dirent. This is the
+ // common case for Files that are parented and maintain consistency with
+ // other files via the Dirent cache.
+ Dirent *Dirent
+
+ // flagsMu protects flags and async below.
+ flagsMu sync.Mutex `state:"nosave"`
+
+ // flags are the File's flags. Setting or getting flags is fully atomic
+ // and is not protected by mu (below).
+ flags FileFlags
+
+ // async handles O_ASYNC notifications.
+ async FileAsync
+
+ // saving indicates that this file is in the process of being saved.
+ saving bool `state:"nosave"`
+
+ // mu is dual-purpose: first, to make read(2) and write(2) thread-safe
+ // in conformity with POSIX, and second, to cancel operations before they
+ // begin in response to interruptions (i.e. signals).
+ mu amutex.AbortableMutex `state:"nosave"`
+
+ // FileOperations implements file system specific behavior for this File.
+ FileOperations FileOperations `state:"wait"`
+
+ // offset is the File's offset. Updating offset is protected by mu but
+ // can be read atomically via File.Offset() outside of mu.
+ offset int64
+}
+
+// NewFile returns a File. It takes a reference on the Dirent and owns the
+// lifetime of the FileOperations. Files that do not support reading and
+// writing at an arbitrary offset should set flags.Pread and flags.Pwrite
+// to false respectively.
+func NewFile(ctx context.Context, dirent *Dirent, flags FileFlags, fops FileOperations) *File {
+ dirent.IncRef()
+ f := &File{
+ UniqueID: uniqueid.GlobalFromContext(ctx),
+ Dirent: dirent,
+ FileOperations: fops,
+ flags: flags,
+ }
+ f.mu.Init()
+ return f
+}
+
+// DecRef destroys the File when it is no longer referenced.
+func (f *File) DecRef() {
+ f.DecRefWithDestructor(func() {
+ // Drop BSD style locks.
+ lockRng := lock.LockRange{Start: 0, End: lock.LockEOF}
+ f.Dirent.Inode.LockCtx.BSD.UnlockRegion(lock.UniqueID(f.UniqueID), lockRng)
+
+ // Release resources held by the FileOperations.
+ f.FileOperations.Release()
+
+ // Release a reference on the Dirent.
+ f.Dirent.DecRef()
+
+ // Only unregister if we are currently registered. There is nothing
+ // to register if f.async is nil (this happens when async mode is
+ // enabled without setting an owner). Also, we unregister during
+ // save.
+ f.flagsMu.Lock()
+ if !f.saving && f.flags.Async && f.async != nil {
+ f.async.Unregister(f)
+ }
+ f.async = nil
+ f.flagsMu.Unlock()
+ })
+}
+
+// Flags atomically loads the File's flags.
+func (f *File) Flags() FileFlags {
+ f.flagsMu.Lock()
+ flags := f.flags
+ f.flagsMu.Unlock()
+ return flags
+}
+
+// SetFlags atomically changes the File's flags to the values contained
+// in newFlags. See SettableFileFlags for values that can be set.
+func (f *File) SetFlags(newFlags SettableFileFlags) {
+ f.flagsMu.Lock()
+ f.flags.Direct = newFlags.Direct
+ f.flags.NonBlocking = newFlags.NonBlocking
+ f.flags.Append = newFlags.Append
+ if f.async != nil {
+ if newFlags.Async && !f.flags.Async {
+ f.async.Register(f)
+ }
+ if !newFlags.Async && f.flags.Async {
+ f.async.Unregister(f)
+ }
+ }
+ f.flags.Async = newFlags.Async
+ f.flagsMu.Unlock()
+}
+
+// Offset atomically loads the File's offset.
+func (f *File) Offset() int64 {
+ return atomic.LoadInt64(&f.offset)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (f *File) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return f.FileOperations.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (f *File) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ f.FileOperations.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (f *File) EventUnregister(e *waiter.Entry) {
+ f.FileOperations.EventUnregister(e)
+}
+
+// Seek calls f.FileOperations.Seek with f as the File, updating the file
+// offset to the value returned by f.FileOperations.Seek if the operation
+// is successful.
+//
+// Returns syserror.ErrInterrupted if seeking was interrupted.
+func (f *File) Seek(ctx context.Context, whence SeekWhence, offset int64) (int64, error) {
+ if !f.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ newOffset, err := f.FileOperations.Seek(ctx, f, whence, offset)
+ if err == nil {
+ atomic.StoreInt64(&f.offset, newOffset)
+ }
+ return newOffset, err
+}
+
+// Readdir reads the directory entries of this File and writes them out
+// to the DentrySerializer until entries can no longer be written. If even
+// a single directory entry is written then Readdir returns a nil error
+// and the directory offset is advanced.
+//
+// Readdir unconditionally updates the access time on the File's Inode,
+// see fs/readdir.c:iterate_dir.
+//
+// Returns syserror.ErrInterrupted if reading was interrupted.
+func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error {
+ if !f.mu.Lock(ctx) {
+ return syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ offset, err := f.FileOperations.Readdir(ctx, f, serializer)
+ atomic.StoreInt64(&f.offset, offset)
+ return err
+}
+
+// Readv calls f.FileOperations.Read with f as the File, advancing the file
+// offset if f.FileOperations.Read returns bytes read > 0.
+//
+// Returns syserror.ErrInterrupted if reading was interrupted.
+func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ var start time.Time
+ if RecordWaitTime {
+ start = time.Now()
+ }
+ if !f.mu.Lock(ctx) {
+ IncrementWait(readWait, start)
+ return 0, syserror.ErrInterrupted
+ }
+
+ reads.Increment()
+ n, err := f.FileOperations.Read(ctx, f, dst, f.offset)
+ if n > 0 {
+ atomic.AddInt64(&f.offset, n)
+ }
+ f.mu.Unlock()
+ IncrementWait(readWait, start)
+ return n, err
+}
+
+// Preadv calls f.FileOperations.Read with f as the File. It does not
+// advance the file offset. If !f.Flags().Pread, Preadv should not be
+// called.
+//
+// Otherwise same as Readv.
+func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var start time.Time
+ if RecordWaitTime {
+ start = time.Now()
+ }
+ if !f.mu.Lock(ctx) {
+ IncrementWait(readWait, start)
+ return 0, syserror.ErrInterrupted
+ }
+
+ reads.Increment()
+ n, err := f.FileOperations.Read(ctx, f, dst, offset)
+ f.mu.Unlock()
+ IncrementWait(readWait, start)
+ return n, err
+}
+
+// Writev calls f.FileOperations.Write with f as the File, advancing the
+// file offset if f.FileOperations.Write returns bytes written > 0.
+//
+// Writev positions the write offset at EOF if f.Flags().Append. This is
+// unavoidably racy for network file systems. Writev also truncates src
+// to avoid overrunning the current file size limit if necessary.
+//
+// Returns syserror.ErrInterrupted if writing was interrupted.
+func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ if !f.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+
+ // Handle append mode.
+ if f.Flags().Append {
+ if err := f.offsetForAppend(ctx, &f.offset); err != nil {
+ f.mu.Unlock()
+ return 0, err
+ }
+ }
+
+ // Enforce file limits.
+ limit, ok := f.checkLimit(ctx, f.offset)
+ switch {
+ case ok && limit == 0:
+ f.mu.Unlock()
+ return 0, syserror.ErrExceedsFileSizeLimit
+ case ok:
+ src = src.TakeFirst64(limit)
+ }
+
+ // We must hold the lock during the write.
+ n, err := f.FileOperations.Write(ctx, f, src, f.offset)
+ if n >= 0 {
+ atomic.StoreInt64(&f.offset, f.offset+n)
+ }
+ f.mu.Unlock()
+ return n, err
+}
+
+// Pwritev calls f.FileOperations.Write with f as the File. It does not
+// advance the file offset. If !f.Flags().Pwritev, Pwritev should not be
+// called.
+//
+// Otherwise same as Writev.
+func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // "POSIX requires that opening a file with the O_APPEND flag should
+ // have no effect on the location at which pwrite() writes data.
+ // However, on Linux, if a file is opened with O_APPEND, pwrite()
+ // appends data to the end of the file, regardless of the value of
+ // offset."
+ if f.Flags().Append {
+ if !f.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+ if err := f.offsetForAppend(ctx, &offset); err != nil {
+ f.mu.Unlock()
+ return 0, err
+ }
+ }
+
+ // Enforce file limits.
+ limit, ok := f.checkLimit(ctx, offset)
+ switch {
+ case ok && limit == 0:
+ return 0, syserror.ErrExceedsFileSizeLimit
+ case ok:
+ src = src.TakeFirst64(limit)
+ }
+
+ return f.FileOperations.Write(ctx, f, src, offset)
+}
+
+// offsetForAppend sets the given offset to the end of the file.
+//
+// Precondition: the underlying file mutex should be held.
+func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
+ uattr, err := f.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ // This is an odd error, we treat it as evidence that
+ // something is terribly wrong with the filesystem.
+ return syserror.EIO
+ }
+
+ // Update the offset.
+ *offset = uattr.Size
+
+ return nil
+}
+
+// checkLimit checks the offset that the write will be performed at. The
+// returned boolean indicates that the write must be limited. The returned
+// integer indicates the new maximum write length.
+func (f *File) checkLimit(ctx context.Context, offset int64) (int64, bool) {
+ if IsRegular(f.Dirent.Inode.StableAttr) {
+ // Enforce size limits.
+ fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur
+ if fileSizeLimit <= math.MaxInt64 {
+ if offset >= int64(fileSizeLimit) {
+ return 0, true
+ }
+ return int64(fileSizeLimit) - offset, true
+ }
+ }
+
+ return 0, false
+}
+
+// Fsync calls f.FileOperations.Fsync with f as the File.
+//
+// Returns syserror.ErrInterrupted if syncing was interrupted.
+func (f *File) Fsync(ctx context.Context, start int64, end int64, syncType SyncType) error {
+ if !f.mu.Lock(ctx) {
+ return syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ return f.FileOperations.Fsync(ctx, f, start, end, syncType)
+}
+
+// Flush calls f.FileOperations.Flush with f as the File.
+//
+// Returns syserror.ErrInterrupted if syncing was interrupted.
+func (f *File) Flush(ctx context.Context) error {
+ if !f.mu.Lock(ctx) {
+ return syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ return f.FileOperations.Flush(ctx, f)
+}
+
+// ConfigureMMap calls f.FileOperations.ConfigureMMap with f as the File.
+//
+// Returns syserror.ErrInterrupted if interrupted.
+func (f *File) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ if !f.mu.Lock(ctx) {
+ return syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ return f.FileOperations.ConfigureMMap(ctx, f, opts)
+}
+
+// UnstableAttr calls f.FileOperations.UnstableAttr with f as the File.
+//
+// Returns syserror.ErrInterrupted if interrupted.
+func (f *File) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
+ if !f.mu.Lock(ctx) {
+ return UnstableAttr{}, syserror.ErrInterrupted
+ }
+ defer f.mu.Unlock()
+
+ return f.FileOperations.UnstableAttr(ctx, f)
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (f *File) MappedName(ctx context.Context) string {
+ root := RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ name, _ := f.Dirent.FullName(root)
+ return name
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (f *File) DeviceID() uint64 {
+ return f.Dirent.Inode.StableAttr.DeviceID
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (f *File) InodeID() uint64 {
+ return f.Dirent.Inode.StableAttr.InodeID
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (f *File) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ return f.Fsync(ctx, int64(mr.Start), int64(mr.End-1), SyncData)
+}
+
+// A FileAsync sends signals to its owner when w is ready for IO.
+type FileAsync interface {
+ Register(w waiter.Waitable)
+ Unregister(w waiter.Waitable)
+}
+
+// Async gets the stored FileAsync or creates a new one with the supplied
+// function. If the supplied function is nil, no FileAsync is created and the
+// current value is returned.
+func (f *File) Async(newAsync func() FileAsync) FileAsync {
+ f.flagsMu.Lock()
+ defer f.flagsMu.Unlock()
+ if f.async == nil && newAsync != nil {
+ f.async = newAsync()
+ if f.flags.Async {
+ f.async.Register(f)
+ }
+ }
+ return f.async
+}
+
+// lockedReader implements io.Reader and io.ReaderAt.
+//
+// Note this reads the underlying file using the file operations directly. It
+// is the responsibility of the caller to ensure that locks are appropriately
+// held and offsets updated if required. This should be used only by internal
+// functions that perform these operations and checks at other times.
+type lockedReader struct {
+ // Ctx is the context for the file reader.
+ Ctx context.Context
+
+ // File is the file to read from.
+ File *File
+}
+
+// Read implements io.Reader.Read.
+func (r *lockedReader) Read(buf []byte) (int, error) {
+ if r.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset)
+ return int(n), err
+}
+
+// ReadAt implements io.Reader.ReadAt.
+func (r *lockedReader) ReadAt(buf []byte, offset int64) (int, error) {
+ if r.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), offset)
+ return int(n), err
+}
+
+// lockedWriter implements io.Writer and io.WriterAt.
+//
+// The same constraints as lockedReader apply; see above.
+type lockedWriter struct {
+ // Ctx is the context for the file writer.
+ Ctx context.Context
+
+ // File is the file to write to.
+ File *File
+}
+
+// Write implements io.Writer.Write.
+func (w *lockedWriter) Write(buf []byte) (int, error) {
+ n, err := w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf), w.File.offset)
+ return int(n), err
+}
+
+// WriteAt implements io.Writer.WriteAt.
+func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
+ n, err := w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf), offset)
+ return int(n), err
+}
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
new file mode 100644
index 000000000..0f2dfa273
--- /dev/null
+++ b/pkg/sentry/fs/file_operations.go
@@ -0,0 +1,159 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// SpliceOpts define how a splice works.
+type SpliceOpts struct {
+ // Length is the length of the splice operation.
+ Length int64
+
+ // SrcOffset indicates whether the existing source file offset should
+ // be used. If this is true, then the Start value below is used.
+ //
+ // When passed to FileOperations object, this should always be true as
+ // the offset will be provided by a layer above, unless the object in
+ // question is a pipe or socket. This value can be relied upon for such
+ // an indicator.
+ SrcOffset bool
+
+ // SrcStart is the start of the source file. This is used only if
+ // SrcOffset is false.
+ SrcStart int64
+
+ // Dup indicates that the contents should not be consumed from the
+ // source (e.g. in the case of a socket or a pipe), but duplicated.
+ Dup bool
+
+ // DstOffset indicates that the destination file offset should be used.
+ //
+ // See SrcOffset for additional information.
+ DstOffset bool
+
+ // DstStart is the start of the destination file. This is used only if
+ // DstOffset is false.
+ DstStart int64
+}
+
+// FileOperations are operations on a File that diverge per file system.
+//
+// Operations that take a *File may use only the following interfaces:
+//
+// - File.UniqueID: Operations may only read this value.
+// - File.Dirent: Operations must not take or drop a reference.
+// - File.Offset(): This value is guaranteed to not change for the
+// duration of the operation.
+// - File.Flags(): This value may change during the operation.
+type FileOperations interface {
+ // Release release resources held by FileOperations.
+ Release()
+
+ // Waitable defines how this File can be waited on for read and
+ // write readiness.
+ waiter.Waitable
+
+ // Seek seeks to offset based on SeekWhence. Returns the new
+ // offset or no change in the offset and an error.
+ Seek(ctx context.Context, file *File, whence SeekWhence, offset int64) (int64, error)
+
+ // Readdir reads the directory entries of file and serializes them
+ // using serializer.
+ //
+ // Returns the new directory offset or no change in the offset and
+ // an error. The offset returned must not be less than file.Offset().
+ //
+ // Serialization of directory entries must not happen asynchronously.
+ Readdir(ctx context.Context, file *File, serializer DentrySerializer) (int64, error)
+
+ // Read reads from file into dst at offset and returns the number
+ // of bytes read which must be greater than or equal to 0. File
+ // systems that do not support reading at an offset, (i.e. pipefs,
+ // sockfs) may ignore the offset. These file systems are expected
+ // to construct Files with !FileFlags.Pread.
+ //
+ // Read may return a nil error and only partially fill dst (at or
+ // before EOF). If the file represents a symlink, Read reads the target
+ // value of the symlink.
+ //
+ // Read does not check permissions nor flags.
+ //
+ // Read must not be called if !FileFlags.Read.
+ Read(ctx context.Context, file *File, dst usermem.IOSequence, offset int64) (int64, error)
+
+ // WriteTo is a variant of read that takes another file as a
+ // destination. For a splice (copy or move from one file to another),
+ // first a WriteTo on the source is attempted, followed by a ReadFrom
+ // on the destination, following by a buffered copy with standard Read
+ // and Write operations.
+ //
+ // The same preconditions as Read apply.
+ WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error)
+
+ // Write writes src to file at offset and returns the number of bytes
+ // written which must be greater than or equal to 0. Like Read, file
+ // systems that do not support writing at an offset (i.e. pipefs, sockfs)
+ // may ignore the offset. These file systems are expected to construct
+ // Files with !FileFlags.Pwrite.
+ //
+ // If only part of src could be written, Write must return an error
+ // indicating why (e.g. syserror.ErrWouldBlock).
+ //
+ // Write does not check permissions nor flags.
+ //
+ // Write must not be called if !FileFlags.Write.
+ Write(ctx context.Context, file *File, src usermem.IOSequence, offset int64) (int64, error)
+
+ // ReadFrom is a variant of write that takes a another file as a
+ // source. See WriteTo for details regarding how this is called.
+ //
+ // The same preconditions as Write apply; FileFlags.Write must be set.
+ ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error)
+
+ // Fsync writes buffered modifications of file and/or flushes in-flight
+ // operations to backing storage based on syncType. The range to sync is
+ // [start, end]. The end is inclusive so that the last byte of a maximally
+ // sized file can be synced.
+ Fsync(ctx context.Context, file *File, start, end int64, syncType SyncType) error
+
+ // Flush this file's buffers/state (on close(2)).
+ Flush(ctx context.Context, file *File) error
+
+ // ConfigureMMap mutates opts to implement mmap(2) for the file. Most
+ // implementations can either embed fsutil.FileNoMMap (if they don't support
+ // memory mapping) or call fsutil.GenericConfigureMMap with the appropriate
+ // memmap.Mappable.
+ ConfigureMMap(ctx context.Context, file *File, opts *memmap.MMapOpts) error
+
+ // UnstableAttr returns the "unstable" attributes of the inode represented
+ // by the file. Most implementations can embed
+ // fsutil.FileUseInodeUnstableAttr, which delegates to
+ // InodeOperations.UnstableAttr.
+ UnstableAttr(ctx context.Context, file *File) (UnstableAttr, error)
+
+ // Ioctl implements the ioctl(2) linux syscall.
+ //
+ // io provides access to the virtual memory space to which pointers in args
+ // refer.
+ //
+ // Preconditions: The AddressSpace (if any) that io refers to is activated.
+ Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error)
+}
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
new file mode 100644
index 000000000..273de1e14
--- /dev/null
+++ b/pkg/sentry/fs/file_overlay.go
@@ -0,0 +1,505 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// overlayFile gets a handle to a file from the upper or lower filesystem
+// in an overlay. The caller is responsible for calling File.DecRef on
+// the returned file.
+func overlayFile(ctx context.Context, inode *Inode, flags FileFlags) (*File, error) {
+ // Do a song and dance to eventually get to:
+ //
+ // File -> single reference
+ // Dirent -> single reference
+ // Inode -> multiple references
+ //
+ // So that File.DecRef() -> File.destroy -> Dirent.DecRef -> Dirent.destroy,
+ // and both the transitory File and Dirent can be GC'ed but the Inode
+ // remains.
+
+ // Take another reference on the Inode.
+ inode.IncRef()
+
+ // Start with a single reference on the Dirent. It inherits the reference
+ // we just took on the Inode above.
+ dirent := NewTransientDirent(inode)
+
+ // Get a File. This will take another reference on the Dirent.
+ f, err := inode.GetFile(ctx, dirent, flags)
+
+ // Drop the extra reference on the Dirent. Now there's only one reference
+ // on the dirent, either owned by f (if non-nil), or the Dirent is about
+ // to be destroyed (if GetFile failed).
+ dirent.DecRef()
+
+ return f, err
+}
+
+// overlayFileOperations implements FileOperations for a file in an overlay.
+//
+// +stateify savable
+type overlayFileOperations struct {
+ // upperMu protects upper below. In contrast lower is stable.
+ upperMu sync.Mutex `state:"nosave"`
+
+ // We can't share Files in upper and lower filesystems between all Files
+ // in an overlay because some file systems expect to get distinct handles
+ // that are not consistent with each other on open(2).
+ //
+ // So we lazily acquire an upper File when the overlayEntry acquires an
+ // upper Inode (it might have one from the start). This synchronizes with
+ // copy up.
+ //
+ // If upper is non-nil and this is not a directory, then lower is ignored.
+ //
+ // For directories, upper and lower are ignored because it is always
+ // necessary to acquire new directory handles so that the directory cursors
+ // of the upper and lower Files are not exhausted.
+ upper *File
+ lower *File
+
+ // dirCursor is a directory cursor for a directory in an overlay. It is
+ // protected by File.mu of the owning file, which is held during
+ // Readdir and Seek calls.
+ dirCursor string
+
+ // dirCacheMu protects dirCache.
+ dirCacheMu sync.RWMutex `state:"nosave"`
+
+ // dirCache is cache of DentAttrs from upper and lower Inodes.
+ dirCache *SortedDentryMap
+}
+
+// Release implements FileOperations.Release.
+func (f *overlayFileOperations) Release() {
+ if f.upper != nil {
+ f.upper.DecRef()
+ }
+ if f.lower != nil {
+ f.lower.DecRef()
+ }
+}
+
+// EventRegister implements FileOperations.EventRegister.
+func (f *overlayFileOperations) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+ if f.upper != nil {
+ f.upper.EventRegister(we, mask)
+ return
+ }
+ f.lower.EventRegister(we, mask)
+}
+
+// EventUnregister implements FileOperations.Unregister.
+func (f *overlayFileOperations) EventUnregister(we *waiter.Entry) {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+ if f.upper != nil {
+ f.upper.EventUnregister(we)
+ return
+ }
+ f.lower.EventUnregister(we)
+}
+
+// Readiness implements FileOperations.Readiness.
+func (f *overlayFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+ if f.upper != nil {
+ return f.upper.Readiness(mask)
+ }
+ return f.lower.Readiness(mask)
+}
+
+// Seek implements FileOperations.Seek.
+func (f *overlayFileOperations) Seek(ctx context.Context, file *File, whence SeekWhence, offset int64) (int64, error) {
+ f.upperMu.Lock()
+ defer f.upperMu.Unlock()
+
+ var seekDir bool
+ var n int64
+ if f.upper != nil {
+ var err error
+ if n, err = f.upper.FileOperations.Seek(ctx, file, whence, offset); err != nil {
+ return n, err
+ }
+ seekDir = IsDir(f.upper.Dirent.Inode.StableAttr)
+ } else {
+ var err error
+ if n, err = f.lower.FileOperations.Seek(ctx, file, whence, offset); err != nil {
+ return n, err
+ }
+ seekDir = IsDir(f.lower.Dirent.Inode.StableAttr)
+ }
+
+ // If this was a seek on a directory, we must update the cursor.
+ if seekDir && whence == SeekSet && offset == 0 {
+ // Currently only seeking to 0 on a directory is supported.
+ // FIXME(b/33075855): Lift directory seeking limitations.
+ f.dirCursor = ""
+ }
+ return n, nil
+}
+
+// Readdir implements FileOperations.Readdir.
+func (f *overlayFileOperations) Readdir(ctx context.Context, file *File, serializer DentrySerializer) (int64, error) {
+ root := RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &DirCtx{
+ Serializer: serializer,
+ DirCursor: &f.dirCursor,
+ }
+
+ // If the directory dirent is frozen, then DirentReaddir will calculate
+ // the children based off the frozen dirent tree. There is no need to
+ // call readdir on the upper/lower layers.
+ if file.Dirent.frozen {
+ return DirentReaddir(ctx, file.Dirent, f, root, dirCtx, file.Offset())
+ }
+
+ // Otherwise proceed with usual overlay readdir.
+ o := file.Dirent.Inode.overlay
+
+ // readdirEntries holds o.copyUpMu to ensure that copy-up does not
+ // occur while calculating the readir results.
+ //
+ // However, it is possible for a copy-up to occur after the call to
+ // readdirEntries, but before setting f.dirCache. This is OK, since
+ // copy-up only does not change the children in a way that would affect
+ // the children returned in dirCache. Copy-up only moves
+ // files/directories between layers in the overlay.
+ //
+ // It is also possible for Readdir to race with a Create operation
+ // (which may trigger a copy-up during it's execution). Depending on
+ // whether the Create happens before or after the readdirEntries call,
+ // the newly created file may or may not appear in the readdir results.
+ // But this can only be caused by a real race between readdir and
+ // create syscalls, so it's also OK.
+ dirCache, err := readdirEntries(ctx, o)
+ if err != nil {
+ return file.Offset(), err
+ }
+
+ f.dirCacheMu.Lock()
+ f.dirCache = dirCache
+ f.dirCacheMu.Unlock()
+
+ return DirentReaddir(ctx, file.Dirent, f, root, dirCtx, file.Offset())
+}
+
+// IterateDir implements DirIterator.IterateDir.
+func (f *overlayFileOperations) IterateDir(ctx context.Context, dirCtx *DirCtx, offset int) (int, error) {
+ f.dirCacheMu.RLock()
+ n, err := GenericReaddir(dirCtx, f.dirCache)
+ f.dirCacheMu.RUnlock()
+ return offset + n, err
+}
+
+// onTop performs the given operation on the top-most available layer.
+func (f *overlayFileOperations) onTop(ctx context.Context, file *File, fn func(*File, FileOperations) error) error {
+ file.Dirent.Inode.overlay.copyMu.RLock()
+ defer file.Dirent.Inode.overlay.copyMu.RUnlock()
+
+ // Only lower layer is available.
+ if file.Dirent.Inode.overlay.upper == nil {
+ return fn(f.lower, f.lower.FileOperations)
+ }
+
+ f.upperMu.Lock()
+ if f.upper == nil {
+ upper, err := overlayFile(ctx, file.Dirent.Inode.overlay.upper, file.Flags())
+ if err != nil {
+ // Something very wrong; return a generic filesystem
+ // error to avoid propagating internals.
+ f.upperMu.Unlock()
+ return syserror.EIO
+ }
+
+ // Save upper file.
+ f.upper = upper
+ }
+ f.upperMu.Unlock()
+
+ return fn(f.upper, f.upper.FileOperations)
+}
+
+// Read implements FileOperations.Read.
+func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst usermem.IOSequence, offset int64) (n int64, err error) {
+ err = f.onTop(ctx, file, func(file *File, ops FileOperations) error {
+ n, err = ops.Read(ctx, file, dst, offset)
+ return err // Will overwrite itself.
+ })
+ return
+}
+
+// WriteTo implements FileOperations.WriteTo.
+func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) {
+ err = f.onTop(ctx, file, func(file *File, ops FileOperations) error {
+ n, err = ops.WriteTo(ctx, file, dst, opts)
+ return err // Will overwrite itself.
+ })
+ return
+}
+
+// Write implements FileOperations.Write.
+func (f *overlayFileOperations) Write(ctx context.Context, file *File, src usermem.IOSequence, offset int64) (int64, error) {
+ // f.upper must be non-nil. See inode_overlay.go:overlayGetFile, where the
+ // file is copied up and opened in the upper filesystem if FileFlags.Write.
+ // Write cannot be called if !FileFlags.Write, see FileOperations.Write.
+ return f.upper.FileOperations.Write(ctx, f.upper, src, offset)
+}
+
+// ReadFrom implements FileOperations.ReadFrom.
+func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) {
+ // See above; f.upper must be non-nil.
+ return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts)
+}
+
+// Fsync implements FileOperations.Fsync.
+func (f *overlayFileOperations) Fsync(ctx context.Context, file *File, start, end int64, syncType SyncType) (err error) {
+ f.upperMu.Lock()
+ if f.upper != nil {
+ err = f.upper.FileOperations.Fsync(ctx, f.upper, start, end, syncType)
+ }
+ f.upperMu.Unlock()
+ if err == nil && f.lower != nil {
+ // N.B. Fsync on the lower filesystem can cause writes of file
+ // attributes (i.e. access time) despite the fact that we must
+ // treat the lower filesystem as read-only.
+ //
+ // This matches the semantics of fsync(2) in Linux overlayfs.
+ err = f.lower.FileOperations.Fsync(ctx, f.lower, start, end, syncType)
+ }
+ return err
+}
+
+// Flush implements FileOperations.Flush.
+func (f *overlayFileOperations) Flush(ctx context.Context, file *File) (err error) {
+ // Flush whatever handles we have.
+ f.upperMu.Lock()
+ if f.upper != nil {
+ err = f.upper.FileOperations.Flush(ctx, f.upper)
+ }
+ f.upperMu.Unlock()
+ if err == nil && f.lower != nil {
+ err = f.lower.FileOperations.Flush(ctx, f.lower)
+ }
+ return err
+}
+
+// ConfigureMMap implements FileOperations.ConfigureMMap.
+func (*overlayFileOperations) ConfigureMMap(ctx context.Context, file *File, opts *memmap.MMapOpts) error {
+ o := file.Dirent.Inode.overlay
+
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ // If there is no lower inode, the overlay will never need to do a
+ // copy-up, and thus will never need to invalidate any mappings. We can
+ // call ConfigureMMap directly on the upper file.
+ if o.lower == nil {
+ f := file.FileOperations.(*overlayFileOperations)
+ if err := f.upper.ConfigureMMap(ctx, opts); err != nil {
+ return err
+ }
+
+ // ConfigureMMap will set the MappableIdentity to the upper
+ // file and take a reference on it, but we must also hold a
+ // reference to the overlay file during the lifetime of the
+ // Mappable. If we do not do this, the overlay file can be
+ // Released before the upper file is Released, and we will be
+ // unable to traverse to the upper file during Save, thus
+ // preventing us from saving a proper inode mapping for the
+ // file.
+ file.IncRef()
+ id := &overlayMappingIdentity{
+ id: opts.MappingIdentity,
+ overlayFile: file,
+ }
+
+ // Swap out the old MappingIdentity for the wrapped one.
+ opts.MappingIdentity = id
+ return nil
+ }
+
+ if !o.isMappableLocked() {
+ return syserror.ENODEV
+ }
+
+ // FIXME(jamieliu): This is a copy/paste of fsutil.GenericConfigureMMap,
+ // which we can't use because the overlay implementation is in package fs,
+ // so depending on fs/fsutil would create a circular dependency. Move
+ // overlay to fs/overlay.
+ opts.Mappable = o
+ opts.MappingIdentity = file
+ file.IncRef()
+ return nil
+}
+
+// UnstableAttr implements fs.FileOperations.UnstableAttr.
+func (f *overlayFileOperations) UnstableAttr(ctx context.Context, file *File) (UnstableAttr, error) {
+ // Hot path. Avoid defers.
+ f.upperMu.Lock()
+ if f.upper != nil {
+ attr, err := f.upper.UnstableAttr(ctx)
+ f.upperMu.Unlock()
+ return attr, err
+ }
+ f.upperMu.Unlock()
+
+ // It's possible that copy-up has occurred, but we haven't opened a upper
+ // file yet. If this is the case, just use the upper inode's UnstableAttr
+ // rather than opening a file.
+ o := file.Dirent.Inode.overlay
+ o.copyMu.RLock()
+ if o.upper != nil {
+ attr, err := o.upper.UnstableAttr(ctx)
+ o.copyMu.RUnlock()
+ return attr, err
+ }
+ o.copyMu.RUnlock()
+
+ return f.lower.UnstableAttr(ctx)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl and always returns ENOTTY.
+func (*overlayFileOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return 0, syserror.ENOTTY
+}
+
+// readdirEntries returns a sorted map of directory entries from the
+// upper and/or lower filesystem.
+func readdirEntries(ctx context.Context, o *overlayEntry) (*SortedDentryMap, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ // Assert that there is at least one upper or lower entry.
+ if o.upper == nil && o.lower == nil {
+ panic("invalid overlayEntry, needs at least one Inode")
+ }
+ entries := make(map[string]DentAttr)
+
+ // Try the upper filesystem first.
+ if o.upper != nil {
+ var err error
+ entries, err = readdirOne(ctx, NewTransientDirent(o.upper))
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Try the lower filesystem next.
+ if o.lower != nil {
+ lowerEntries, err := readdirOne(ctx, NewTransientDirent(o.lower))
+ if err != nil {
+ return nil, err
+ }
+ for name, entry := range lowerEntries {
+ // Skip this name if it is a negative entry in the
+ // upper or there exists a whiteout for it.
+ if o.upper != nil {
+ if overlayHasWhiteout(o.upper, name) {
+ continue
+ }
+ }
+ // Prefer the entries from the upper filesystem
+ // when names overlap.
+ if _, ok := entries[name]; !ok {
+ entries[name] = entry
+ }
+ }
+ }
+
+ // Sort and return the entries.
+ return NewSortedDentryMap(entries), nil
+}
+
+// readdirOne reads all of the directory entries from d.
+func readdirOne(ctx context.Context, d *Dirent) (map[string]DentAttr, error) {
+ dir, err := d.Inode.GetFile(ctx, d, FileFlags{Read: true})
+ if err != nil {
+ return nil, err
+ }
+ defer dir.DecRef()
+
+ // Use a stub serializer to read the entries into memory.
+ stubSerializer := &CollectEntriesSerializer{}
+ if err := dir.Readdir(ctx, stubSerializer); err != nil {
+ return nil, err
+ }
+ // The "." and ".." entries are from the overlay Inode's Dirent, not the stub.
+ delete(stubSerializer.Entries, ".")
+ delete(stubSerializer.Entries, "..")
+ return stubSerializer.Entries, nil
+}
+
+// overlayMappingIdentity wraps a MappingIdentity, and also holds a reference
+// on a file during its lifetime.
+//
+// +stateify savable
+type overlayMappingIdentity struct {
+ refs.AtomicRefCount
+ id memmap.MappingIdentity
+ overlayFile *File
+}
+
+// DecRef implements AtomicRefCount.DecRef.
+func (omi *overlayMappingIdentity) DecRef() {
+ omi.AtomicRefCount.DecRefWithDestructor(func() {
+ omi.overlayFile.DecRef()
+ omi.id.DecRef()
+ })
+}
+
+// DeviceID implements MappingIdentity.DeviceID using the device id from the
+// overlayFile.
+func (omi *overlayMappingIdentity) DeviceID() uint64 {
+ return omi.overlayFile.Dirent.Inode.StableAttr.DeviceID
+}
+
+// DeviceID implements MappingIdentity.InodeID using the inode id from the
+// overlayFile.
+func (omi *overlayMappingIdentity) InodeID() uint64 {
+ return omi.overlayFile.Dirent.Inode.StableAttr.InodeID
+}
+
+// MappedName implements MappingIdentity.MappedName.
+func (omi *overlayMappingIdentity) MappedName(ctx context.Context) string {
+ root := RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ name, _ := omi.overlayFile.Dirent.FullName(root)
+ return name
+}
+
+// Msync implements MappingIdentity.Msync.
+func (omi *overlayMappingIdentity) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ return omi.id.Msync(ctx, mr)
+}
diff --git a/pkg/sentry/fs/file_state.go b/pkg/sentry/fs/file_state.go
new file mode 100644
index 000000000..523182d59
--- /dev/null
+++ b/pkg/sentry/fs/file_state.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+// beforeSave is invoked by stateify.
+func (f *File) beforeSave() {
+ f.saving = true
+ if f.flags.Async && f.async != nil {
+ f.async.Unregister(f)
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (f *File) afterLoad() {
+ f.mu.Init()
+ if f.flags.Async && f.async != nil {
+ f.async.Register(f)
+ }
+}
diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go
new file mode 100644
index 000000000..acd84dfcc
--- /dev/null
+++ b/pkg/sentry/fs/filesystems.go
@@ -0,0 +1,174 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags.
+type FilesystemFlags int
+
+const (
+ // FilesystemRequiresDev indicates that the file system requires a device name
+ // on mount. It is used to construct the output of /proc/filesystems.
+ FilesystemRequiresDev FilesystemFlags = 1
+
+ // Currently other flags are not used, but can be pulled in from
+ // include/linux/fs.h:file_system_type as needed.
+)
+
+// Filesystem is a mountable file system.
+type Filesystem interface {
+ // Name is the unique identifier of the file system. It corresponds to the
+ // filesystemtype argument of sys_mount and will appear in the output of
+ // /proc/filesystems.
+ Name() string
+
+ // Flags indicate common properties of the file system.
+ Flags() FilesystemFlags
+
+ // Mount generates a mountable Inode backed by device and configured
+ // using file system independent flags and file system dependent
+ // data options.
+ //
+ // Mount may return arbitrary errors. They do not need syserr translations.
+ Mount(ctx context.Context, device string, flags MountSourceFlags, data string, dataObj interface{}) (*Inode, error)
+
+ // AllowUserMount determines whether mount(2) is allowed to mount a
+ // file system of this type.
+ AllowUserMount() bool
+
+ // AllowUserList determines whether this filesystem is listed in
+ // /proc/filesystems
+ AllowUserList() bool
+}
+
+// filesystems is the global set of registered file systems. It does not need
+// to be saved. Packages registering and unregistering file systems must do so
+// before calling save/restore methods.
+var filesystems = struct {
+ // mu protects registered below.
+ mu sync.Mutex
+
+ // registered is a set of registered Filesystems.
+ registered map[string]Filesystem
+}{
+ registered: make(map[string]Filesystem),
+}
+
+// RegisterFilesystem registers a new file system that is visible to mount and
+// the /proc/filesystems list. Packages implementing Filesystem should call
+// RegisterFilesystem in init().
+func RegisterFilesystem(f Filesystem) {
+ filesystems.mu.Lock()
+ defer filesystems.mu.Unlock()
+
+ if _, ok := filesystems.registered[f.Name()]; ok {
+ panic(fmt.Sprintf("filesystem already registered at %q", f.Name()))
+ }
+ filesystems.registered[f.Name()] = f
+}
+
+// UnregisterFilesystem removes a file system from the global set. To keep the
+// file system set compatible with save/restore, UnregisterFilesystem must be
+// called before save/restore methods.
+//
+// For instance, packages may unregister their file system after it is mounted.
+// This makes sense for pseudo file systems that should not be visible or
+// mountable. See whitelistfs in fs/host/fs.go for one example.
+func UnregisterFilesystem(name string) {
+ filesystems.mu.Lock()
+ defer filesystems.mu.Unlock()
+
+ delete(filesystems.registered, name)
+}
+
+// FindFilesystem returns a Filesystem registered at name or (nil, false) if name
+// is not a file system type that can be found in /proc/filesystems.
+func FindFilesystem(name string) (Filesystem, bool) {
+ filesystems.mu.Lock()
+ defer filesystems.mu.Unlock()
+
+ f, ok := filesystems.registered[name]
+ return f, ok
+}
+
+// GetFilesystems returns the set of registered filesystems in a consistent order.
+func GetFilesystems() []Filesystem {
+ filesystems.mu.Lock()
+ defer filesystems.mu.Unlock()
+
+ var ss []Filesystem
+ for _, s := range filesystems.registered {
+ ss = append(ss, s)
+ }
+ sort.Slice(ss, func(i, j int) bool { return ss[i].Name() < ss[j].Name() })
+ return ss
+}
+
+// MountSourceFlags represents all mount option flags as a struct.
+//
+// +stateify savable
+type MountSourceFlags struct {
+ // ReadOnly corresponds to mount(2)'s "MS_RDONLY" and indicates that
+ // the filesystem should be mounted read-only.
+ ReadOnly bool
+
+ // NoAtime corresponds to mount(2)'s "MS_NOATIME" and indicates that
+ // the filesystem should not update access time in-place.
+ NoAtime bool
+
+ // ForcePageCache causes all filesystem I/O operations to use the page
+ // cache, even when the platform supports direct mapped I/O. This
+ // doesn't correspond to any Linux mount options.
+ ForcePageCache bool
+
+ // NoExec corresponds to mount(2)'s "MS_NOEXEC" and indicates that
+ // binaries from this file system can't be executed.
+ NoExec bool
+}
+
+// GenericMountSourceOptions splits a string containing comma separated tokens of the
+// format 'key=value' or 'key' into a map of keys and values. For example:
+//
+// data = "key0=value0,key1,key2=value2" -> map{'key0':'value0','key1':'','key2':'value2'}
+//
+// If data contains duplicate keys, then the last token wins.
+func GenericMountSourceOptions(data string) map[string]string {
+ options := make(map[string]string)
+ if len(data) == 0 {
+ // Don't return a nil map, callers might not be expecting that.
+ return options
+ }
+
+ // Parse options and skip empty ones.
+ for _, opt := range strings.Split(data, ",") {
+ if len(opt) > 0 {
+ res := strings.SplitN(opt, "=", 2)
+ if len(res) == 2 {
+ options[res[0]] = res[1]
+ } else {
+ options[opt] = ""
+ }
+ }
+ }
+ return options
+}
diff --git a/pkg/sentry/fs/flags.go b/pkg/sentry/fs/flags.go
new file mode 100644
index 000000000..5c8cb773f
--- /dev/null
+++ b/pkg/sentry/fs/flags.go
@@ -0,0 +1,121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// FileFlags encodes file flags.
+//
+// +stateify savable
+type FileFlags struct {
+ // Direct indicates that I/O should be done directly.
+ Direct bool
+
+ // NonBlocking indicates that I/O should not block.
+ NonBlocking bool
+
+ // Sync indicates that any writes should be synchronous.
+ Sync bool
+
+ // Append indicates this file is append only.
+ Append bool
+
+ // Read indicates this file is readable.
+ Read bool
+
+ // Write indicates this file is writeable.
+ Write bool
+
+ // Pread indicates this file is readable at an arbitrary offset.
+ Pread bool
+
+ // Pwrite indicates this file is writable at an arbitrary offset.
+ Pwrite bool
+
+ // Directory indicates that this file must be a directory.
+ Directory bool
+
+ // Async indicates that this file sends signals on IO events.
+ Async bool
+
+ // LargeFile indicates that this file should be opened even if it has
+ // size greater than linux's off_t. When running in 64-bit mode,
+ // Linux sets this flag for all files. Since gVisor is only compatible
+ // with 64-bit Linux, it also sets this flag for all files.
+ LargeFile bool
+}
+
+// SettableFileFlags is a subset of FileFlags above that can be changed
+// via fcntl(2) using the F_SETFL command.
+type SettableFileFlags struct {
+ // Direct indicates that I/O should be done directly.
+ Direct bool
+
+ // NonBlocking indicates that I/O should not block.
+ NonBlocking bool
+
+ // Append indicates this file is append only.
+ Append bool
+
+ // Async indicates that this file sends signals on IO events.
+ Async bool
+}
+
+// Settable returns the subset of f that are settable.
+func (f FileFlags) Settable() SettableFileFlags {
+ return SettableFileFlags{
+ Direct: f.Direct,
+ NonBlocking: f.NonBlocking,
+ Append: f.Append,
+ Async: f.Async,
+ }
+}
+
+// ToLinux converts a FileFlags object to a Linux representation.
+func (f FileFlags) ToLinux() (mask uint) {
+ if f.Direct {
+ mask |= linux.O_DIRECT
+ }
+ if f.NonBlocking {
+ mask |= linux.O_NONBLOCK
+ }
+ if f.Sync {
+ mask |= linux.O_SYNC
+ }
+ if f.Append {
+ mask |= linux.O_APPEND
+ }
+ if f.Directory {
+ mask |= linux.O_DIRECTORY
+ }
+ if f.Async {
+ mask |= linux.O_ASYNC
+ }
+ if f.LargeFile {
+ mask |= linux.O_LARGEFILE
+ }
+
+ switch {
+ case f.Read && f.Write:
+ mask |= linux.O_RDWR
+ case f.Write:
+ mask |= linux.O_WRONLY
+ case f.Read:
+ mask |= linux.O_RDONLY
+ }
+ return
+}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
new file mode 100644
index 000000000..632055cce
--- /dev/null
+++ b/pkg/sentry/fs/fs.go
@@ -0,0 +1,161 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fs implements a virtual filesystem layer.
+//
+// Specific filesystem implementations must implement the InodeOperations
+// interface (inode.go).
+//
+// The MountNamespace (mounts.go) is used to create a collection of mounts in
+// a filesystem rooted at a given Inode.
+//
+// MountSources (mount.go) form a tree, with each mount holding pointers to its
+// parent and children.
+//
+// Dirents (dirents.go) wrap Inodes in a caching layer.
+//
+// When multiple locks are to be held at the same time, they should be acquired
+// in the following order.
+//
+// Either:
+// File.mu
+// Locks in FileOperations implementations
+// goto Dirent-Locks
+//
+// Or:
+// MountNamespace.mu
+// goto Dirent-Locks
+//
+// Dirent-Locks:
+// renameMu
+// Dirent.dirMu
+// Dirent.mu
+// DirentCache.mu
+// Locks in InodeOperations implementations or overlayEntry
+// Inode.Watches.mu (see `Inotify` for other lock ordering)
+// MountSource.mu
+//
+// If multiple Dirent or MountSource locks must be taken, locks in the parent must be
+// taken before locks in their children.
+//
+// If locks must be taken on multiple unrelated Dirents, renameMu must be taken
+// first. See lockForRename.
+package fs
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+var (
+ // workMu is used to synchronize pending asynchronous work. Async work
+ // runs with the lock held for reading. AsyncBarrier will take the lock
+ // for writing, thus ensuring that all Async work completes before
+ // AsyncBarrier returns.
+ workMu sync.RWMutex
+
+ // asyncError is used to store up to one asynchronous execution error.
+ asyncError = make(chan error, 1)
+)
+
+// AsyncBarrier waits for all outstanding asynchronous work to complete.
+func AsyncBarrier() {
+ workMu.Lock()
+ workMu.Unlock()
+}
+
+// Async executes a function asynchronously.
+//
+// Async must not be called recursively.
+func Async(f func()) {
+ workMu.RLock()
+ go func() { // S/R-SAFE: AsyncBarrier must be called.
+ defer workMu.RUnlock() // Ensure RUnlock in case of panic.
+ f()
+ }()
+}
+
+// AsyncWithContext is just like Async, except that it calls the asynchronous
+// function with the given context as argument. This function exists to avoid
+// needing to allocate an extra function on the heap in a hot path.
+func AsyncWithContext(ctx context.Context, f func(context.Context)) {
+ workMu.RLock()
+ go func() { // S/R-SAFE: AsyncBarrier must be called.
+ defer workMu.RUnlock() // Ensure RUnlock in case of panic.
+ f(ctx)
+ }()
+}
+
+// AsyncErrorBarrier waits for all outstanding asynchronous work to complete, or
+// the first async error to arrive. Other unfinished async executions will
+// continue in the background. Other past and future async errors are ignored.
+func AsyncErrorBarrier() error {
+ wait := make(chan struct{}, 1)
+ go func() { // S/R-SAFE: Does not touch persistent state.
+ AsyncBarrier()
+ wait <- struct{}{}
+ }()
+ select {
+ case <-wait:
+ select {
+ case err := <-asyncError:
+ return err
+ default:
+ return nil
+ }
+ case err := <-asyncError:
+ return err
+ }
+}
+
+// CatchError tries to capture the potential async error returned by the
+// function. At most one async error will be captured globally so excessive
+// errors will be dropped.
+func CatchError(f func() error) func() {
+ return func() {
+ if err := f(); err != nil {
+ select {
+ case asyncError <- err:
+ default:
+ log.Warningf("excessive async error dropped: %v", err)
+ }
+ }
+ }
+}
+
+// ErrSaveRejection indicates a failed save due to unsupported file system state
+// such as dangling open fd, etc.
+type ErrSaveRejection struct {
+ // Err is the wrapped error.
+ Err error
+}
+
+// Error returns a sensible description of the save rejection error.
+func (e ErrSaveRejection) Error() string {
+ return "save rejected due to unsupported file system state: " + e.Err.Error()
+}
+
+// ErrCorruption indicates a failed restore due to external file system state in
+// corruption.
+type ErrCorruption struct {
+ // Err is the wrapped error.
+ Err error
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrCorruption) Error() string {
+ return "restore failed due to external file system state in corruption: " + e.Err.Error()
+}
diff --git a/pkg/sentry/fs/fs_state_autogen.go b/pkg/sentry/fs/fs_state_autogen.go
new file mode 100755
index 000000000..4af22a474
--- /dev/null
+++ b/pkg/sentry/fs/fs_state_autogen.go
@@ -0,0 +1,626 @@
+// automatically generated by stateify.
+
+package fs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *StableAttr) beforeSave() {}
+func (x *StableAttr) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Type", &x.Type)
+ m.Save("DeviceID", &x.DeviceID)
+ m.Save("InodeID", &x.InodeID)
+ m.Save("BlockSize", &x.BlockSize)
+ m.Save("DeviceFileMajor", &x.DeviceFileMajor)
+ m.Save("DeviceFileMinor", &x.DeviceFileMinor)
+}
+
+func (x *StableAttr) afterLoad() {}
+func (x *StableAttr) load(m state.Map) {
+ m.Load("Type", &x.Type)
+ m.Load("DeviceID", &x.DeviceID)
+ m.Load("InodeID", &x.InodeID)
+ m.Load("BlockSize", &x.BlockSize)
+ m.Load("DeviceFileMajor", &x.DeviceFileMajor)
+ m.Load("DeviceFileMinor", &x.DeviceFileMinor)
+}
+
+func (x *UnstableAttr) beforeSave() {}
+func (x *UnstableAttr) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Size", &x.Size)
+ m.Save("Usage", &x.Usage)
+ m.Save("Perms", &x.Perms)
+ m.Save("Owner", &x.Owner)
+ m.Save("AccessTime", &x.AccessTime)
+ m.Save("ModificationTime", &x.ModificationTime)
+ m.Save("StatusChangeTime", &x.StatusChangeTime)
+ m.Save("Links", &x.Links)
+}
+
+func (x *UnstableAttr) afterLoad() {}
+func (x *UnstableAttr) load(m state.Map) {
+ m.Load("Size", &x.Size)
+ m.Load("Usage", &x.Usage)
+ m.Load("Perms", &x.Perms)
+ m.Load("Owner", &x.Owner)
+ m.Load("AccessTime", &x.AccessTime)
+ m.Load("ModificationTime", &x.ModificationTime)
+ m.Load("StatusChangeTime", &x.StatusChangeTime)
+ m.Load("Links", &x.Links)
+}
+
+func (x *AttrMask) beforeSave() {}
+func (x *AttrMask) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Type", &x.Type)
+ m.Save("DeviceID", &x.DeviceID)
+ m.Save("InodeID", &x.InodeID)
+ m.Save("BlockSize", &x.BlockSize)
+ m.Save("Size", &x.Size)
+ m.Save("Usage", &x.Usage)
+ m.Save("Perms", &x.Perms)
+ m.Save("UID", &x.UID)
+ m.Save("GID", &x.GID)
+ m.Save("AccessTime", &x.AccessTime)
+ m.Save("ModificationTime", &x.ModificationTime)
+ m.Save("StatusChangeTime", &x.StatusChangeTime)
+ m.Save("Links", &x.Links)
+}
+
+func (x *AttrMask) afterLoad() {}
+func (x *AttrMask) load(m state.Map) {
+ m.Load("Type", &x.Type)
+ m.Load("DeviceID", &x.DeviceID)
+ m.Load("InodeID", &x.InodeID)
+ m.Load("BlockSize", &x.BlockSize)
+ m.Load("Size", &x.Size)
+ m.Load("Usage", &x.Usage)
+ m.Load("Perms", &x.Perms)
+ m.Load("UID", &x.UID)
+ m.Load("GID", &x.GID)
+ m.Load("AccessTime", &x.AccessTime)
+ m.Load("ModificationTime", &x.ModificationTime)
+ m.Load("StatusChangeTime", &x.StatusChangeTime)
+ m.Load("Links", &x.Links)
+}
+
+func (x *PermMask) beforeSave() {}
+func (x *PermMask) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Read", &x.Read)
+ m.Save("Write", &x.Write)
+ m.Save("Execute", &x.Execute)
+}
+
+func (x *PermMask) afterLoad() {}
+func (x *PermMask) load(m state.Map) {
+ m.Load("Read", &x.Read)
+ m.Load("Write", &x.Write)
+ m.Load("Execute", &x.Execute)
+}
+
+func (x *FilePermissions) beforeSave() {}
+func (x *FilePermissions) save(m state.Map) {
+ x.beforeSave()
+ m.Save("User", &x.User)
+ m.Save("Group", &x.Group)
+ m.Save("Other", &x.Other)
+ m.Save("Sticky", &x.Sticky)
+ m.Save("SetUID", &x.SetUID)
+ m.Save("SetGID", &x.SetGID)
+}
+
+func (x *FilePermissions) afterLoad() {}
+func (x *FilePermissions) load(m state.Map) {
+ m.Load("User", &x.User)
+ m.Load("Group", &x.Group)
+ m.Load("Other", &x.Other)
+ m.Load("Sticky", &x.Sticky)
+ m.Load("SetUID", &x.SetUID)
+ m.Load("SetGID", &x.SetGID)
+}
+
+func (x *FileOwner) beforeSave() {}
+func (x *FileOwner) save(m state.Map) {
+ x.beforeSave()
+ m.Save("UID", &x.UID)
+ m.Save("GID", &x.GID)
+}
+
+func (x *FileOwner) afterLoad() {}
+func (x *FileOwner) load(m state.Map) {
+ m.Load("UID", &x.UID)
+ m.Load("GID", &x.GID)
+}
+
+func (x *DentAttr) beforeSave() {}
+func (x *DentAttr) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Type", &x.Type)
+ m.Save("InodeID", &x.InodeID)
+}
+
+func (x *DentAttr) afterLoad() {}
+func (x *DentAttr) load(m state.Map) {
+ m.Load("Type", &x.Type)
+ m.Load("InodeID", &x.InodeID)
+}
+
+func (x *SortedDentryMap) beforeSave() {}
+func (x *SortedDentryMap) save(m state.Map) {
+ x.beforeSave()
+ m.Save("names", &x.names)
+ m.Save("entries", &x.entries)
+}
+
+func (x *SortedDentryMap) afterLoad() {}
+func (x *SortedDentryMap) load(m state.Map) {
+ m.Load("names", &x.names)
+ m.Load("entries", &x.entries)
+}
+
+func (x *Dirent) save(m state.Map) {
+ x.beforeSave()
+ var children map[string]*Dirent = x.saveChildren()
+ m.SaveValue("children", children)
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("userVisible", &x.userVisible)
+ m.Save("Inode", &x.Inode)
+ m.Save("name", &x.name)
+ m.Save("parent", &x.parent)
+ m.Save("deleted", &x.deleted)
+ m.Save("frozen", &x.frozen)
+ m.Save("mounted", &x.mounted)
+}
+
+func (x *Dirent) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("userVisible", &x.userVisible)
+ m.Load("Inode", &x.Inode)
+ m.Load("name", &x.name)
+ m.Load("parent", &x.parent)
+ m.Load("deleted", &x.deleted)
+ m.Load("frozen", &x.frozen)
+ m.Load("mounted", &x.mounted)
+ m.LoadValue("children", new(map[string]*Dirent), func(y interface{}) { x.loadChildren(y.(map[string]*Dirent)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *DirentCache) beforeSave() {}
+func (x *DirentCache) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.currentSize) { m.Failf("currentSize is %v, expected zero", x.currentSize) }
+ if !state.IsZeroValue(x.list) { m.Failf("list is %v, expected zero", x.list) }
+ m.Save("maxSize", &x.maxSize)
+ m.Save("limit", &x.limit)
+}
+
+func (x *DirentCache) afterLoad() {}
+func (x *DirentCache) load(m state.Map) {
+ m.Load("maxSize", &x.maxSize)
+ m.Load("limit", &x.limit)
+}
+
+func (x *DirentCacheLimiter) beforeSave() {}
+func (x *DirentCacheLimiter) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.count) { m.Failf("count is %v, expected zero", x.count) }
+ m.Save("max", &x.max)
+}
+
+func (x *DirentCacheLimiter) afterLoad() {}
+func (x *DirentCacheLimiter) load(m state.Map) {
+ m.Load("max", &x.max)
+}
+
+func (x *direntList) beforeSave() {}
+func (x *direntList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *direntList) afterLoad() {}
+func (x *direntList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *direntEntry) beforeSave() {}
+func (x *direntEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *direntEntry) afterLoad() {}
+func (x *direntEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *eventList) beforeSave() {}
+func (x *eventList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *eventList) afterLoad() {}
+func (x *eventList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *eventEntry) beforeSave() {}
+func (x *eventEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *eventEntry) afterLoad() {}
+func (x *eventEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *File) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("UniqueID", &x.UniqueID)
+ m.Save("Dirent", &x.Dirent)
+ m.Save("flags", &x.flags)
+ m.Save("async", &x.async)
+ m.Save("FileOperations", &x.FileOperations)
+ m.Save("offset", &x.offset)
+}
+
+func (x *File) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("UniqueID", &x.UniqueID)
+ m.Load("Dirent", &x.Dirent)
+ m.Load("flags", &x.flags)
+ m.Load("async", &x.async)
+ m.LoadWait("FileOperations", &x.FileOperations)
+ m.Load("offset", &x.offset)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *overlayFileOperations) beforeSave() {}
+func (x *overlayFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("upper", &x.upper)
+ m.Save("lower", &x.lower)
+ m.Save("dirCursor", &x.dirCursor)
+ m.Save("dirCache", &x.dirCache)
+}
+
+func (x *overlayFileOperations) afterLoad() {}
+func (x *overlayFileOperations) load(m state.Map) {
+ m.Load("upper", &x.upper)
+ m.Load("lower", &x.lower)
+ m.Load("dirCursor", &x.dirCursor)
+ m.Load("dirCache", &x.dirCache)
+}
+
+func (x *overlayMappingIdentity) beforeSave() {}
+func (x *overlayMappingIdentity) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("id", &x.id)
+ m.Save("overlayFile", &x.overlayFile)
+}
+
+func (x *overlayMappingIdentity) afterLoad() {}
+func (x *overlayMappingIdentity) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("id", &x.id)
+ m.Load("overlayFile", &x.overlayFile)
+}
+
+func (x *MountSourceFlags) beforeSave() {}
+func (x *MountSourceFlags) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ReadOnly", &x.ReadOnly)
+ m.Save("NoAtime", &x.NoAtime)
+ m.Save("ForcePageCache", &x.ForcePageCache)
+ m.Save("NoExec", &x.NoExec)
+}
+
+func (x *MountSourceFlags) afterLoad() {}
+func (x *MountSourceFlags) load(m state.Map) {
+ m.Load("ReadOnly", &x.ReadOnly)
+ m.Load("NoAtime", &x.NoAtime)
+ m.Load("ForcePageCache", &x.ForcePageCache)
+ m.Load("NoExec", &x.NoExec)
+}
+
+func (x *FileFlags) beforeSave() {}
+func (x *FileFlags) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Direct", &x.Direct)
+ m.Save("NonBlocking", &x.NonBlocking)
+ m.Save("Sync", &x.Sync)
+ m.Save("Append", &x.Append)
+ m.Save("Read", &x.Read)
+ m.Save("Write", &x.Write)
+ m.Save("Pread", &x.Pread)
+ m.Save("Pwrite", &x.Pwrite)
+ m.Save("Directory", &x.Directory)
+ m.Save("Async", &x.Async)
+ m.Save("LargeFile", &x.LargeFile)
+}
+
+func (x *FileFlags) afterLoad() {}
+func (x *FileFlags) load(m state.Map) {
+ m.Load("Direct", &x.Direct)
+ m.Load("NonBlocking", &x.NonBlocking)
+ m.Load("Sync", &x.Sync)
+ m.Load("Append", &x.Append)
+ m.Load("Read", &x.Read)
+ m.Load("Write", &x.Write)
+ m.Load("Pread", &x.Pread)
+ m.Load("Pwrite", &x.Pwrite)
+ m.Load("Directory", &x.Directory)
+ m.Load("Async", &x.Async)
+ m.Load("LargeFile", &x.LargeFile)
+}
+
+func (x *Inode) beforeSave() {}
+func (x *Inode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("InodeOperations", &x.InodeOperations)
+ m.Save("StableAttr", &x.StableAttr)
+ m.Save("LockCtx", &x.LockCtx)
+ m.Save("Watches", &x.Watches)
+ m.Save("MountSource", &x.MountSource)
+ m.Save("overlay", &x.overlay)
+}
+
+func (x *Inode) afterLoad() {}
+func (x *Inode) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("InodeOperations", &x.InodeOperations)
+ m.Load("StableAttr", &x.StableAttr)
+ m.Load("LockCtx", &x.LockCtx)
+ m.Load("Watches", &x.Watches)
+ m.Load("MountSource", &x.MountSource)
+ m.Load("overlay", &x.overlay)
+}
+
+func (x *LockCtx) beforeSave() {}
+func (x *LockCtx) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Posix", &x.Posix)
+ m.Save("BSD", &x.BSD)
+}
+
+func (x *LockCtx) afterLoad() {}
+func (x *LockCtx) load(m state.Map) {
+ m.Load("Posix", &x.Posix)
+ m.Load("BSD", &x.BSD)
+}
+
+func (x *Watches) beforeSave() {}
+func (x *Watches) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ws", &x.ws)
+ m.Save("unlinked", &x.unlinked)
+}
+
+func (x *Watches) afterLoad() {}
+func (x *Watches) load(m state.Map) {
+ m.Load("ws", &x.ws)
+ m.Load("unlinked", &x.unlinked)
+}
+
+func (x *Inotify) beforeSave() {}
+func (x *Inotify) save(m state.Map) {
+ x.beforeSave()
+ m.Save("id", &x.id)
+ m.Save("events", &x.events)
+ m.Save("scratch", &x.scratch)
+ m.Save("nextWatch", &x.nextWatch)
+ m.Save("watches", &x.watches)
+}
+
+func (x *Inotify) afterLoad() {}
+func (x *Inotify) load(m state.Map) {
+ m.Load("id", &x.id)
+ m.Load("events", &x.events)
+ m.Load("scratch", &x.scratch)
+ m.Load("nextWatch", &x.nextWatch)
+ m.Load("watches", &x.watches)
+}
+
+func (x *Event) beforeSave() {}
+func (x *Event) save(m state.Map) {
+ x.beforeSave()
+ m.Save("eventEntry", &x.eventEntry)
+ m.Save("wd", &x.wd)
+ m.Save("mask", &x.mask)
+ m.Save("cookie", &x.cookie)
+ m.Save("len", &x.len)
+ m.Save("name", &x.name)
+}
+
+func (x *Event) afterLoad() {}
+func (x *Event) load(m state.Map) {
+ m.Load("eventEntry", &x.eventEntry)
+ m.Load("wd", &x.wd)
+ m.Load("mask", &x.mask)
+ m.Load("cookie", &x.cookie)
+ m.Load("len", &x.len)
+ m.Load("name", &x.name)
+}
+
+func (x *Watch) beforeSave() {}
+func (x *Watch) save(m state.Map) {
+ x.beforeSave()
+ m.Save("owner", &x.owner)
+ m.Save("wd", &x.wd)
+ m.Save("target", &x.target)
+ m.Save("unpinned", &x.unpinned)
+ m.Save("mask", &x.mask)
+ m.Save("pins", &x.pins)
+}
+
+func (x *Watch) afterLoad() {}
+func (x *Watch) load(m state.Map) {
+ m.Load("owner", &x.owner)
+ m.Load("wd", &x.wd)
+ m.Load("target", &x.target)
+ m.Load("unpinned", &x.unpinned)
+ m.Load("mask", &x.mask)
+ m.Load("pins", &x.pins)
+}
+
+func (x *MountSource) beforeSave() {}
+func (x *MountSource) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("MountSourceOperations", &x.MountSourceOperations)
+ m.Save("FilesystemType", &x.FilesystemType)
+ m.Save("Flags", &x.Flags)
+ m.Save("fscache", &x.fscache)
+ m.Save("direntRefs", &x.direntRefs)
+}
+
+func (x *MountSource) afterLoad() {}
+func (x *MountSource) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("MountSourceOperations", &x.MountSourceOperations)
+ m.Load("FilesystemType", &x.FilesystemType)
+ m.Load("Flags", &x.Flags)
+ m.Load("fscache", &x.fscache)
+ m.Load("direntRefs", &x.direntRefs)
+}
+
+func (x *SimpleMountSourceOperations) beforeSave() {}
+func (x *SimpleMountSourceOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("keep", &x.keep)
+ m.Save("revalidate", &x.revalidate)
+}
+
+func (x *SimpleMountSourceOperations) afterLoad() {}
+func (x *SimpleMountSourceOperations) load(m state.Map) {
+ m.Load("keep", &x.keep)
+ m.Load("revalidate", &x.revalidate)
+}
+
+func (x *overlayMountSourceOperations) beforeSave() {}
+func (x *overlayMountSourceOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("upper", &x.upper)
+ m.Save("lower", &x.lower)
+}
+
+func (x *overlayMountSourceOperations) afterLoad() {}
+func (x *overlayMountSourceOperations) load(m state.Map) {
+ m.Load("upper", &x.upper)
+ m.Load("lower", &x.lower)
+}
+
+func (x *overlayFilesystem) beforeSave() {}
+func (x *overlayFilesystem) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *overlayFilesystem) afterLoad() {}
+func (x *overlayFilesystem) load(m state.Map) {
+}
+
+func (x *Mount) beforeSave() {}
+func (x *Mount) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ID", &x.ID)
+ m.Save("ParentID", &x.ParentID)
+ m.Save("root", &x.root)
+ m.Save("previous", &x.previous)
+}
+
+func (x *Mount) afterLoad() {}
+func (x *Mount) load(m state.Map) {
+ m.Load("ID", &x.ID)
+ m.Load("ParentID", &x.ParentID)
+ m.Load("root", &x.root)
+ m.Load("previous", &x.previous)
+}
+
+func (x *MountNamespace) beforeSave() {}
+func (x *MountNamespace) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("userns", &x.userns)
+ m.Save("root", &x.root)
+ m.Save("mounts", &x.mounts)
+ m.Save("mountID", &x.mountID)
+}
+
+func (x *MountNamespace) afterLoad() {}
+func (x *MountNamespace) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("userns", &x.userns)
+ m.Load("root", &x.root)
+ m.Load("mounts", &x.mounts)
+ m.Load("mountID", &x.mountID)
+}
+
+func (x *overlayEntry) beforeSave() {}
+func (x *overlayEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("lowerExists", &x.lowerExists)
+ m.Save("lower", &x.lower)
+ m.Save("mappings", &x.mappings)
+ m.Save("upper", &x.upper)
+}
+
+func (x *overlayEntry) afterLoad() {}
+func (x *overlayEntry) load(m state.Map) {
+ m.Load("lowerExists", &x.lowerExists)
+ m.Load("lower", &x.lower)
+ m.Load("mappings", &x.mappings)
+ m.Load("upper", &x.upper)
+}
+
+func init() {
+ state.Register("fs.StableAttr", (*StableAttr)(nil), state.Fns{Save: (*StableAttr).save, Load: (*StableAttr).load})
+ state.Register("fs.UnstableAttr", (*UnstableAttr)(nil), state.Fns{Save: (*UnstableAttr).save, Load: (*UnstableAttr).load})
+ state.Register("fs.AttrMask", (*AttrMask)(nil), state.Fns{Save: (*AttrMask).save, Load: (*AttrMask).load})
+ state.Register("fs.PermMask", (*PermMask)(nil), state.Fns{Save: (*PermMask).save, Load: (*PermMask).load})
+ state.Register("fs.FilePermissions", (*FilePermissions)(nil), state.Fns{Save: (*FilePermissions).save, Load: (*FilePermissions).load})
+ state.Register("fs.FileOwner", (*FileOwner)(nil), state.Fns{Save: (*FileOwner).save, Load: (*FileOwner).load})
+ state.Register("fs.DentAttr", (*DentAttr)(nil), state.Fns{Save: (*DentAttr).save, Load: (*DentAttr).load})
+ state.Register("fs.SortedDentryMap", (*SortedDentryMap)(nil), state.Fns{Save: (*SortedDentryMap).save, Load: (*SortedDentryMap).load})
+ state.Register("fs.Dirent", (*Dirent)(nil), state.Fns{Save: (*Dirent).save, Load: (*Dirent).load})
+ state.Register("fs.DirentCache", (*DirentCache)(nil), state.Fns{Save: (*DirentCache).save, Load: (*DirentCache).load})
+ state.Register("fs.DirentCacheLimiter", (*DirentCacheLimiter)(nil), state.Fns{Save: (*DirentCacheLimiter).save, Load: (*DirentCacheLimiter).load})
+ state.Register("fs.direntList", (*direntList)(nil), state.Fns{Save: (*direntList).save, Load: (*direntList).load})
+ state.Register("fs.direntEntry", (*direntEntry)(nil), state.Fns{Save: (*direntEntry).save, Load: (*direntEntry).load})
+ state.Register("fs.eventList", (*eventList)(nil), state.Fns{Save: (*eventList).save, Load: (*eventList).load})
+ state.Register("fs.eventEntry", (*eventEntry)(nil), state.Fns{Save: (*eventEntry).save, Load: (*eventEntry).load})
+ state.Register("fs.File", (*File)(nil), state.Fns{Save: (*File).save, Load: (*File).load})
+ state.Register("fs.overlayFileOperations", (*overlayFileOperations)(nil), state.Fns{Save: (*overlayFileOperations).save, Load: (*overlayFileOperations).load})
+ state.Register("fs.overlayMappingIdentity", (*overlayMappingIdentity)(nil), state.Fns{Save: (*overlayMappingIdentity).save, Load: (*overlayMappingIdentity).load})
+ state.Register("fs.MountSourceFlags", (*MountSourceFlags)(nil), state.Fns{Save: (*MountSourceFlags).save, Load: (*MountSourceFlags).load})
+ state.Register("fs.FileFlags", (*FileFlags)(nil), state.Fns{Save: (*FileFlags).save, Load: (*FileFlags).load})
+ state.Register("fs.Inode", (*Inode)(nil), state.Fns{Save: (*Inode).save, Load: (*Inode).load})
+ state.Register("fs.LockCtx", (*LockCtx)(nil), state.Fns{Save: (*LockCtx).save, Load: (*LockCtx).load})
+ state.Register("fs.Watches", (*Watches)(nil), state.Fns{Save: (*Watches).save, Load: (*Watches).load})
+ state.Register("fs.Inotify", (*Inotify)(nil), state.Fns{Save: (*Inotify).save, Load: (*Inotify).load})
+ state.Register("fs.Event", (*Event)(nil), state.Fns{Save: (*Event).save, Load: (*Event).load})
+ state.Register("fs.Watch", (*Watch)(nil), state.Fns{Save: (*Watch).save, Load: (*Watch).load})
+ state.Register("fs.MountSource", (*MountSource)(nil), state.Fns{Save: (*MountSource).save, Load: (*MountSource).load})
+ state.Register("fs.SimpleMountSourceOperations", (*SimpleMountSourceOperations)(nil), state.Fns{Save: (*SimpleMountSourceOperations).save, Load: (*SimpleMountSourceOperations).load})
+ state.Register("fs.overlayMountSourceOperations", (*overlayMountSourceOperations)(nil), state.Fns{Save: (*overlayMountSourceOperations).save, Load: (*overlayMountSourceOperations).load})
+ state.Register("fs.overlayFilesystem", (*overlayFilesystem)(nil), state.Fns{Save: (*overlayFilesystem).save, Load: (*overlayFilesystem).load})
+ state.Register("fs.Mount", (*Mount)(nil), state.Fns{Save: (*Mount).save, Load: (*Mount).load})
+ state.Register("fs.MountNamespace", (*MountNamespace)(nil), state.Fns{Save: (*MountNamespace).save, Load: (*MountNamespace).load})
+ state.Register("fs.overlayEntry", (*overlayEntry)(nil), state.Fns{Save: (*overlayEntry).save, Load: (*overlayEntry).load})
+}
diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go
new file mode 100644
index 000000000..f1451d77a
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/dirty_set.go
@@ -0,0 +1,237 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// DirtySet maps offsets into a memmap.Mappable to DirtyInfo. It is used to
+// implement Mappables that cache data from another source.
+//
+// type DirtySet <generated by go_generics>
+
+// DirtyInfo is the value type of DirtySet, and represents information about a
+// Mappable offset that is dirty (the cached data for that offset is newer than
+// its source).
+//
+// +stateify savable
+type DirtyInfo struct {
+ // Keep is true if the represented offset is concurrently writable, such
+ // that writing the data for that offset back to the source does not
+ // guarantee that the offset is clean (since it may be concurrently
+ // rewritten after the writeback).
+ Keep bool
+}
+
+// dirtySetFunctions implements segment.Functions for DirtySet.
+type dirtySetFunctions struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (dirtySetFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (dirtySetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (dirtySetFunctions) ClearValue(val *DirtyInfo) {
+}
+
+// Merge implements segment.Functions.Merge.
+func (dirtySetFunctions) Merge(_ memmap.MappableRange, val1 DirtyInfo, _ memmap.MappableRange, val2 DirtyInfo) (DirtyInfo, bool) {
+ if val1 != val2 {
+ return DirtyInfo{}, false
+ }
+ return val1, true
+}
+
+// Split implements segment.Functions.Split.
+func (dirtySetFunctions) Split(_ memmap.MappableRange, val DirtyInfo, _ uint64) (DirtyInfo, DirtyInfo) {
+ return val, val
+}
+
+// MarkClean marks all offsets in mr as not dirty, except for those to which
+// KeepDirty has been applied.
+func (ds *DirtySet) MarkClean(mr memmap.MappableRange) {
+ seg := ds.LowerBoundSegment(mr.Start)
+ for seg.Ok() && seg.Start() < mr.End {
+ if seg.Value().Keep {
+ seg = seg.NextSegment()
+ continue
+ }
+ seg = ds.Isolate(seg, mr)
+ seg = ds.Remove(seg).NextSegment()
+ }
+}
+
+// KeepClean marks all offsets in mr as not dirty, even those that were
+// previously kept dirty by KeepDirty.
+func (ds *DirtySet) KeepClean(mr memmap.MappableRange) {
+ ds.RemoveRange(mr)
+}
+
+// MarkDirty marks all offsets in mr as dirty.
+func (ds *DirtySet) MarkDirty(mr memmap.MappableRange) {
+ ds.setDirty(mr, false)
+}
+
+// KeepDirty marks all offsets in mr as dirty and prevents them from being
+// marked as clean by MarkClean.
+func (ds *DirtySet) KeepDirty(mr memmap.MappableRange) {
+ ds.setDirty(mr, true)
+}
+
+func (ds *DirtySet) setDirty(mr memmap.MappableRange, keep bool) {
+ var changedAny bool
+ defer func() {
+ if changedAny {
+ // Merge segments split by Isolate to reduce cost of iteration.
+ ds.MergeRange(mr)
+ }
+ }()
+ seg, gap := ds.Find(mr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < mr.End:
+ if keep && !seg.Value().Keep {
+ changedAny = true
+ seg = ds.Isolate(seg, mr)
+ seg.ValuePtr().Keep = true
+ }
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok() && gap.Start() < mr.End:
+ changedAny = true
+ seg = ds.Insert(gap, gap.Range().Intersect(mr), DirtyInfo{keep})
+ seg, gap = seg.NextNonEmpty()
+
+ default:
+ return
+ }
+ }
+}
+
+// AllowClean allows MarkClean to mark offsets in mr as not dirty, ending the
+// effect of a previous call to KeepDirty. (It does not itself mark those
+// offsets as not dirty.)
+func (ds *DirtySet) AllowClean(mr memmap.MappableRange) {
+ var changedAny bool
+ defer func() {
+ if changedAny {
+ // Merge segments split by Isolate to reduce cost of iteration.
+ ds.MergeRange(mr)
+ }
+ }()
+ for seg := ds.LowerBoundSegment(mr.Start); seg.Ok() && seg.Start() < mr.End; seg = seg.NextSegment() {
+ if seg.Value().Keep {
+ changedAny = true
+ seg = ds.Isolate(seg, mr)
+ seg.ValuePtr().Keep = false
+ }
+ }
+}
+
+// SyncDirty passes pages in the range mr that are stored in cache and
+// identified as dirty to writeAt, updating dirty to reflect successful writes.
+// If writeAt returns a successful partial write, SyncDirty will call it
+// repeatedly until all bytes have been written. max is the true size of the
+// cached object; offsets beyond max will not be passed to writeAt, even if
+// they are marked dirty.
+func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+ var changedDirty bool
+ defer func() {
+ if changedDirty {
+ // Merge segments split by Isolate to reduce cost of iteration.
+ dirty.MergeRange(mr)
+ }
+ }()
+ dseg := dirty.LowerBoundSegment(mr.Start)
+ for dseg.Ok() && dseg.Start() < mr.End {
+ var dr memmap.MappableRange
+ if dseg.Value().Keep {
+ dr = dseg.Range().Intersect(mr)
+ } else {
+ changedDirty = true
+ dseg = dirty.Isolate(dseg, mr)
+ dr = dseg.Range()
+ }
+ if err := syncDirtyRange(ctx, dr, cache, max, mem, writeAt); err != nil {
+ return err
+ }
+ if dseg.Value().Keep {
+ dseg = dseg.NextSegment()
+ } else {
+ dseg = dirty.Remove(dseg).NextSegment()
+ }
+ }
+ return nil
+}
+
+// SyncDirtyAll passes all pages stored in cache identified as dirty to
+// writeAt, updating dirty to reflect successful writes. If writeAt returns a
+// successful partial write, SyncDirtyAll will call it repeatedly until all
+// bytes have been written. max is the true size of the cached object; offsets
+// beyond max will not be passed to writeAt, even if they are marked dirty.
+func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+ dseg := dirty.FirstSegment()
+ for dseg.Ok() {
+ if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil {
+ return err
+ }
+ if dseg.Value().Keep {
+ dseg = dseg.NextSegment()
+ } else {
+ dseg = dirty.Remove(dseg).NextSegment()
+ }
+ }
+ return nil
+}
+
+// Preconditions: mr must be page-aligned.
+func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+ for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() {
+ wbr := cseg.Range().Intersect(mr)
+ if max < wbr.Start {
+ break
+ }
+ ims, err := mem.MapInternal(cseg.FileRangeOf(wbr), usermem.Read)
+ if err != nil {
+ return err
+ }
+ if max < wbr.End {
+ ims = ims.TakeFirst64(max - wbr.Start)
+ }
+ offset := wbr.Start
+ for !ims.IsEmpty() {
+ n, err := writeAt(ctx, ims, offset)
+ if err != nil {
+ return err
+ }
+ offset += n
+ ims = ims.DropFirst64(n)
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/fsutil/dirty_set_impl.go b/pkg/sentry/fs/fsutil/dirty_set_impl.go
new file mode 100755
index 000000000..5f25068a1
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/dirty_set_impl.go
@@ -0,0 +1,1274 @@
+package fsutil
+
+import (
+ __generics_imported0 "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+)
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ DirtyminDegree = 3
+
+ DirtymaxDegree = 2 * DirtyminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type DirtySet struct {
+ root Dirtynode `state:".(*DirtySegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *DirtySet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *DirtySet) IsEmptyRange(r __generics_imported0.MappableRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *DirtySet) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *DirtySet) SpanRange(r __generics_imported0.MappableRange) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *DirtySet) FirstSegment() DirtyIterator {
+ if s.root.nrSegments == 0 {
+ return DirtyIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *DirtySet) LastSegment() DirtyIterator {
+ if s.root.nrSegments == 0 {
+ return DirtyIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *DirtySet) FirstGap() DirtyGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return DirtyGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *DirtySet) LastGap() DirtyGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return DirtyGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *DirtySet) Find(key uint64) (DirtyIterator, DirtyGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return DirtyIterator{n, i}, DirtyGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return DirtyIterator{}, DirtyGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *DirtySet) FindSegment(key uint64) DirtyIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *DirtySet) LowerBoundSegment(min uint64) DirtyIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *DirtySet) UpperBoundSegment(max uint64) DirtyIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *DirtySet) FindGap(key uint64) DirtyGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *DirtySet) LowerBoundGap(min uint64) DirtyGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *DirtySet) UpperBoundGap(max uint64) DirtyGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *DirtySet) Add(r __generics_imported0.MappableRange, val DirtyInfo) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *DirtySet) AddWithoutMerging(r __generics_imported0.MappableRange, val DirtyInfo) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *DirtySet) Insert(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (dirtySetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (dirtySetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (dirtySetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *DirtySet) InsertWithoutMerging(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *DirtySet) InsertWithoutMergingUnchecked(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return DirtyIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *DirtySet) Remove(seg DirtyIterator) DirtyGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ dirtySetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(DirtyGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *DirtySet) RemoveAll() {
+ s.root = Dirtynode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *DirtySet) RemoveRange(r __generics_imported0.MappableRange) DirtyGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *DirtySet) Merge(first, second DirtyIterator) DirtyIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *DirtySet) MergeUnchecked(first, second DirtyIterator) DirtyIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (dirtySetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return DirtyIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *DirtySet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *DirtySet) MergeRange(r __generics_imported0.MappableRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *DirtySet) MergeAdjacent(r __generics_imported0.MappableRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *DirtySet) Split(seg DirtyIterator, split uint64) (DirtyIterator, DirtyIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *DirtySet) SplitUnchecked(seg DirtyIterator, split uint64) (DirtyIterator, DirtyIterator) {
+ val1, val2 := (dirtySetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.MappableRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *DirtySet) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *DirtySet) Isolate(seg DirtyIterator, r __generics_imported0.MappableRange) DirtyIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *DirtySet) ApplyContiguous(r __generics_imported0.MappableRange, fn func(seg DirtyIterator)) DirtyGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return DirtyGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return DirtyGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type Dirtynode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *Dirtynode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [DirtymaxDegree - 1]__generics_imported0.MappableRange
+ values [DirtymaxDegree - 1]DirtyInfo
+ children [DirtymaxDegree]*Dirtynode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *Dirtynode) firstSegment() DirtyIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return DirtyIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *Dirtynode) lastSegment() DirtyIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return DirtyIterator{n, n.nrSegments - 1}
+}
+
+func (n *Dirtynode) prevSibling() *Dirtynode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *Dirtynode) nextSibling() *Dirtynode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *Dirtynode) rebalanceBeforeInsert(gap DirtyGapIterator) DirtyGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < DirtymaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &Dirtynode{
+ nrSegments: DirtyminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &Dirtynode{
+ nrSegments: DirtyminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:DirtyminDegree-1], n.keys[:DirtyminDegree-1])
+ copy(left.values[:DirtyminDegree-1], n.values[:DirtyminDegree-1])
+ copy(right.keys[:DirtyminDegree-1], n.keys[DirtyminDegree:])
+ copy(right.values[:DirtyminDegree-1], n.values[DirtyminDegree:])
+ n.keys[0], n.values[0] = n.keys[DirtyminDegree-1], n.values[DirtyminDegree-1]
+ DirtyzeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:DirtyminDegree], n.children[:DirtyminDegree])
+ copy(right.children[:DirtyminDegree], n.children[DirtyminDegree:])
+ DirtyzeroNodeSlice(n.children[2:])
+ for i := 0; i < DirtyminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < DirtyminDegree {
+ return DirtyGapIterator{left, gap.index}
+ }
+ return DirtyGapIterator{right, gap.index - DirtyminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[DirtyminDegree-1], n.values[DirtyminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &Dirtynode{
+ nrSegments: DirtyminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:DirtyminDegree-1], n.keys[DirtyminDegree:])
+ copy(sibling.values[:DirtyminDegree-1], n.values[DirtyminDegree:])
+ DirtyzeroValueSlice(n.values[DirtyminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:DirtyminDegree], n.children[DirtyminDegree:])
+ DirtyzeroNodeSlice(n.children[DirtyminDegree:])
+ for i := 0; i < DirtyminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = DirtyminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < DirtyminDegree {
+ return gap
+ }
+ return DirtyGapIterator{sibling, gap.index - DirtyminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *Dirtynode) rebalanceAfterRemove(gap DirtyGapIterator) DirtyGapIterator {
+ for {
+ if n.nrSegments >= DirtyminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= DirtyminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ dirtySetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return DirtyGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return DirtyGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= DirtyminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ dirtySetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return DirtyGapIterator{n, n.nrSegments}
+ }
+ return DirtyGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return DirtyGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return DirtyGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *Dirtynode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = DirtyGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ dirtySetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type DirtyIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *Dirtynode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg DirtyIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg DirtyIterator) Range() __generics_imported0.MappableRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg DirtyIterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg DirtyIterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg DirtyIterator) SetRangeUnchecked(r __generics_imported0.MappableRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg DirtyIterator) SetRange(r __generics_imported0.MappableRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg DirtyIterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg DirtyIterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg DirtyIterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg DirtyIterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg DirtyIterator) Value() DirtyInfo {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg DirtyIterator) ValuePtr() *DirtyInfo {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg DirtyIterator) SetValue(val DirtyInfo) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg DirtyIterator) PrevSegment() DirtyIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return DirtyIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return DirtyIterator{}
+ }
+ return DirtysegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg DirtyIterator) NextSegment() DirtyIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return DirtyIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return DirtyIterator{}
+ }
+ return DirtysegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg DirtyIterator) PrevGap() DirtyGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return DirtyGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg DirtyIterator) NextGap() DirtyGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return DirtyGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg DirtyIterator) PrevNonEmpty() (DirtyIterator, DirtyGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return DirtyIterator{}, gap
+ }
+ return gap.PrevSegment(), DirtyGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg DirtyIterator) NextNonEmpty() (DirtyIterator, DirtyGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return DirtyIterator{}, gap
+ }
+ return gap.NextSegment(), DirtyGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type DirtyGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *Dirtynode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap DirtyGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap DirtyGapIterator) Range() __generics_imported0.MappableRange {
+ return __generics_imported0.MappableRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap DirtyGapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return dirtySetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap DirtyGapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return dirtySetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap DirtyGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap DirtyGapIterator) PrevSegment() DirtyIterator {
+ return DirtysegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap DirtyGapIterator) NextSegment() DirtyIterator {
+ return DirtysegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap DirtyGapIterator) PrevGap() DirtyGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return DirtyGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap DirtyGapIterator) NextGap() DirtyGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return DirtyGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func DirtysegmentBeforePosition(n *Dirtynode, i int) DirtyIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return DirtyIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return DirtyIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func DirtysegmentAfterPosition(n *Dirtynode, i int) DirtyIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return DirtyIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return DirtyIterator{n, i}
+}
+
+func DirtyzeroValueSlice(slice []DirtyInfo) {
+
+ for i := range slice {
+ dirtySetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func DirtyzeroNodeSlice(slice []*Dirtynode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *DirtySet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *Dirtynode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *Dirtynode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type DirtySegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []DirtyInfo
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *DirtySet) ExportSortedSlices() *DirtySegmentDataSlices {
+ var sds DirtySegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *DirtySet) ImportSortedSlices(sds *DirtySegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := __generics_imported0.MappableRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *DirtySet) saveRoot() *DirtySegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *DirtySet) loadRoot(sds *DirtySegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
new file mode 100644
index 000000000..9381963d0
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -0,0 +1,394 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// FileNoopRelease implements fs.FileOperations.Release for files that have no
+// resources to release.
+type FileNoopRelease struct{}
+
+// Release is a no-op.
+func (FileNoopRelease) Release() {}
+
+// SeekWithDirCursor is used to implement fs.FileOperations.Seek. If dirCursor
+// is not nil and the seek was on a directory, the cursor will be updated.
+//
+// Currently only seeking to 0 on a directory is supported.
+//
+// FIXME(b/33075855): Lift directory seeking limitations.
+func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64, dirCursor *string) (int64, error) {
+ inode := file.Dirent.Inode
+ current := file.Offset()
+
+ // Does the Inode represents a non-seekable type?
+ if fs.IsPipe(inode.StableAttr) || fs.IsSocket(inode.StableAttr) {
+ return current, syserror.ESPIPE
+ }
+
+ // Does the Inode represent a character device?
+ if fs.IsCharDevice(inode.StableAttr) {
+ // Ignore seek requests.
+ //
+ // FIXME(b/34716638): This preserves existing
+ // behavior but is not universally correct.
+ return 0, nil
+ }
+
+ // Otherwise compute the new offset.
+ switch whence {
+ case fs.SeekSet:
+ switch inode.StableAttr.Type {
+ case fs.RegularFile, fs.SpecialFile, fs.BlockDevice:
+ if offset < 0 {
+ return current, syserror.EINVAL
+ }
+ return offset, nil
+ case fs.Directory, fs.SpecialDirectory:
+ if offset != 0 {
+ return current, syserror.EINVAL
+ }
+ // SEEK_SET to 0 moves the directory "cursor" to the beginning.
+ if dirCursor != nil {
+ *dirCursor = ""
+ }
+ return 0, nil
+ default:
+ return current, syserror.EINVAL
+ }
+ case fs.SeekCurrent:
+ switch inode.StableAttr.Type {
+ case fs.RegularFile, fs.SpecialFile, fs.BlockDevice:
+ if current+offset < 0 {
+ return current, syserror.EINVAL
+ }
+ return current + offset, nil
+ case fs.Directory, fs.SpecialDirectory:
+ if offset != 0 {
+ return current, syserror.EINVAL
+ }
+ return current, nil
+ default:
+ return current, syserror.EINVAL
+ }
+ case fs.SeekEnd:
+ switch inode.StableAttr.Type {
+ case fs.RegularFile, fs.BlockDevice:
+ // Allow the file to determine the end.
+ uattr, err := inode.UnstableAttr(ctx)
+ if err != nil {
+ return current, err
+ }
+ sz := uattr.Size
+ if sz+offset < 0 {
+ return current, syserror.EINVAL
+ }
+ return sz + offset, nil
+ // FIXME(b/34778850): This is not universally correct.
+ // Remove SpecialDirectory.
+ case fs.SpecialDirectory:
+ if offset != 0 {
+ return current, syserror.EINVAL
+ }
+ // SEEK_END to 0 moves the directory "cursor" to the end.
+ //
+ // FIXME(b/35442290): The ensures that after the seek,
+ // reading on the directory will get EOF. But it is not
+ // correct in general because the directory can grow in
+ // size; attempting to read those new entries will be
+ // futile (EOF will always be the result).
+ return fs.FileMaxOffset, nil
+ default:
+ return current, syserror.EINVAL
+ }
+ }
+
+ // Not a valid seek request.
+ return current, syserror.EINVAL
+}
+
+// FileGenericSeek implements fs.FileOperations.Seek for files that use a
+// generic seek implementation.
+type FileGenericSeek struct{}
+
+// Seek implements fs.FileOperations.Seek.
+func (FileGenericSeek) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return SeekWithDirCursor(ctx, file, whence, offset, nil)
+}
+
+// FileZeroSeek implements fs.FileOperations.Seek for files that maintain a
+// constant zero-value offset and require a no-op Seek.
+type FileZeroSeek struct{}
+
+// Seek implements fs.FileOperations.Seek.
+func (FileZeroSeek) Seek(context.Context, *fs.File, fs.SeekWhence, int64) (int64, error) {
+ return 0, nil
+}
+
+// FileNoSeek implements fs.FileOperations.Seek to return EINVAL.
+type FileNoSeek struct{}
+
+// Seek implements fs.FileOperations.Seek.
+func (FileNoSeek) Seek(context.Context, *fs.File, fs.SeekWhence, int64) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// FilePipeSeek implements fs.FileOperations.Seek and can be used for files
+// that behave like pipes (seeking is not supported).
+type FilePipeSeek struct{}
+
+// Seek implements fs.FileOperations.Seek.
+func (FilePipeSeek) Seek(context.Context, *fs.File, fs.SeekWhence, int64) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// FileNotDirReaddir implements fs.FileOperations.Readdir for non-directories.
+type FileNotDirReaddir struct{}
+
+// Readdir implements fs.FileOperations.FileNotDirReaddir.
+func (FileNotDirReaddir) Readdir(context.Context, *fs.File, fs.DentrySerializer) (int64, error) {
+ return 0, syserror.ENOTDIR
+}
+
+// FileNoFsync implements fs.FileOperations.Fsync for files that don't support
+// syncing.
+type FileNoFsync struct{}
+
+// Fsync implements fs.FileOperations.Fsync.
+func (FileNoFsync) Fsync(context.Context, *fs.File, int64, int64, fs.SyncType) error {
+ return syserror.EINVAL
+}
+
+// FileNoopFsync implements fs.FileOperations.Fsync for files that don't need
+// to synced.
+type FileNoopFsync struct{}
+
+// Fsync implements fs.FileOperations.Fsync.
+func (FileNoopFsync) Fsync(context.Context, *fs.File, int64, int64, fs.SyncType) error {
+ return nil
+}
+
+// FileNoopFlush implements fs.FileOperations.Flush as a no-op.
+type FileNoopFlush struct{}
+
+// Flush implements fs.FileOperations.Flush.
+func (FileNoopFlush) Flush(context.Context, *fs.File) error {
+ return nil
+}
+
+// FileNoMMap implements fs.FileOperations.Mappable for files that cannot
+// be memory mapped.
+type FileNoMMap struct{}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (FileNoMMap) ConfigureMMap(context.Context, *fs.File, *memmap.MMapOpts) error {
+ return syserror.ENODEV
+}
+
+// GenericConfigureMMap implements fs.FileOperations.ConfigureMMap for most
+// filesystems that support memory mapping.
+func GenericConfigureMMap(file *fs.File, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ opts.Mappable = m
+ opts.MappingIdentity = file
+ file.IncRef()
+ return nil
+}
+
+// FileNoIoctl implements fs.FileOperations.Ioctl for files that don't
+// implement the ioctl syscall.
+type FileNoIoctl struct{}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (FileNoIoctl) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return 0, syserror.ENOTTY
+}
+
+// FileNoSplice implements fs.FileOperations.ReadFrom and
+// fs.FileOperations.WriteTo for files that don't support splice.
+type FileNoSplice struct{}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// DirFileOperations implements most of fs.FileOperations for directories,
+// except for Readdir and UnstableAttr which the embedding type must implement.
+type DirFileOperations struct {
+ waiter.AlwaysReady
+ FileGenericSeek
+ FileNoIoctl
+ FileNoMMap
+ FileNoopFlush
+ FileNoopFsync
+ FileNoopRelease
+ FileNoSplice
+}
+
+// Read implements fs.FileOperations.Read
+func (*DirFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// Write implements fs.FileOperations.Write.
+func (*DirFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// StaticDirFileOperations implements fs.FileOperations for directories with
+// static children.
+//
+// +stateify savable
+type StaticDirFileOperations struct {
+ DirFileOperations `state:"nosave"`
+ FileUseInodeUnstableAttr `state:"nosave"`
+
+ // dentryMap is a SortedDentryMap used to implement Readdir.
+ dentryMap *fs.SortedDentryMap
+
+ // dirCursor contains the name of the last directory entry that was
+ // serialized.
+ dirCursor string
+}
+
+// NewStaticDirFileOperations returns a new StaticDirFileOperations that will
+// iterate the given denty map.
+func NewStaticDirFileOperations(dentries *fs.SortedDentryMap) *StaticDirFileOperations {
+ return &StaticDirFileOperations{
+ dentryMap: dentries,
+ }
+}
+
+// IterateDir implements DirIterator.IterateDir.
+func (sdfo *StaticDirFileOperations) IterateDir(ctx context.Context, dirCtx *fs.DirCtx, offset int) (int, error) {
+ n, err := fs.GenericReaddir(dirCtx, sdfo.dentryMap)
+ return offset + n, err
+}
+
+// Readdir implements fs.FileOperations.Readdir.
+func (sdfo *StaticDirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &sdfo.dirCursor,
+ }
+ return fs.DirentReaddir(ctx, file.Dirent, sdfo, root, dirCtx, file.Offset())
+}
+
+// NoReadWriteFile is a file that does not support reading or writing.
+//
+// +stateify savable
+type NoReadWriteFile struct {
+ waiter.AlwaysReady `state:"nosave"`
+ FileGenericSeek `state:"nosave"`
+ FileNoIoctl `state:"nosave"`
+ FileNoMMap `state:"nosave"`
+ FileNoopFsync `state:"nosave"`
+ FileNoopFlush `state:"nosave"`
+ FileNoopRelease `state:"nosave"`
+ FileNoRead `state:"nosave"`
+ FileNoWrite `state:"nosave"`
+ FileNotDirReaddir `state:"nosave"`
+ FileUseInodeUnstableAttr `state:"nosave"`
+ FileNoSplice `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*NoReadWriteFile)(nil)
+
+// FileStaticContentReader is a helper to implement fs.FileOperations.Read with
+// static content.
+//
+// +stateify savable
+type FileStaticContentReader struct {
+ // content is immutable.
+ content []byte
+}
+
+// NewFileStaticContentReader initializes a FileStaticContentReader with the
+// given content.
+func NewFileStaticContentReader(b []byte) FileStaticContentReader {
+ return FileStaticContentReader{
+ content: b,
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (scr *FileStaticContentReader) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset >= int64(len(scr.content)) {
+ return 0, nil
+ }
+ n, err := dst.CopyOut(ctx, scr.content[offset:])
+ return int64(n), err
+}
+
+// FileNoopWrite implements fs.FileOperations.Write as a noop.
+type FileNoopWrite struct{}
+
+// Write implements fs.FileOperations.Write.
+func (FileNoopWrite) Write(_ context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// FileNoRead implements fs.FileOperations.Read to return EINVAL.
+type FileNoRead struct{}
+
+// Read implements fs.FileOperations.Read.
+func (FileNoRead) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// FileNoWrite implements fs.FileOperations.Write to return EINVAL.
+type FileNoWrite struct{}
+
+// Write implements fs.FileOperations.Write.
+func (FileNoWrite) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// FileNoopRead implement fs.FileOperations.Read as a noop.
+type FileNoopRead struct{}
+
+// Read implements fs.FileOperations.Read.
+func (FileNoopRead) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, nil
+}
+
+// FileUseInodeUnstableAttr implements fs.FileOperations.UnstableAttr by calling
+// InodeOperations.UnstableAttr.
+type FileUseInodeUnstableAttr struct{}
+
+// UnstableAttr implements fs.FileOperations.UnstableAttr.
+func (FileUseInodeUnstableAttr) UnstableAttr(ctx context.Context, file *fs.File) (fs.UnstableAttr, error) {
+ return file.Dirent.Inode.UnstableAttr(ctx)
+}
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
new file mode 100644
index 000000000..b5ac6c71c
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -0,0 +1,209 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// FileRangeSet maps offsets into a memmap.Mappable to offsets into a
+// platform.File. It is used to implement Mappables that store data in
+// sparsely-allocated memory.
+//
+// type FileRangeSet <generated by go_generics>
+
+// fileRangeSetFunctions implements segment.Functions for FileRangeSet.
+type fileRangeSetFunctions struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (fileRangeSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (fileRangeSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (fileRangeSetFunctions) ClearValue(_ *uint64) {
+}
+
+// Merge implements segment.Functions.Merge.
+func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
+ if frstart1+mr1.Length() != frstart2 {
+ return 0, false
+ }
+ return frstart1, true
+}
+
+// Split implements segment.Functions.Split.
+func (fileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
+ return frstart, frstart + (split - mr.Start)
+}
+
+// FileRange returns the FileRange mapped by seg.
+func (seg FileRangeIterator) FileRange() platform.FileRange {
+ return seg.FileRangeOf(seg.Range())
+}
+
+// FileRangeOf returns the FileRange mapped by mr.
+//
+// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0.
+func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange {
+ frstart := seg.Value() + (mr.Start - seg.Start())
+ return platform.FileRange{frstart, frstart + mr.Length()}
+}
+
+// Fill attempts to ensure that all memmap.Mappable offsets in required are
+// mapped to a platform.File offset, by allocating from mf with the given
+// memory usage kind and invoking readAt to store data into memory. (If readAt
+// returns a successful partial read, Fill will call it repeatedly until all
+// bytes have been read.) EOF is handled consistently with the requirements of
+// mmap(2): bytes after EOF on the same page are zeroed; pages after EOF are
+// invalid.
+//
+// Fill may read offsets outside of required, but will never read offsets
+// outside of optional. It returns a non-nil error if any error occurs, even
+// if the error only affects offsets in optional, but not in required.
+//
+// Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
+// required and optional must be page-aligned.
+func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error {
+ gap := frs.LowerBoundGap(required.Start)
+ for gap.Ok() && gap.Start() < required.End {
+ if gap.Range().Length() == 0 {
+ gap = gap.NextGap()
+ continue
+ }
+ gr := gap.Range().Intersect(optional)
+
+ // Read data into the gap.
+ fr, err := mf.AllocateAndFill(gr.Length(), kind, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ var done uint64
+ for !dsts.IsEmpty() {
+ n, err := readAt(ctx, dsts, gr.Start+done)
+ done += n
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ if err == io.EOF {
+ // MemoryFile.AllocateAndFill truncates down to a page
+ // boundary, but FileRangeSet.Fill is supposed to
+ // zero-fill to the end of the page in this case.
+ donepgaddr, ok := usermem.Addr(done).RoundUp()
+ if donepg := uint64(donepgaddr); ok && donepg != done {
+ dsts.DropFirst64(donepg - done)
+ done = donepg
+ if dsts.IsEmpty() {
+ return done, nil
+ }
+ }
+ }
+ return done, err
+ }
+ }
+ return done, nil
+ }))
+
+ // Store anything we managed to read into the cache.
+ if done := fr.Length(); done != 0 {
+ gr.End = gr.Start + done
+ gap = frs.Insert(gap, gr, fr.Start).NextGap()
+ }
+
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Drop removes segments for memmap.Mappable offsets in mr, freeing the
+// corresponding platform.FileRanges.
+//
+// Preconditions: mr must be page-aligned.
+func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
+ seg := frs.LowerBoundSegment(mr.Start)
+ for seg.Ok() && seg.Start() < mr.End {
+ seg = frs.Isolate(seg, mr)
+ mf.DecRef(seg.FileRange())
+ seg = frs.Remove(seg).NextSegment()
+ }
+}
+
+// DropAll removes all segments in mr, freeing the corresponding
+// platform.FileRanges.
+func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) {
+ for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ mf.DecRef(seg.FileRange())
+ }
+ frs.RemoveAll()
+}
+
+// Truncate updates frs to reflect Mappable truncation to the given length:
+// bytes after the new EOF on the same page are zeroed, and pages after the new
+// EOF are freed.
+func (frs *FileRangeSet) Truncate(end uint64, mf *pgalloc.MemoryFile) {
+ pgendaddr, ok := usermem.Addr(end).RoundUp()
+ if ok {
+ pgend := uint64(pgendaddr)
+
+ // Free truncated pages.
+ frs.SplitAt(pgend)
+ seg := frs.LowerBoundSegment(pgend)
+ for seg.Ok() {
+ mf.DecRef(seg.FileRange())
+ seg = frs.Remove(seg).NextSegment()
+ }
+
+ if end == pgend {
+ return
+ }
+ }
+
+ // Here we know end < end.RoundUp(). If the new EOF lands in the
+ // middle of a page that we have, zero out its contents beyond the new
+ // length.
+ seg := frs.FindSegment(end)
+ if seg.Ok() {
+ fr := seg.FileRange()
+ fr.Start += end - seg.Start()
+ ims, err := mf.MapInternal(fr, usermem.Write)
+ if err != nil {
+ // There's no good recourse from here. This means
+ // that we can't keep cached memory consistent with
+ // the new end of file. The caller may have already
+ // updated the file size on their backing file system.
+ //
+ // We don't want to risk blindly continuing onward,
+ // so in the extremely rare cases this does happen,
+ // we abandon ship.
+ panic(fmt.Sprintf("Failed to map %v: %v", fr, err))
+ }
+ if _, err := safemem.ZeroSeq(ims); err != nil {
+ panic(fmt.Sprintf("Zeroing %v failed: %v", fr, err))
+ }
+ }
+}
diff --git a/pkg/sentry/fs/fsutil/file_range_set_impl.go b/pkg/sentry/fs/fsutil/file_range_set_impl.go
new file mode 100755
index 000000000..a0ab61628
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/file_range_set_impl.go
@@ -0,0 +1,1274 @@
+package fsutil
+
+import (
+ __generics_imported0 "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+)
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ FileRangeminDegree = 3
+
+ FileRangemaxDegree = 2 * FileRangeminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type FileRangeSet struct {
+ root FileRangenode `state:".(*FileRangeSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *FileRangeSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *FileRangeSet) IsEmptyRange(r __generics_imported0.MappableRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *FileRangeSet) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *FileRangeSet) SpanRange(r __generics_imported0.MappableRange) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *FileRangeSet) FirstSegment() FileRangeIterator {
+ if s.root.nrSegments == 0 {
+ return FileRangeIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *FileRangeSet) LastSegment() FileRangeIterator {
+ if s.root.nrSegments == 0 {
+ return FileRangeIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *FileRangeSet) FirstGap() FileRangeGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return FileRangeGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *FileRangeSet) LastGap() FileRangeGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return FileRangeGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *FileRangeSet) Find(key uint64) (FileRangeIterator, FileRangeGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return FileRangeIterator{n, i}, FileRangeGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return FileRangeIterator{}, FileRangeGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *FileRangeSet) FindSegment(key uint64) FileRangeIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *FileRangeSet) LowerBoundSegment(min uint64) FileRangeIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *FileRangeSet) UpperBoundSegment(max uint64) FileRangeIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *FileRangeSet) FindGap(key uint64) FileRangeGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *FileRangeSet) LowerBoundGap(min uint64) FileRangeGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *FileRangeSet) UpperBoundGap(max uint64) FileRangeGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *FileRangeSet) Add(r __generics_imported0.MappableRange, val uint64) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *FileRangeSet) AddWithoutMerging(r __generics_imported0.MappableRange, val uint64) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *FileRangeSet) Insert(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (fileRangeSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (fileRangeSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (fileRangeSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *FileRangeSet) InsertWithoutMerging(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *FileRangeSet) InsertWithoutMergingUnchecked(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return FileRangeIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *FileRangeSet) Remove(seg FileRangeIterator) FileRangeGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ fileRangeSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(FileRangeGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *FileRangeSet) RemoveAll() {
+ s.root = FileRangenode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *FileRangeSet) RemoveRange(r __generics_imported0.MappableRange) FileRangeGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *FileRangeSet) Merge(first, second FileRangeIterator) FileRangeIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *FileRangeSet) MergeUnchecked(first, second FileRangeIterator) FileRangeIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (fileRangeSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return FileRangeIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *FileRangeSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *FileRangeSet) MergeRange(r __generics_imported0.MappableRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *FileRangeSet) MergeAdjacent(r __generics_imported0.MappableRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *FileRangeSet) Split(seg FileRangeIterator, split uint64) (FileRangeIterator, FileRangeIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *FileRangeSet) SplitUnchecked(seg FileRangeIterator, split uint64) (FileRangeIterator, FileRangeIterator) {
+ val1, val2 := (fileRangeSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.MappableRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *FileRangeSet) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *FileRangeSet) Isolate(seg FileRangeIterator, r __generics_imported0.MappableRange) FileRangeIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *FileRangeSet) ApplyContiguous(r __generics_imported0.MappableRange, fn func(seg FileRangeIterator)) FileRangeGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return FileRangeGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return FileRangeGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type FileRangenode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *FileRangenode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [FileRangemaxDegree - 1]__generics_imported0.MappableRange
+ values [FileRangemaxDegree - 1]uint64
+ children [FileRangemaxDegree]*FileRangenode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *FileRangenode) firstSegment() FileRangeIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return FileRangeIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *FileRangenode) lastSegment() FileRangeIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return FileRangeIterator{n, n.nrSegments - 1}
+}
+
+func (n *FileRangenode) prevSibling() *FileRangenode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *FileRangenode) nextSibling() *FileRangenode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *FileRangenode) rebalanceBeforeInsert(gap FileRangeGapIterator) FileRangeGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < FileRangemaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &FileRangenode{
+ nrSegments: FileRangeminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &FileRangenode{
+ nrSegments: FileRangeminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:FileRangeminDegree-1], n.keys[:FileRangeminDegree-1])
+ copy(left.values[:FileRangeminDegree-1], n.values[:FileRangeminDegree-1])
+ copy(right.keys[:FileRangeminDegree-1], n.keys[FileRangeminDegree:])
+ copy(right.values[:FileRangeminDegree-1], n.values[FileRangeminDegree:])
+ n.keys[0], n.values[0] = n.keys[FileRangeminDegree-1], n.values[FileRangeminDegree-1]
+ FileRangezeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:FileRangeminDegree], n.children[:FileRangeminDegree])
+ copy(right.children[:FileRangeminDegree], n.children[FileRangeminDegree:])
+ FileRangezeroNodeSlice(n.children[2:])
+ for i := 0; i < FileRangeminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < FileRangeminDegree {
+ return FileRangeGapIterator{left, gap.index}
+ }
+ return FileRangeGapIterator{right, gap.index - FileRangeminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[FileRangeminDegree-1], n.values[FileRangeminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &FileRangenode{
+ nrSegments: FileRangeminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:FileRangeminDegree-1], n.keys[FileRangeminDegree:])
+ copy(sibling.values[:FileRangeminDegree-1], n.values[FileRangeminDegree:])
+ FileRangezeroValueSlice(n.values[FileRangeminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:FileRangeminDegree], n.children[FileRangeminDegree:])
+ FileRangezeroNodeSlice(n.children[FileRangeminDegree:])
+ for i := 0; i < FileRangeminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = FileRangeminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < FileRangeminDegree {
+ return gap
+ }
+ return FileRangeGapIterator{sibling, gap.index - FileRangeminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *FileRangenode) rebalanceAfterRemove(gap FileRangeGapIterator) FileRangeGapIterator {
+ for {
+ if n.nrSegments >= FileRangeminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= FileRangeminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ fileRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return FileRangeGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return FileRangeGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= FileRangeminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ fileRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return FileRangeGapIterator{n, n.nrSegments}
+ }
+ return FileRangeGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return FileRangeGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return FileRangeGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *FileRangenode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = FileRangeGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ fileRangeSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type FileRangeIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *FileRangenode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg FileRangeIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg FileRangeIterator) Range() __generics_imported0.MappableRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg FileRangeIterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg FileRangeIterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg FileRangeIterator) SetRangeUnchecked(r __generics_imported0.MappableRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg FileRangeIterator) SetRange(r __generics_imported0.MappableRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg FileRangeIterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg FileRangeIterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg FileRangeIterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg FileRangeIterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg FileRangeIterator) Value() uint64 {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg FileRangeIterator) ValuePtr() *uint64 {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg FileRangeIterator) SetValue(val uint64) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg FileRangeIterator) PrevSegment() FileRangeIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return FileRangeIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return FileRangeIterator{}
+ }
+ return FileRangesegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg FileRangeIterator) NextSegment() FileRangeIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return FileRangeIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return FileRangeIterator{}
+ }
+ return FileRangesegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg FileRangeIterator) PrevGap() FileRangeGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return FileRangeGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg FileRangeIterator) NextGap() FileRangeGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return FileRangeGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg FileRangeIterator) PrevNonEmpty() (FileRangeIterator, FileRangeGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return FileRangeIterator{}, gap
+ }
+ return gap.PrevSegment(), FileRangeGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg FileRangeIterator) NextNonEmpty() (FileRangeIterator, FileRangeGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return FileRangeIterator{}, gap
+ }
+ return gap.NextSegment(), FileRangeGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type FileRangeGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *FileRangenode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap FileRangeGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap FileRangeGapIterator) Range() __generics_imported0.MappableRange {
+ return __generics_imported0.MappableRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap FileRangeGapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return fileRangeSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap FileRangeGapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return fileRangeSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap FileRangeGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap FileRangeGapIterator) PrevSegment() FileRangeIterator {
+ return FileRangesegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap FileRangeGapIterator) NextSegment() FileRangeIterator {
+ return FileRangesegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap FileRangeGapIterator) PrevGap() FileRangeGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return FileRangeGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap FileRangeGapIterator) NextGap() FileRangeGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return FileRangeGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func FileRangesegmentBeforePosition(n *FileRangenode, i int) FileRangeIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return FileRangeIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return FileRangeIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func FileRangesegmentAfterPosition(n *FileRangenode, i int) FileRangeIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return FileRangeIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return FileRangeIterator{n, i}
+}
+
+func FileRangezeroValueSlice(slice []uint64) {
+
+ for i := range slice {
+ fileRangeSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func FileRangezeroNodeSlice(slice []*FileRangenode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *FileRangeSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *FileRangenode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *FileRangenode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type FileRangeSegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []uint64
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *FileRangeSet) ExportSortedSlices() *FileRangeSegmentDataSlices {
+ var sds FileRangeSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *FileRangeSet) ImportSortedSlices(sds *FileRangeSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := __generics_imported0.MappableRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *FileRangeSet) saveRoot() *FileRangeSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *FileRangeSet) loadRoot(sds *FileRangeSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go
new file mode 100644
index 000000000..6565c28c8
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/frame_ref_set.go
@@ -0,0 +1,50 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+)
+
+type frameRefSetFunctions struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (frameRefSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (frameRefSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (frameRefSetFunctions) ClearValue(val *uint64) {
+}
+
+// Merge implements segment.Functions.Merge.
+func (frameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) {
+ if val1 != val2 {
+ return 0, false
+ }
+ return val1, true
+}
+
+// Split implements segment.Functions.Split.
+func (frameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) {
+ return val, val
+}
diff --git a/pkg/sentry/fs/fsutil/frame_ref_set_impl.go b/pkg/sentry/fs/fsutil/frame_ref_set_impl.go
new file mode 100755
index 000000000..2f858f419
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/frame_ref_set_impl.go
@@ -0,0 +1,1274 @@
+package fsutil
+
+import (
+ __generics_imported0 "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+)
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ frameRefminDegree = 3
+
+ frameRefmaxDegree = 2 * frameRefminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type frameRefSet struct {
+ root frameRefnode `state:".(*frameRefSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *frameRefSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *frameRefSet) IsEmptyRange(r __generics_imported0.FileRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *frameRefSet) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *frameRefSet) SpanRange(r __generics_imported0.FileRange) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *frameRefSet) FirstSegment() frameRefIterator {
+ if s.root.nrSegments == 0 {
+ return frameRefIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *frameRefSet) LastSegment() frameRefIterator {
+ if s.root.nrSegments == 0 {
+ return frameRefIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *frameRefSet) FirstGap() frameRefGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return frameRefGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *frameRefSet) LastGap() frameRefGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return frameRefGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *frameRefSet) Find(key uint64) (frameRefIterator, frameRefGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return frameRefIterator{n, i}, frameRefGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return frameRefIterator{}, frameRefGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *frameRefSet) FindSegment(key uint64) frameRefIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *frameRefSet) LowerBoundSegment(min uint64) frameRefIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *frameRefSet) UpperBoundSegment(max uint64) frameRefIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *frameRefSet) FindGap(key uint64) frameRefGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *frameRefSet) LowerBoundGap(min uint64) frameRefGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *frameRefSet) UpperBoundGap(max uint64) frameRefGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *frameRefSet) Add(r __generics_imported0.FileRange, val uint64) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *frameRefSet) AddWithoutMerging(r __generics_imported0.FileRange, val uint64) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *frameRefSet) Insert(gap frameRefGapIterator, r __generics_imported0.FileRange, val uint64) frameRefIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (frameRefSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (frameRefSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (frameRefSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *frameRefSet) InsertWithoutMerging(gap frameRefGapIterator, r __generics_imported0.FileRange, val uint64) frameRefIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *frameRefSet) InsertWithoutMergingUnchecked(gap frameRefGapIterator, r __generics_imported0.FileRange, val uint64) frameRefIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return frameRefIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *frameRefSet) Remove(seg frameRefIterator) frameRefGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ frameRefSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(frameRefGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *frameRefSet) RemoveAll() {
+ s.root = frameRefnode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *frameRefSet) RemoveRange(r __generics_imported0.FileRange) frameRefGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *frameRefSet) Merge(first, second frameRefIterator) frameRefIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *frameRefSet) MergeUnchecked(first, second frameRefIterator) frameRefIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (frameRefSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return frameRefIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *frameRefSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *frameRefSet) MergeRange(r __generics_imported0.FileRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *frameRefSet) MergeAdjacent(r __generics_imported0.FileRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *frameRefSet) Split(seg frameRefIterator, split uint64) (frameRefIterator, frameRefIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *frameRefSet) SplitUnchecked(seg frameRefIterator, split uint64) (frameRefIterator, frameRefIterator) {
+ val1, val2 := (frameRefSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *frameRefSet) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *frameRefSet) Isolate(seg frameRefIterator, r __generics_imported0.FileRange) frameRefIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *frameRefSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg frameRefIterator)) frameRefGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return frameRefGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return frameRefGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type frameRefnode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *frameRefnode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [frameRefmaxDegree - 1]__generics_imported0.FileRange
+ values [frameRefmaxDegree - 1]uint64
+ children [frameRefmaxDegree]*frameRefnode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *frameRefnode) firstSegment() frameRefIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return frameRefIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *frameRefnode) lastSegment() frameRefIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return frameRefIterator{n, n.nrSegments - 1}
+}
+
+func (n *frameRefnode) prevSibling() *frameRefnode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *frameRefnode) nextSibling() *frameRefnode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *frameRefnode) rebalanceBeforeInsert(gap frameRefGapIterator) frameRefGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < frameRefmaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &frameRefnode{
+ nrSegments: frameRefminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &frameRefnode{
+ nrSegments: frameRefminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:frameRefminDegree-1], n.keys[:frameRefminDegree-1])
+ copy(left.values[:frameRefminDegree-1], n.values[:frameRefminDegree-1])
+ copy(right.keys[:frameRefminDegree-1], n.keys[frameRefminDegree:])
+ copy(right.values[:frameRefminDegree-1], n.values[frameRefminDegree:])
+ n.keys[0], n.values[0] = n.keys[frameRefminDegree-1], n.values[frameRefminDegree-1]
+ frameRefzeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:frameRefminDegree], n.children[:frameRefminDegree])
+ copy(right.children[:frameRefminDegree], n.children[frameRefminDegree:])
+ frameRefzeroNodeSlice(n.children[2:])
+ for i := 0; i < frameRefminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < frameRefminDegree {
+ return frameRefGapIterator{left, gap.index}
+ }
+ return frameRefGapIterator{right, gap.index - frameRefminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[frameRefminDegree-1], n.values[frameRefminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &frameRefnode{
+ nrSegments: frameRefminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:frameRefminDegree-1], n.keys[frameRefminDegree:])
+ copy(sibling.values[:frameRefminDegree-1], n.values[frameRefminDegree:])
+ frameRefzeroValueSlice(n.values[frameRefminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:frameRefminDegree], n.children[frameRefminDegree:])
+ frameRefzeroNodeSlice(n.children[frameRefminDegree:])
+ for i := 0; i < frameRefminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = frameRefminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < frameRefminDegree {
+ return gap
+ }
+ return frameRefGapIterator{sibling, gap.index - frameRefminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *frameRefnode) rebalanceAfterRemove(gap frameRefGapIterator) frameRefGapIterator {
+ for {
+ if n.nrSegments >= frameRefminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= frameRefminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ frameRefSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return frameRefGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return frameRefGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= frameRefminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ frameRefSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return frameRefGapIterator{n, n.nrSegments}
+ }
+ return frameRefGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return frameRefGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return frameRefGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *frameRefnode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = frameRefGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ frameRefSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type frameRefIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *frameRefnode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg frameRefIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg frameRefIterator) Range() __generics_imported0.FileRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg frameRefIterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg frameRefIterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg frameRefIterator) SetRangeUnchecked(r __generics_imported0.FileRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg frameRefIterator) SetRange(r __generics_imported0.FileRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg frameRefIterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg frameRefIterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg frameRefIterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg frameRefIterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg frameRefIterator) Value() uint64 {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg frameRefIterator) ValuePtr() *uint64 {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg frameRefIterator) SetValue(val uint64) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg frameRefIterator) PrevSegment() frameRefIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return frameRefIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return frameRefIterator{}
+ }
+ return frameRefsegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg frameRefIterator) NextSegment() frameRefIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return frameRefIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return frameRefIterator{}
+ }
+ return frameRefsegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg frameRefIterator) PrevGap() frameRefGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return frameRefGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg frameRefIterator) NextGap() frameRefGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return frameRefGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg frameRefIterator) PrevNonEmpty() (frameRefIterator, frameRefGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return frameRefIterator{}, gap
+ }
+ return gap.PrevSegment(), frameRefGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg frameRefIterator) NextNonEmpty() (frameRefIterator, frameRefGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return frameRefIterator{}, gap
+ }
+ return gap.NextSegment(), frameRefGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type frameRefGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *frameRefnode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap frameRefGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap frameRefGapIterator) Range() __generics_imported0.FileRange {
+ return __generics_imported0.FileRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap frameRefGapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return frameRefSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap frameRefGapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return frameRefSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap frameRefGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap frameRefGapIterator) PrevSegment() frameRefIterator {
+ return frameRefsegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap frameRefGapIterator) NextSegment() frameRefIterator {
+ return frameRefsegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap frameRefGapIterator) PrevGap() frameRefGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return frameRefGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap frameRefGapIterator) NextGap() frameRefGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return frameRefGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func frameRefsegmentBeforePosition(n *frameRefnode, i int) frameRefIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return frameRefIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return frameRefIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func frameRefsegmentAfterPosition(n *frameRefnode, i int) frameRefIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return frameRefIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return frameRefIterator{n, i}
+}
+
+func frameRefzeroValueSlice(slice []uint64) {
+
+ for i := range slice {
+ frameRefSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func frameRefzeroNodeSlice(slice []*frameRefnode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *frameRefSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *frameRefnode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *frameRefnode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type frameRefSegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []uint64
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *frameRefSet) ExportSortedSlices() *frameRefSegmentDataSlices {
+ var sds frameRefSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *frameRefSet) ImportSortedSlices(sds *frameRefSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *frameRefSet) saveRoot() *frameRefSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *frameRefSet) loadRoot(sds *frameRefSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/fs/fsutil/fsutil.go b/pkg/sentry/fs/fsutil/fsutil.go
new file mode 100644
index 000000000..c9587b1d9
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/fsutil.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsutil provides utilities for implementing fs.InodeOperations
+// and fs.FileOperations:
+//
+// - For embeddable utilities, see inode.go and file.go.
+//
+// - For fs.Inodes that require a page cache to be memory mapped, see
+// inode_cache.go.
+//
+// - For anon fs.Inodes, see anon.go.
+package fsutil
diff --git a/pkg/sentry/fs/fsutil/fsutil_state_autogen.go b/pkg/sentry/fs/fsutil/fsutil_state_autogen.go
new file mode 100755
index 000000000..5783b151d
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/fsutil_state_autogen.go
@@ -0,0 +1,349 @@
+// automatically generated by stateify.
+
+package fsutil
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *DirtyInfo) beforeSave() {}
+func (x *DirtyInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Keep", &x.Keep)
+}
+
+func (x *DirtyInfo) afterLoad() {}
+func (x *DirtyInfo) load(m state.Map) {
+ m.Load("Keep", &x.Keep)
+}
+
+func (x *DirtySet) beforeSave() {}
+func (x *DirtySet) save(m state.Map) {
+ x.beforeSave()
+ var root *DirtySegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *DirtySet) afterLoad() {}
+func (x *DirtySet) load(m state.Map) {
+ m.LoadValue("root", new(*DirtySegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*DirtySegmentDataSlices)) })
+}
+
+func (x *Dirtynode) beforeSave() {}
+func (x *Dirtynode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *Dirtynode) afterLoad() {}
+func (x *Dirtynode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *DirtySegmentDataSlices) beforeSave() {}
+func (x *DirtySegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *DirtySegmentDataSlices) afterLoad() {}
+func (x *DirtySegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func (x *StaticDirFileOperations) beforeSave() {}
+func (x *StaticDirFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("dentryMap", &x.dentryMap)
+ m.Save("dirCursor", &x.dirCursor)
+}
+
+func (x *StaticDirFileOperations) afterLoad() {}
+func (x *StaticDirFileOperations) load(m state.Map) {
+ m.Load("dentryMap", &x.dentryMap)
+ m.Load("dirCursor", &x.dirCursor)
+}
+
+func (x *NoReadWriteFile) beforeSave() {}
+func (x *NoReadWriteFile) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *NoReadWriteFile) afterLoad() {}
+func (x *NoReadWriteFile) load(m state.Map) {
+}
+
+func (x *FileStaticContentReader) beforeSave() {}
+func (x *FileStaticContentReader) save(m state.Map) {
+ x.beforeSave()
+ m.Save("content", &x.content)
+}
+
+func (x *FileStaticContentReader) afterLoad() {}
+func (x *FileStaticContentReader) load(m state.Map) {
+ m.Load("content", &x.content)
+}
+
+func (x *FileRangeSet) beforeSave() {}
+func (x *FileRangeSet) save(m state.Map) {
+ x.beforeSave()
+ var root *FileRangeSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *FileRangeSet) afterLoad() {}
+func (x *FileRangeSet) load(m state.Map) {
+ m.LoadValue("root", new(*FileRangeSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*FileRangeSegmentDataSlices)) })
+}
+
+func (x *FileRangenode) beforeSave() {}
+func (x *FileRangenode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *FileRangenode) afterLoad() {}
+func (x *FileRangenode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *FileRangeSegmentDataSlices) beforeSave() {}
+func (x *FileRangeSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *FileRangeSegmentDataSlices) afterLoad() {}
+func (x *FileRangeSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func (x *frameRefSet) beforeSave() {}
+func (x *frameRefSet) save(m state.Map) {
+ x.beforeSave()
+ var root *frameRefSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *frameRefSet) afterLoad() {}
+func (x *frameRefSet) load(m state.Map) {
+ m.LoadValue("root", new(*frameRefSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*frameRefSegmentDataSlices)) })
+}
+
+func (x *frameRefnode) beforeSave() {}
+func (x *frameRefnode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *frameRefnode) afterLoad() {}
+func (x *frameRefnode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *frameRefSegmentDataSlices) beforeSave() {}
+func (x *frameRefSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *frameRefSegmentDataSlices) afterLoad() {}
+func (x *frameRefSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func (x *HostFileMapper) beforeSave() {}
+func (x *HostFileMapper) save(m state.Map) {
+ x.beforeSave()
+ m.Save("refs", &x.refs)
+}
+
+func (x *HostFileMapper) load(m state.Map) {
+ m.Load("refs", &x.refs)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *HostMappable) beforeSave() {}
+func (x *HostMappable) save(m state.Map) {
+ x.beforeSave()
+ m.Save("hostFileMapper", &x.hostFileMapper)
+ m.Save("backingFile", &x.backingFile)
+ m.Save("mappings", &x.mappings)
+}
+
+func (x *HostMappable) afterLoad() {}
+func (x *HostMappable) load(m state.Map) {
+ m.Load("hostFileMapper", &x.hostFileMapper)
+ m.Load("backingFile", &x.backingFile)
+ m.Load("mappings", &x.mappings)
+}
+
+func (x *SimpleFileInode) beforeSave() {}
+func (x *SimpleFileInode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *SimpleFileInode) afterLoad() {}
+func (x *SimpleFileInode) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *NoReadWriteFileInode) beforeSave() {}
+func (x *NoReadWriteFileInode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *NoReadWriteFileInode) afterLoad() {}
+func (x *NoReadWriteFileInode) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+}
+
+func (x *InodeSimpleAttributes) beforeSave() {}
+func (x *InodeSimpleAttributes) save(m state.Map) {
+ x.beforeSave()
+ m.Save("fsType", &x.fsType)
+ m.Save("unstable", &x.unstable)
+}
+
+func (x *InodeSimpleAttributes) afterLoad() {}
+func (x *InodeSimpleAttributes) load(m state.Map) {
+ m.Load("fsType", &x.fsType)
+ m.Load("unstable", &x.unstable)
+}
+
+func (x *InodeSimpleExtendedAttributes) beforeSave() {}
+func (x *InodeSimpleExtendedAttributes) save(m state.Map) {
+ x.beforeSave()
+ m.Save("xattrs", &x.xattrs)
+}
+
+func (x *InodeSimpleExtendedAttributes) afterLoad() {}
+func (x *InodeSimpleExtendedAttributes) load(m state.Map) {
+ m.Load("xattrs", &x.xattrs)
+}
+
+func (x *staticFile) beforeSave() {}
+func (x *staticFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("FileStaticContentReader", &x.FileStaticContentReader)
+}
+
+func (x *staticFile) afterLoad() {}
+func (x *staticFile) load(m state.Map) {
+ m.Load("FileStaticContentReader", &x.FileStaticContentReader)
+}
+
+func (x *InodeStaticFileGetter) beforeSave() {}
+func (x *InodeStaticFileGetter) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Contents", &x.Contents)
+}
+
+func (x *InodeStaticFileGetter) afterLoad() {}
+func (x *InodeStaticFileGetter) load(m state.Map) {
+ m.Load("Contents", &x.Contents)
+}
+
+func (x *CachingInodeOperations) beforeSave() {}
+func (x *CachingInodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("backingFile", &x.backingFile)
+ m.Save("mfp", &x.mfp)
+ m.Save("forcePageCache", &x.forcePageCache)
+ m.Save("attr", &x.attr)
+ m.Save("dirtyAttr", &x.dirtyAttr)
+ m.Save("mappings", &x.mappings)
+ m.Save("cache", &x.cache)
+ m.Save("dirty", &x.dirty)
+ m.Save("hostFileMapper", &x.hostFileMapper)
+ m.Save("refs", &x.refs)
+}
+
+func (x *CachingInodeOperations) afterLoad() {}
+func (x *CachingInodeOperations) load(m state.Map) {
+ m.Load("backingFile", &x.backingFile)
+ m.Load("mfp", &x.mfp)
+ m.Load("forcePageCache", &x.forcePageCache)
+ m.Load("attr", &x.attr)
+ m.Load("dirtyAttr", &x.dirtyAttr)
+ m.Load("mappings", &x.mappings)
+ m.Load("cache", &x.cache)
+ m.Load("dirty", &x.dirty)
+ m.Load("hostFileMapper", &x.hostFileMapper)
+ m.Load("refs", &x.refs)
+}
+
+func init() {
+ state.Register("fsutil.DirtyInfo", (*DirtyInfo)(nil), state.Fns{Save: (*DirtyInfo).save, Load: (*DirtyInfo).load})
+ state.Register("fsutil.DirtySet", (*DirtySet)(nil), state.Fns{Save: (*DirtySet).save, Load: (*DirtySet).load})
+ state.Register("fsutil.Dirtynode", (*Dirtynode)(nil), state.Fns{Save: (*Dirtynode).save, Load: (*Dirtynode).load})
+ state.Register("fsutil.DirtySegmentDataSlices", (*DirtySegmentDataSlices)(nil), state.Fns{Save: (*DirtySegmentDataSlices).save, Load: (*DirtySegmentDataSlices).load})
+ state.Register("fsutil.StaticDirFileOperations", (*StaticDirFileOperations)(nil), state.Fns{Save: (*StaticDirFileOperations).save, Load: (*StaticDirFileOperations).load})
+ state.Register("fsutil.NoReadWriteFile", (*NoReadWriteFile)(nil), state.Fns{Save: (*NoReadWriteFile).save, Load: (*NoReadWriteFile).load})
+ state.Register("fsutil.FileStaticContentReader", (*FileStaticContentReader)(nil), state.Fns{Save: (*FileStaticContentReader).save, Load: (*FileStaticContentReader).load})
+ state.Register("fsutil.FileRangeSet", (*FileRangeSet)(nil), state.Fns{Save: (*FileRangeSet).save, Load: (*FileRangeSet).load})
+ state.Register("fsutil.FileRangenode", (*FileRangenode)(nil), state.Fns{Save: (*FileRangenode).save, Load: (*FileRangenode).load})
+ state.Register("fsutil.FileRangeSegmentDataSlices", (*FileRangeSegmentDataSlices)(nil), state.Fns{Save: (*FileRangeSegmentDataSlices).save, Load: (*FileRangeSegmentDataSlices).load})
+ state.Register("fsutil.frameRefSet", (*frameRefSet)(nil), state.Fns{Save: (*frameRefSet).save, Load: (*frameRefSet).load})
+ state.Register("fsutil.frameRefnode", (*frameRefnode)(nil), state.Fns{Save: (*frameRefnode).save, Load: (*frameRefnode).load})
+ state.Register("fsutil.frameRefSegmentDataSlices", (*frameRefSegmentDataSlices)(nil), state.Fns{Save: (*frameRefSegmentDataSlices).save, Load: (*frameRefSegmentDataSlices).load})
+ state.Register("fsutil.HostFileMapper", (*HostFileMapper)(nil), state.Fns{Save: (*HostFileMapper).save, Load: (*HostFileMapper).load})
+ state.Register("fsutil.HostMappable", (*HostMappable)(nil), state.Fns{Save: (*HostMappable).save, Load: (*HostMappable).load})
+ state.Register("fsutil.SimpleFileInode", (*SimpleFileInode)(nil), state.Fns{Save: (*SimpleFileInode).save, Load: (*SimpleFileInode).load})
+ state.Register("fsutil.NoReadWriteFileInode", (*NoReadWriteFileInode)(nil), state.Fns{Save: (*NoReadWriteFileInode).save, Load: (*NoReadWriteFileInode).load})
+ state.Register("fsutil.InodeSimpleAttributes", (*InodeSimpleAttributes)(nil), state.Fns{Save: (*InodeSimpleAttributes).save, Load: (*InodeSimpleAttributes).load})
+ state.Register("fsutil.InodeSimpleExtendedAttributes", (*InodeSimpleExtendedAttributes)(nil), state.Fns{Save: (*InodeSimpleExtendedAttributes).save, Load: (*InodeSimpleExtendedAttributes).load})
+ state.Register("fsutil.staticFile", (*staticFile)(nil), state.Fns{Save: (*staticFile).save, Load: (*staticFile).load})
+ state.Register("fsutil.InodeStaticFileGetter", (*InodeStaticFileGetter)(nil), state.Fns{Save: (*InodeStaticFileGetter).save, Load: (*InodeStaticFileGetter).load})
+ state.Register("fsutil.CachingInodeOperations", (*CachingInodeOperations)(nil), state.Fns{Save: (*CachingInodeOperations).save, Load: (*CachingInodeOperations).load})
+}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
new file mode 100644
index 000000000..2bdfc0db6
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -0,0 +1,211 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// HostFileMapper caches mappings of an arbitrary host file descriptor. It is
+// used by implementations of memmap.Mappable that represent a host file
+// descriptor.
+//
+// +stateify savable
+type HostFileMapper struct {
+ // HostFile conceptually breaks the file into pieces called chunks, of
+ // size and alignment chunkSize, and caches mappings of the file on a chunk
+ // granularity.
+
+ refsMu sync.Mutex `state:"nosave"`
+
+ // refs maps chunk start offsets to the sum of reference counts for all
+ // pages in that chunk. refs is protected by refsMu.
+ refs map[uint64]int32
+
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings maps chunk start offsets to mappings of those chunks,
+ // obtained by calling syscall.Mmap. mappings is protected by
+ // mapsMu.
+ mappings map[uint64]mapping `state:"nosave"`
+}
+
+const (
+ chunkShift = usermem.HugePageShift
+ chunkSize = 1 << chunkShift
+ chunkMask = chunkSize - 1
+)
+
+func pagesInChunk(mr memmap.MappableRange, chunkStart uint64) int32 {
+ return int32(mr.Intersect(memmap.MappableRange{chunkStart, chunkStart + chunkSize}).Length() / usermem.PageSize)
+}
+
+type mapping struct {
+ addr uintptr
+ writable bool
+}
+
+// NewHostFileMapper returns a HostFileMapper with no references or cached
+// mappings.
+func NewHostFileMapper() *HostFileMapper {
+ return &HostFileMapper{
+ refs: make(map[uint64]int32),
+ mappings: make(map[uint64]mapping),
+ }
+}
+
+// IncRefOn increments the reference count on all offsets in mr.
+//
+// Preconditions: mr.Length() != 0. mr.Start and mr.End must be page-aligned.
+func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) {
+ f.refsMu.Lock()
+ defer f.refsMu.Unlock()
+ for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize {
+ refs := f.refs[chunkStart]
+ pgs := pagesInChunk(mr, chunkStart)
+ if refs+pgs < refs {
+ // Would overflow.
+ panic(fmt.Sprintf("HostFileMapper.IncRefOn(%v): adding %d page references to chunk %#x, which has %d page references", mr, pgs, chunkStart, refs))
+ }
+ f.refs[chunkStart] = refs + pgs
+ }
+}
+
+// DecRefOn decrements the reference count on all offsets in mr.
+//
+// Preconditions: mr.Length() != 0. mr.Start and mr.End must be page-aligned.
+func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
+ f.refsMu.Lock()
+ defer f.refsMu.Unlock()
+ for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize {
+ refs := f.refs[chunkStart]
+ pgs := pagesInChunk(mr, chunkStart)
+ switch {
+ case refs > pgs:
+ f.refs[chunkStart] = refs - pgs
+ case refs == pgs:
+ f.mapsMu.Lock()
+ delete(f.refs, chunkStart)
+ if m, ok := f.mappings[chunkStart]; ok {
+ f.unmapAndRemoveLocked(chunkStart, m)
+ }
+ f.mapsMu.Unlock()
+ case refs < pgs:
+ panic(fmt.Sprintf("HostFileMapper.DecRefOn(%v): removing %d page references from chunk %#x, which has %d page references", mr, pgs, chunkStart, refs))
+ }
+ }
+}
+
+// MapInternal returns a mapping of offsets in fr from fd. The returned
+// safemem.BlockSeq is valid as long as at least one reference is held on all
+// offsets in fr or until the next call to UnmapAll.
+//
+// Preconditions: The caller must hold a reference on all offsets in fr.
+func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
+ chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift)
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+ if chunks == 1 {
+ // Avoid an unnecessary slice allocation.
+ var seq safemem.BlockSeq
+ err := f.forEachMappingBlockLocked(fr, fd, write, func(b safemem.Block) {
+ seq = safemem.BlockSeqOf(b)
+ })
+ return seq, err
+ }
+ blocks := make([]safemem.Block, 0, chunks)
+ err := f.forEachMappingBlockLocked(fr, fd, write, func(b safemem.Block) {
+ blocks = append(blocks, b)
+ })
+ return safemem.BlockSeqFromSlice(blocks), err
+}
+
+// Preconditions: f.mapsMu must be locked.
+func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error {
+ prot := syscall.PROT_READ
+ if write {
+ prot |= syscall.PROT_WRITE
+ }
+ for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
+ m, ok := f.mappings[chunkStart]
+ if !ok {
+ addr, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ 0,
+ chunkSize,
+ uintptr(prot),
+ syscall.MAP_SHARED,
+ uintptr(fd),
+ uintptr(chunkStart))
+ if errno != 0 {
+ return errno
+ }
+ m = mapping{addr, write}
+ f.mappings[chunkStart] = m
+ } else if write && !m.writable {
+ addr, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ m.addr,
+ chunkSize,
+ uintptr(prot),
+ syscall.MAP_SHARED|syscall.MAP_FIXED,
+ uintptr(fd),
+ uintptr(chunkStart))
+ if errno != 0 {
+ return errno
+ }
+ m = mapping{addr, write}
+ f.mappings[chunkStart] = m
+ }
+ var startOff uint64
+ if chunkStart < fr.Start {
+ startOff = fr.Start - chunkStart
+ }
+ endOff := uint64(chunkSize)
+ if chunkStart+chunkSize > fr.End {
+ endOff = fr.End - chunkStart
+ }
+ fn(f.unsafeBlockFromChunkMapping(m.addr).TakeFirst64(endOff).DropFirst64(startOff))
+ }
+ return nil
+}
+
+// UnmapAll unmaps all cached mappings. Callers are responsible for
+// synchronization with mappings returned by previous calls to MapInternal.
+func (f *HostFileMapper) UnmapAll() {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+ for chunkStart, m := range f.mappings {
+ f.unmapAndRemoveLocked(chunkStart, m)
+ }
+}
+
+// Preconditions: f.mapsMu must be locked. f.mappings[chunkStart] == m.
+func (f *HostFileMapper) unmapAndRemoveLocked(chunkStart uint64, m mapping) {
+ if _, _, errno := syscall.Syscall(syscall.SYS_MUNMAP, m.addr, chunkSize, 0); errno != 0 {
+ // This leaks address space and is unexpected, but is otherwise
+ // harmless, so complain but don't panic.
+ log.Warningf("HostFileMapper: failed to unmap mapping %#x for chunk %#x: %v", m.addr, chunkStart, errno)
+ }
+ delete(f.mappings, chunkStart)
+}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper_state.go b/pkg/sentry/fs/fsutil/host_file_mapper_state.go
new file mode 100644
index 000000000..576d2a3df
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/host_file_mapper_state.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+// afterLoad is invoked by stateify.
+func (f *HostFileMapper) afterLoad() {
+ f.mappings = make(map[uint64]mapping)
+}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go b/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go
new file mode 100644
index 000000000..7167be263
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+)
+
+func (*HostFileMapper) unsafeBlockFromChunkMapping(addr uintptr) safemem.Block {
+ // We don't control the host file's length, so touching its mappings may
+ // raise SIGBUS. Thus accesses to it must use safecopy.
+ return safemem.BlockFromUnsafePointer((unsafe.Pointer)(addr), chunkSize)
+}
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
new file mode 100644
index 000000000..ad0518b8f
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -0,0 +1,197 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "math"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// HostMappable implements memmap.Mappable and platform.File over a
+// CachedFileObject.
+//
+// Lock order (compare the lock order model in mm/mm.go):
+// truncateMu ("fs locks")
+// mu ("memmap.Mappable locks not taken by Translate")
+// ("platform.File locks")
+// backingFile ("CachedFileObject locks")
+//
+// +stateify savable
+type HostMappable struct {
+ hostFileMapper *HostFileMapper
+
+ backingFile CachedFileObject
+
+ mu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the cached file object into
+ // memmap.MappingSpaces so it can invalidated upon save. Protected by mu.
+ mappings memmap.MappingSet
+
+ // truncateMu protects writes and truncations. See Truncate() for details.
+ truncateMu sync.RWMutex `state:"nosave"`
+}
+
+// NewHostMappable creates a new mappable that maps directly to host FD.
+func NewHostMappable(backingFile CachedFileObject) *HostMappable {
+ return &HostMappable{
+ hostFileMapper: NewHostFileMapper(),
+ backingFile: backingFile,
+ }
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (h *HostMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ // Hot path. Avoid defers.
+ h.mu.Lock()
+ mapped := h.mappings.AddMapping(ms, ar, offset, writable)
+ for _, r := range mapped {
+ h.hostFileMapper.IncRefOn(r)
+ }
+ h.mu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (h *HostMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ // Hot path. Avoid defers.
+ h.mu.Lock()
+ unmapped := h.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ h.hostFileMapper.DecRefOn(r)
+ }
+ h.mu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (h *HostMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return h.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (h *HostMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ return []memmap.Translation{
+ {
+ Source: optional,
+ File: h,
+ Offset: optional.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (h *HostMappable) InvalidateUnsavable(ctx context.Context) error {
+ h.mu.Lock()
+ h.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ h.mu.Unlock()
+ return nil
+}
+
+// MapInternal implements platform.File.MapInternal.
+func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
+}
+
+// FD implements platform.File.FD.
+func (h *HostMappable) FD() int {
+ return h.backingFile.FD()
+}
+
+// IncRef implements platform.File.IncRef.
+func (h *HostMappable) IncRef(fr platform.FileRange) {
+ mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
+ h.hostFileMapper.IncRefOn(mr)
+}
+
+// DecRef implements platform.File.DecRef.
+func (h *HostMappable) DecRef(fr platform.FileRange) {
+ mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
+ h.hostFileMapper.DecRefOn(mr)
+}
+
+// Truncate truncates the file, invalidating any mapping that may have been
+// removed after the size change.
+//
+// Truncation and writes are synchronized to prevent races where writes make the
+// file grow between truncation and invalidation below:
+// T1: Calls SetMaskedAttributes and stalls
+// T2: Appends to file causing it to grow
+// T2: Writes to mapped pages and COW happens
+// T1: Continues and wronly invalidates the page mapped in step above.
+func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
+ h.truncateMu.Lock()
+ defer h.truncateMu.Unlock()
+
+ mask := fs.AttrMask{Size: true}
+ attr := fs.UnstableAttr{Size: newSize}
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil {
+ return err
+ }
+
+ // Invalidate COW mappings that may exist beyond the new size in case the file
+ // is being shrunk. Other mappings don't need to be invalidated because
+ // translate will just return identical mappings after invalidation anyway,
+ // and SIGBUS will be raised and handled when the mappings are touched.
+ //
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ mr := memmap.MappableRange{
+ Start: fs.OffsetPageEnd(newSize),
+ End: fs.OffsetPageEnd(math.MaxInt64),
+ }
+ h.mappings.Invalidate(mr, memmap.InvalidateOpts{InvalidatePrivate: true})
+
+ return nil
+}
+
+// Allocate reserves space in the backing file.
+func (h *HostMappable) Allocate(ctx context.Context, offset int64, length int64) error {
+ h.truncateMu.RLock()
+ err := h.backingFile.Allocate(ctx, offset, length)
+ h.truncateMu.RUnlock()
+ return err
+}
+
+// Write writes to the file backing this mappable.
+func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ h.truncateMu.RLock()
+ n, err := src.CopyInTo(ctx, &writer{ctx: ctx, hostMappable: h, off: offset})
+ h.truncateMu.RUnlock()
+ return n, err
+}
+
+type writer struct {
+ ctx context.Context
+ hostMappable *HostMappable
+ off int64
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (w *writer) WriteFromBlocks(src safemem.BlockSeq) (uint64, error) {
+ n, err := w.hostMappable.backingFile.WriteFromBlocksAt(w.ctx, src, uint64(w.off))
+ w.off += int64(n)
+ return n, err
+}
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
new file mode 100644
index 000000000..925887335
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -0,0 +1,503 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// SimpleFileInode is a simple implementation of InodeOperations.
+//
+// +stateify savable
+type SimpleFileInode struct {
+ InodeGenericChecker `state:"nosave"`
+ InodeNoExtendedAttributes `state:"nosave"`
+ InodeNoopRelease `state:"nosave"`
+ InodeNoopWriteOut `state:"nosave"`
+ InodeNotAllocatable `state:"nosave"`
+ InodeNotDirectory `state:"nosave"`
+ InodeNotMappable `state:"nosave"`
+ InodeNotOpenable `state:"nosave"`
+ InodeNotSocket `state:"nosave"`
+ InodeNotSymlink `state:"nosave"`
+ InodeNotTruncatable `state:"nosave"`
+ InodeNotVirtual `state:"nosave"`
+
+ InodeSimpleAttributes
+}
+
+// NewSimpleFileInode returns a new SimpleFileInode.
+func NewSimpleFileInode(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, typ uint64) *SimpleFileInode {
+ return &SimpleFileInode{
+ InodeSimpleAttributes: NewInodeSimpleAttributes(ctx, owner, perms, typ),
+ }
+}
+
+// NoReadWriteFileInode is an implementation of InodeOperations that supports
+// opening files that are not readable or writeable.
+//
+// +stateify savable
+type NoReadWriteFileInode struct {
+ InodeGenericChecker `state:"nosave"`
+ InodeNoExtendedAttributes `state:"nosave"`
+ InodeNoopRelease `state:"nosave"`
+ InodeNoopWriteOut `state:"nosave"`
+ InodeNotAllocatable `state:"nosave"`
+ InodeNotDirectory `state:"nosave"`
+ InodeNotMappable `state:"nosave"`
+ InodeNotSocket `state:"nosave"`
+ InodeNotSymlink `state:"nosave"`
+ InodeNotTruncatable `state:"nosave"`
+ InodeNotVirtual `state:"nosave"`
+
+ InodeSimpleAttributes
+}
+
+// NewNoReadWriteFileInode returns a new NoReadWriteFileInode.
+func NewNoReadWriteFileInode(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, typ uint64) *NoReadWriteFileInode {
+ return &NoReadWriteFileInode{
+ InodeSimpleAttributes: NewInodeSimpleAttributes(ctx, owner, perms, typ),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (*NoReadWriteFileInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &NoReadWriteFile{}), nil
+}
+
+// InodeSimpleAttributes implements methods for updating in-memory unstable
+// attributes.
+//
+// +stateify savable
+type InodeSimpleAttributes struct {
+ // fsType is the immutable filesystem type that will be returned by
+ // StatFS.
+ fsType uint64
+
+ // mu protects unstable.
+ mu sync.RWMutex `state:"nosave"`
+ unstable fs.UnstableAttr
+}
+
+// NewInodeSimpleAttributes returns a new InodeSimpleAttributes with the given
+// owner and permissions, and all timestamps set to the current time.
+func NewInodeSimpleAttributes(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, typ uint64) InodeSimpleAttributes {
+ return NewInodeSimpleAttributesWithUnstable(fs.WithCurrentTime(ctx, fs.UnstableAttr{
+ Owner: owner,
+ Perms: perms,
+ }), typ)
+}
+
+// NewInodeSimpleAttributesWithUnstable returns a new InodeSimpleAttributes
+// with the given unstable attributes.
+func NewInodeSimpleAttributesWithUnstable(uattr fs.UnstableAttr, typ uint64) InodeSimpleAttributes {
+ return InodeSimpleAttributes{
+ fsType: typ,
+ unstable: uattr,
+ }
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (i *InodeSimpleAttributes) UnstableAttr(ctx context.Context, _ *fs.Inode) (fs.UnstableAttr, error) {
+ i.mu.RLock()
+ u := i.unstable
+ i.mu.RUnlock()
+ return u, nil
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (i *InodeSimpleAttributes) SetPermissions(ctx context.Context, _ *fs.Inode, p fs.FilePermissions) bool {
+ i.mu.Lock()
+ i.unstable.SetPermissions(ctx, p)
+ i.mu.Unlock()
+ return true
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (i *InodeSimpleAttributes) SetOwner(ctx context.Context, _ *fs.Inode, owner fs.FileOwner) error {
+ i.mu.Lock()
+ i.unstable.SetOwner(ctx, owner)
+ i.mu.Unlock()
+ return nil
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (i *InodeSimpleAttributes) SetTimestamps(ctx context.Context, _ *fs.Inode, ts fs.TimeSpec) error {
+ i.mu.Lock()
+ i.unstable.SetTimestamps(ctx, ts)
+ i.mu.Unlock()
+ return nil
+}
+
+// AddLink implements fs.InodeOperations.AddLink.
+func (i *InodeSimpleAttributes) AddLink() {
+ i.mu.Lock()
+ i.unstable.Links++
+ i.mu.Unlock()
+}
+
+// DropLink implements fs.InodeOperations.DropLink.
+func (i *InodeSimpleAttributes) DropLink() {
+ i.mu.Lock()
+ i.unstable.Links--
+ i.mu.Unlock()
+}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (i *InodeSimpleAttributes) StatFS(context.Context) (fs.Info, error) {
+ if i.fsType == 0 {
+ return fs.Info{}, syserror.ENOSYS
+ }
+ return fs.Info{Type: i.fsType}, nil
+}
+
+// NotifyAccess updates the access time.
+func (i *InodeSimpleAttributes) NotifyAccess(ctx context.Context) {
+ i.mu.Lock()
+ i.unstable.AccessTime = ktime.NowFromContext(ctx)
+ i.mu.Unlock()
+}
+
+// NotifyModification updates the modification time.
+func (i *InodeSimpleAttributes) NotifyModification(ctx context.Context) {
+ i.mu.Lock()
+ i.unstable.ModificationTime = ktime.NowFromContext(ctx)
+ i.mu.Unlock()
+}
+
+// NotifyStatusChange updates the status change time.
+func (i *InodeSimpleAttributes) NotifyStatusChange(ctx context.Context) {
+ i.mu.Lock()
+ i.unstable.StatusChangeTime = ktime.NowFromContext(ctx)
+ i.mu.Unlock()
+}
+
+// NotifyModificationAndStatusChange updates the modification and status change
+// times.
+func (i *InodeSimpleAttributes) NotifyModificationAndStatusChange(ctx context.Context) {
+ i.mu.Lock()
+ now := ktime.NowFromContext(ctx)
+ i.unstable.ModificationTime = now
+ i.unstable.StatusChangeTime = now
+ i.mu.Unlock()
+}
+
+// InodeSimpleExtendedAttributes implements
+// fs.InodeOperations.{Get,Set,List}xattr.
+//
+// +stateify savable
+type InodeSimpleExtendedAttributes struct {
+ // mu protects xattrs.
+ mu sync.RWMutex `state:"nosave"`
+ xattrs map[string]string
+}
+
+// Getxattr implements fs.InodeOperations.Getxattr.
+func (i *InodeSimpleExtendedAttributes) Getxattr(_ *fs.Inode, name string) (string, error) {
+ i.mu.RLock()
+ value, ok := i.xattrs[name]
+ i.mu.RUnlock()
+ if !ok {
+ return "", syserror.ENOATTR
+ }
+ return value, nil
+}
+
+// Setxattr implements fs.InodeOperations.Setxattr.
+func (i *InodeSimpleExtendedAttributes) Setxattr(_ *fs.Inode, name, value string) error {
+ i.mu.Lock()
+ if i.xattrs == nil {
+ i.xattrs = make(map[string]string)
+ }
+ i.xattrs[name] = value
+ i.mu.Unlock()
+ return nil
+}
+
+// Listxattr implements fs.InodeOperations.Listxattr.
+func (i *InodeSimpleExtendedAttributes) Listxattr(_ *fs.Inode) (map[string]struct{}, error) {
+ i.mu.RLock()
+ names := make(map[string]struct{}, len(i.xattrs))
+ for name := range i.xattrs {
+ names[name] = struct{}{}
+ }
+ i.mu.RUnlock()
+ return names, nil
+}
+
+// staticFile is a file with static contents. It is returned by
+// InodeStaticFileGetter.GetFile.
+//
+// +stateify savable
+type staticFile struct {
+ FileGenericSeek `state:"nosave"`
+ FileNoIoctl `state:"nosave"`
+ FileNoMMap `state:"nosave"`
+ FileNoSplice `state:"nosave"`
+ FileNoopFsync `state:"nosave"`
+ FileNoopFlush `state:"nosave"`
+ FileNoopRelease `state:"nosave"`
+ FileNoopWrite `state:"nosave"`
+ FileNotDirReaddir `state:"nosave"`
+ FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ FileStaticContentReader
+}
+
+// InodeNoStatFS implement StatFS by retuning ENOSYS.
+type InodeNoStatFS struct{}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (InodeNoStatFS) StatFS(context.Context) (fs.Info, error) {
+ return fs.Info{}, syserror.ENOSYS
+}
+
+// InodeStaticFileGetter implements GetFile for a file with static contents.
+//
+// +stateify savable
+type InodeStaticFileGetter struct {
+ Contents []byte
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *InodeStaticFileGetter) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &staticFile{
+ FileStaticContentReader: NewFileStaticContentReader(i.Contents),
+ }), nil
+}
+
+// InodeNotMappable returns a nil memmap.Mappable.
+type InodeNotMappable struct{}
+
+// Mappable implements fs.InodeOperations.Mappable.
+func (InodeNotMappable) Mappable(*fs.Inode) memmap.Mappable {
+ return nil
+}
+
+// InodeNoopWriteOut is a no-op implementation of fs.InodeOperations.WriteOut.
+type InodeNoopWriteOut struct{}
+
+// WriteOut is a no-op.
+func (InodeNoopWriteOut) WriteOut(context.Context, *fs.Inode) error {
+ return nil
+}
+
+// InodeNotDirectory can be used by Inodes that are not directories.
+type InodeNotDirectory struct{}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (InodeNotDirectory) Lookup(context.Context, *fs.Inode, string) (*fs.Dirent, error) {
+ return nil, syserror.ENOTDIR
+}
+
+// Create implements fs.InodeOperations.Create.
+func (InodeNotDirectory) Create(context.Context, *fs.Inode, string, fs.FileFlags, fs.FilePermissions) (*fs.File, error) {
+ return nil, syserror.ENOTDIR
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+func (InodeNotDirectory) CreateLink(context.Context, *fs.Inode, string, string) error {
+ return syserror.ENOTDIR
+}
+
+// CreateHardLink implements fs.InodeOperations.CreateHardLink.
+func (InodeNotDirectory) CreateHardLink(context.Context, *fs.Inode, *fs.Inode, string) error {
+ return syserror.ENOTDIR
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+func (InodeNotDirectory) CreateDirectory(context.Context, *fs.Inode, string, fs.FilePermissions) error {
+ return syserror.ENOTDIR
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (InodeNotDirectory) Bind(context.Context, *fs.Inode, string, transport.BoundEndpoint, fs.FilePermissions) (*fs.Dirent, error) {
+ return nil, syserror.ENOTDIR
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (InodeNotDirectory) CreateFifo(context.Context, *fs.Inode, string, fs.FilePermissions) error {
+ return syserror.ENOTDIR
+}
+
+// Remove implements fs.InodeOperations.Remove.
+func (InodeNotDirectory) Remove(context.Context, *fs.Inode, string) error {
+ return syserror.ENOTDIR
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+func (InodeNotDirectory) RemoveDirectory(context.Context, *fs.Inode, string) error {
+ return syserror.ENOTDIR
+}
+
+// Rename implements fs.FileOperations.Rename.
+func (InodeNotDirectory) Rename(context.Context, *fs.Inode, *fs.Inode, string, *fs.Inode, string, bool) error {
+ return syserror.EINVAL
+}
+
+// InodeNotSocket can be used by Inodes that are not sockets.
+type InodeNotSocket struct{}
+
+// BoundEndpoint implements fs.InodeOperations.BoundEndpoint.
+func (InodeNotSocket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint {
+ return nil
+}
+
+// InodeNotTruncatable can be used by Inodes that cannot be truncated.
+type InodeNotTruncatable struct{}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (InodeNotTruncatable) Truncate(context.Context, *fs.Inode, int64) error {
+ return syserror.EINVAL
+}
+
+// InodeIsDirTruncate implements fs.InodeOperations.Truncate for directories.
+type InodeIsDirTruncate struct{}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (InodeIsDirTruncate) Truncate(context.Context, *fs.Inode, int64) error {
+ return syserror.EISDIR
+}
+
+// InodeNoopTruncate implements fs.InodeOperations.Truncate as a noop.
+type InodeNoopTruncate struct{}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (InodeNoopTruncate) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// InodeNotRenameable can be used by Inodes that cannot be truncated.
+type InodeNotRenameable struct{}
+
+// Rename implements fs.InodeOperations.Rename.
+func (InodeNotRenameable) Rename(context.Context, *fs.Inode, *fs.Inode, string, *fs.Inode, string, bool) error {
+ return syserror.EINVAL
+}
+
+// InodeNotOpenable can be used by Inodes that cannot be opened.
+type InodeNotOpenable struct{}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (InodeNotOpenable) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error) {
+ return nil, syserror.EIO
+}
+
+// InodeNotVirtual can be used by Inodes that are not virtual.
+type InodeNotVirtual struct{}
+
+// IsVirtual implements fs.InodeOperations.IsVirtual.
+func (InodeNotVirtual) IsVirtual() bool {
+ return false
+}
+
+// InodeVirtual can be used by Inodes that are virtual.
+type InodeVirtual struct{}
+
+// IsVirtual implements fs.InodeOperations.IsVirtual.
+func (InodeVirtual) IsVirtual() bool {
+ return true
+}
+
+// InodeNotSymlink can be used by Inodes that are not symlinks.
+type InodeNotSymlink struct{}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (InodeNotSymlink) Readlink(context.Context, *fs.Inode) (string, error) {
+ return "", syserror.ENOLINK
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (InodeNotSymlink) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ return nil, syserror.ENOLINK
+}
+
+// InodeNoExtendedAttributes can be used by Inodes that do not support
+// extended attributes.
+type InodeNoExtendedAttributes struct{}
+
+// Getxattr implements fs.InodeOperations.Getxattr.
+func (InodeNoExtendedAttributes) Getxattr(*fs.Inode, string) (string, error) {
+ return "", syserror.EOPNOTSUPP
+}
+
+// Setxattr implements fs.InodeOperations.Setxattr.
+func (InodeNoExtendedAttributes) Setxattr(*fs.Inode, string, string) error {
+ return syserror.EOPNOTSUPP
+}
+
+// Listxattr implements fs.InodeOperations.Listxattr.
+func (InodeNoExtendedAttributes) Listxattr(*fs.Inode) (map[string]struct{}, error) {
+ return nil, syserror.EOPNOTSUPP
+}
+
+// InodeNoopRelease implements fs.InodeOperations.Release as a noop.
+type InodeNoopRelease struct{}
+
+// Release implements fs.InodeOperations.Release.
+func (InodeNoopRelease) Release(context.Context) {}
+
+// InodeGenericChecker implements fs.InodeOperations.Check with a generic
+// implementation.
+type InodeGenericChecker struct{}
+
+// Check implements fs.InodeOperations.Check.
+func (InodeGenericChecker) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// InodeDenyWriteChecker implements fs.InodeOperations.Check which denies all
+// write operations.
+type InodeDenyWriteChecker struct{}
+
+// Check implements fs.InodeOperations.Check.
+func (InodeDenyWriteChecker) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ if p.Write {
+ return false
+ }
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+//InodeNotAllocatable can be used by Inodes that do not support Allocate().
+type InodeNotAllocatable struct{}
+
+func (InodeNotAllocatable) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
+ return syserror.EOPNOTSUPP
+}
+
+// InodeNoopAllocate implements fs.InodeOperations.Allocate as a noop.
+type InodeNoopAllocate struct{}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (InodeNoopAllocate) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
+ return nil
+}
+
+// InodeIsDirAllocate implements fs.InodeOperations.Allocate for directories.
+type InodeIsDirAllocate struct{}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (InodeIsDirAllocate) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
+ return syserror.EISDIR
+}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
new file mode 100644
index 000000000..7bee2eb5f
--- /dev/null
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -0,0 +1,1004 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsutil
+
+import (
+ "fmt"
+ "io"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Lock order (compare the lock order model in mm/mm.go):
+//
+// CachingInodeOperations.attrMu ("fs locks")
+// CachingInodeOperations.mapsMu ("memmap.Mappable locks not taken by Translate")
+// CachingInodeOperations.dataMu ("memmap.Mappable locks taken by Translate")
+// CachedFileObject locks
+
+// CachingInodeOperations caches the metadata and content of a CachedFileObject.
+// It implements a subset of InodeOperations. As a utility it can be used to
+// implement the full set of InodeOperations. Generally it should not be
+// embedded to avoid unexpected inherited behavior.
+//
+// CachingInodeOperations implements Mappable for the CachedFileObject:
+//
+// - If CachedFileObject.FD returns a value >= 0 then the file descriptor
+// will be memory mapped on the host.
+//
+// - Otherwise, the contents of CachedFileObject are buffered into memory
+// managed by the CachingInodeOperations.
+//
+// Implementations of FileOperations for a CachedFileObject must read and
+// write through CachingInodeOperations using Read and Write respectively.
+//
+// Implementations of InodeOperations.WriteOut must call Sync to write out
+// in-memory modifications of data and metadata to the CachedFileObject.
+//
+// +stateify savable
+type CachingInodeOperations struct {
+ // backingFile is a handle to a cached file object.
+ backingFile CachedFileObject
+
+ // mfp is used to allocate memory that caches backingFile's contents.
+ mfp pgalloc.MemoryFileProvider
+
+ // forcePageCache indicates the sentry page cache should be used regardless
+ // of whether the platform supports host mapped I/O or not. This must not be
+ // modified after inode creation.
+ forcePageCache bool
+
+ attrMu sync.Mutex `state:"nosave"`
+
+ // attr is unstable cached metadata.
+ //
+ // attr is protected by attrMu. attr.Size is protected by both attrMu and
+ // dataMu; reading it requires locking either mutex, while mutating it
+ // requires locking both.
+ attr fs.UnstableAttr
+
+ // dirtyAttr is metadata that was updated in-place but hasn't yet
+ // been successfully written out.
+ //
+ // dirtyAttr is protected by attrMu.
+ dirtyAttr fs.AttrMask
+
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the cached file object into
+ // memmap.MappingSpaces.
+ //
+ // mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ dataMu sync.RWMutex `state:"nosave"`
+
+ // cache maps offsets into the cached file to offsets into
+ // mfp.MemoryFile() that store the file's data.
+ //
+ // cache is protected by dataMu.
+ cache FileRangeSet
+
+ // dirty tracks dirty segments in cache.
+ //
+ // dirty is protected by dataMu.
+ dirty DirtySet
+
+ // hostFileMapper caches internal mappings of backingFile.FD().
+ hostFileMapper *HostFileMapper
+
+ // refs tracks active references to data in the cache.
+ //
+ // refs is protected by dataMu.
+ refs frameRefSet
+}
+
+// CachedFileObject is a file that may require caching.
+type CachedFileObject interface {
+ // ReadToBlocksAt reads up to dsts.NumBytes() bytes from the file to dsts,
+ // starting at offset, and returns the number of bytes read. ReadToBlocksAt
+ // may return a partial read without an error.
+ ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)
+
+ // WriteFromBlocksAt writes up to srcs.NumBytes() bytes from srcs to the
+ // file, starting at offset, and returns the number of bytes written.
+ // WriteFromBlocksAt may return a partial write without an error.
+ WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)
+
+ // SetMaskedAttributes sets the attributes in attr that are true in mask
+ // on the backing file.
+ //
+ // SetMaskedAttributes may be called at any point, regardless of whether
+ // the file was opened.
+ SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error
+
+ // Allocate allows the caller to reserve disk space for the inode.
+ // It's equivalent to fallocate(2) with 'mode=0'.
+ Allocate(ctx context.Context, offset int64, length int64) error
+
+ // Sync instructs the remote filesystem to sync the file to stable storage.
+ Sync(ctx context.Context) error
+
+ // FD returns a host file descriptor. If it is possible for
+ // CachingInodeOperations.AddMapping to have ever been called with writable
+ // = true, the FD must have been opened O_RDWR; otherwise, it may have been
+ // opened O_RDONLY or O_RDWR. (mmap unconditionally requires that mapped
+ // files are readable.) If no host file descriptor is available, FD returns
+ // a negative number.
+ //
+ // For any given CachedFileObject, if FD() ever succeeds (returns a
+ // non-negative number), it must always succeed.
+ //
+ // FD is called iff the file has been memory mapped. This implies that
+ // the file was opened (see fs.InodeOperations.GetFile).
+ FD() int
+}
+
+// NewCachingInodeOperations returns a new CachingInodeOperations backed by
+// a CachedFileObject and its initial unstable attributes.
+func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, forcePageCache bool) *CachingInodeOperations {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
+ panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
+ }
+ return &CachingInodeOperations{
+ backingFile: backingFile,
+ mfp: mfp,
+ forcePageCache: forcePageCache,
+ attr: uattr,
+ hostFileMapper: NewHostFileMapper(),
+ }
+}
+
+// Release implements fs.InodeOperations.Release.
+func (c *CachingInodeOperations) Release() {
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+
+ // Something has gone terribly wrong if we're releasing an inode that is
+ // still memory-mapped.
+ if !c.mappings.IsEmpty() {
+ panic(fmt.Sprintf("Releasing CachingInodeOperations with mappings:\n%s", &c.mappings))
+ }
+
+ // Drop any cached pages that are still awaiting MemoryFile eviction. (This
+ // means that MemoryFile no longer needs to evict them.)
+ mf := c.mfp.MemoryFile()
+ mf.MarkAllUnevictable(c)
+ if err := SyncDirtyAll(context.Background(), &c.cache, &c.dirty, uint64(c.attr.Size), mf, c.backingFile.WriteFromBlocksAt); err != nil {
+ panic(fmt.Sprintf("Failed to writeback cached data: %v", err))
+ }
+ c.cache.DropAll(mf)
+ c.dirty.RemoveAll()
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (c *CachingInodeOperations) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ c.attrMu.Lock()
+ attr := c.attr
+ c.attrMu.Unlock()
+ return attr, nil
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.Inode, perms fs.FilePermissions) bool {
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+
+ now := ktime.NowFromContext(ctx)
+ masked := fs.AttrMask{Perms: true}
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil {
+ return false
+ }
+ c.attr.Perms = perms
+ c.touchStatusChangeTimeLocked(now)
+ return true
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode, owner fs.FileOwner) error {
+ if !owner.UID.Ok() && !owner.GID.Ok() {
+ return nil
+ }
+
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+
+ now := ktime.NowFromContext(ctx)
+ masked := fs.AttrMask{
+ UID: owner.UID.Ok(),
+ GID: owner.GID.Ok(),
+ }
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil {
+ return err
+ }
+ if owner.UID.Ok() {
+ c.attr.Owner.UID = owner.UID
+ }
+ if owner.GID.Ok() {
+ c.attr.Owner.GID = owner.GID
+ }
+ c.touchStatusChangeTimeLocked(now)
+ return nil
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.Inode, ts fs.TimeSpec) error {
+ if ts.ATimeOmit && ts.MTimeOmit {
+ return nil
+ }
+
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+
+ // Replace requests to use the "system time" with the current time to
+ // ensure that cached timestamps remain consistent with the remote
+ // filesystem.
+ now := ktime.NowFromContext(ctx)
+ if ts.ATimeSetSystemTime {
+ ts.ATime = now
+ }
+ if ts.MTimeSetSystemTime {
+ ts.MTime = now
+ }
+ masked := fs.AttrMask{
+ AccessTime: !ts.ATimeOmit,
+ ModificationTime: !ts.MTimeOmit,
+ }
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil {
+ return err
+ }
+ if !ts.ATimeOmit {
+ c.attr.AccessTime = ts.ATime
+ }
+ if !ts.MTimeOmit {
+ c.attr.ModificationTime = ts.MTime
+ }
+ c.touchStatusChangeTimeLocked(now)
+ return nil
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode, size int64) error {
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+
+ // c.attr.Size is protected by both c.attrMu and c.dataMu.
+ c.dataMu.Lock()
+ now := ktime.NowFromContext(ctx)
+ masked := fs.AttrMask{Size: true}
+ attr := fs.UnstableAttr{Size: size}
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil {
+ c.dataMu.Unlock()
+ return err
+ }
+ oldSize := c.attr.Size
+ c.attr.Size = size
+ c.touchModificationAndStatusChangeTimeLocked(now)
+
+ // We drop c.dataMu here so that we can lock c.mapsMu and invalidate
+ // mappings below. This allows concurrent calls to Read/Translate/etc.
+ // These functions synchronize with an in-progress Truncate by refusing to
+ // use cache contents beyond the new c.attr.Size. (We are still holding
+ // c.attrMu, so we can't race with Truncate/Write.)
+ c.dataMu.Unlock()
+
+ // Nothing left to do unless shrinking the file.
+ if size >= oldSize {
+ return nil
+ }
+
+ oldpgend := fs.OffsetPageEnd(oldSize)
+ newpgend := fs.OffsetPageEnd(size)
+
+ // Invalidate past translations of truncated pages.
+ if newpgend != oldpgend {
+ c.mapsMu.Lock()
+ c.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ c.mapsMu.Unlock()
+ }
+
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them from the cache. Since truncated pages have been
+ // removed from the backing file, they should be dropped without being
+ // written back.
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+ c.cache.Truncate(uint64(size), c.mfp.MemoryFile())
+ c.dirty.KeepClean(memmap.MappableRange{uint64(size), oldpgend})
+
+ return nil
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (c *CachingInodeOperations) Allocate(ctx context.Context, offset, length int64) error {
+ newSize := offset + length
+
+ // c.attr.Size is protected by both c.attrMu and c.dataMu.
+ c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+
+ if newSize <= c.attr.Size {
+ return nil
+ }
+
+ now := ktime.NowFromContext(ctx)
+ if err := c.backingFile.Allocate(ctx, offset, length); err != nil {
+ return err
+ }
+
+ c.attr.Size = newSize
+ c.touchModificationAndStatusChangeTimeLocked(now)
+ return nil
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ c.attrMu.Lock()
+
+ // Write dirty pages back.
+ c.dataMu.Lock()
+ err := SyncDirtyAll(ctx, &c.cache, &c.dirty, uint64(c.attr.Size), c.mfp.MemoryFile(), c.backingFile.WriteFromBlocksAt)
+ c.dataMu.Unlock()
+ if err != nil {
+ c.attrMu.Unlock()
+ return err
+ }
+
+ // SyncDirtyAll above would have grown if needed. On shrinks, the backing
+ // file is called directly, so size is never needs to be updated.
+ c.dirtyAttr.Size = false
+
+ // Write out cached attributes.
+ if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil {
+ c.attrMu.Unlock()
+ return err
+ }
+ c.dirtyAttr = fs.AttrMask{}
+
+ c.attrMu.Unlock()
+
+ // Fsync the remote file.
+ return c.backingFile.Sync(ctx)
+}
+
+// IncLinks increases the link count and updates cached modification time.
+func (c *CachingInodeOperations) IncLinks(ctx context.Context) {
+ c.attrMu.Lock()
+ c.attr.Links++
+ c.touchModificationAndStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// DecLinks decreases the link count and updates cached modification time.
+func (c *CachingInodeOperations) DecLinks(ctx context.Context) {
+ c.attrMu.Lock()
+ c.attr.Links--
+ c.touchModificationAndStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// TouchAccessTime updates the cached access time in-place to the
+// current time. It does not update status change time in-place. See
+// mm/filemap.c:do_generic_file_read -> include/linux/h:file_accessed.
+func (c *CachingInodeOperations) TouchAccessTime(ctx context.Context, inode *fs.Inode) {
+ if inode.MountSource.Flags.NoAtime {
+ return
+ }
+
+ c.attrMu.Lock()
+ c.touchAccessTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// touchAccesstimeLocked updates the cached access time in-place to the current
+// time.
+//
+// Preconditions: c.attrMu is locked for writing.
+func (c *CachingInodeOperations) touchAccessTimeLocked(now time.Time) {
+ c.attr.AccessTime = now
+ c.dirtyAttr.AccessTime = true
+}
+
+// TouchModificationAndStatusChangeTime updates the cached modification and
+// status change times in-place to the current time.
+func (c *CachingInodeOperations) TouchModificationAndStatusChangeTime(ctx context.Context) {
+ c.attrMu.Lock()
+ c.touchModificationAndStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// touchModificationAndStatusChangeTimeLocked updates the cached modification
+// and status change times in-place to the current time.
+//
+// Preconditions: c.attrMu is locked for writing.
+func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now time.Time) {
+ c.attr.ModificationTime = now
+ c.dirtyAttr.ModificationTime = true
+ c.attr.StatusChangeTime = now
+ c.dirtyAttr.StatusChangeTime = true
+}
+
+// TouchStatusChangeTime updates the cached status change time in-place to the
+// current time.
+func (c *CachingInodeOperations) TouchStatusChangeTime(ctx context.Context) {
+ c.attrMu.Lock()
+ c.touchStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ c.attrMu.Unlock()
+}
+
+// touchStatusChangeTimeLocked updates the cached status change time
+// in-place to the current time.
+//
+// Preconditions: c.attrMu is locked for writing.
+func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now time.Time) {
+ c.attr.StatusChangeTime = now
+ c.dirtyAttr.StatusChangeTime = true
+}
+
+// UpdateUnstable updates the cached unstable attributes. Only non-dirty
+// attributes are updated.
+func (c *CachingInodeOperations) UpdateUnstable(attr fs.UnstableAttr) {
+ // All attributes are protected by attrMu.
+ c.attrMu.Lock()
+
+ if !c.dirtyAttr.Usage {
+ c.attr.Usage = attr.Usage
+ }
+ if !c.dirtyAttr.Perms {
+ c.attr.Perms = attr.Perms
+ }
+ if !c.dirtyAttr.UID {
+ c.attr.Owner.UID = attr.Owner.UID
+ }
+ if !c.dirtyAttr.GID {
+ c.attr.Owner.GID = attr.Owner.GID
+ }
+ if !c.dirtyAttr.AccessTime {
+ c.attr.AccessTime = attr.AccessTime
+ }
+ if !c.dirtyAttr.ModificationTime {
+ c.attr.ModificationTime = attr.ModificationTime
+ }
+ if !c.dirtyAttr.StatusChangeTime {
+ c.attr.StatusChangeTime = attr.StatusChangeTime
+ }
+ if !c.dirtyAttr.Links {
+ c.attr.Links = attr.Links
+ }
+
+ // Size requires holding attrMu and dataMu.
+ c.dataMu.Lock()
+ if !c.dirtyAttr.Size {
+ c.attr.Size = attr.Size
+ }
+ c.dataMu.Unlock()
+
+ c.attrMu.Unlock()
+}
+
+// Read reads from frames and otherwise directly from the backing file
+// into dst starting at offset until dst is full, EOF is reached, or an
+// error is encountered.
+//
+// Read may partially fill dst and return a nil error.
+func (c *CachingInodeOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Have we reached EOF? We check for this again in
+ // inodeReadWriter.ReadToBlocks to avoid holding c.attrMu (which would
+ // serialize reads) or c.dataMu (which would violate lock ordering), but
+ // check here first (before calling into MM) since reading at EOF is
+ // common: getting a return value of 0 from a read syscall is the only way
+ // to detect EOF.
+ //
+ // TODO(jamieliu): Separate out c.attr.Size and use atomics instead of
+ // c.dataMu.
+ c.dataMu.RLock()
+ size := c.attr.Size
+ c.dataMu.RUnlock()
+ if offset >= size {
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOutFrom(ctx, &inodeReadWriter{ctx, c, offset})
+ // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
+ c.TouchAccessTime(ctx, file.Dirent.Inode)
+ return n, err
+}
+
+// Write writes to frames and otherwise directly to the backing file
+// from src starting at offset and until src is empty or an error is
+// encountered.
+//
+// If Write partially fills src, a non-nil error is returned.
+func (c *CachingInodeOperations) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // Hot path. Avoid defers.
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ c.attrMu.Lock()
+ // Compare Linux's mm/filemap.c:__generic_file_write_iter() => file_update_time().
+ c.touchModificationAndStatusChangeTimeLocked(ktime.NowFromContext(ctx))
+ n, err := src.CopyInTo(ctx, &inodeReadWriter{ctx, c, offset})
+ c.attrMu.Unlock()
+ return n, err
+}
+
+type inodeReadWriter struct {
+ ctx context.Context
+ c *CachingInodeOperations
+ offset int64
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ // Hot path. Avoid defers.
+ rw.c.dataMu.RLock()
+
+ // Compute the range to read.
+ if rw.offset >= rw.c.attr.Size {
+ rw.c.dataMu.RUnlock()
+ return 0, io.EOF
+ }
+ end := fs.ReadEndOffset(rw.offset, int64(dsts.NumBytes()), rw.c.attr.Size)
+ if end == rw.offset { // dsts.NumBytes() == 0?
+ rw.c.dataMu.RUnlock()
+ return 0, nil
+ }
+
+ mem := rw.c.mfp.MemoryFile()
+ var done uint64
+ seg, gap := rw.c.cache.Find(uint64(rw.offset))
+ for rw.offset < end {
+ mr := memmap.MappableRange{uint64(rw.offset), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings from the cache.
+ ims, err := mem.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ rw.c.dataMu.RUnlock()
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ rw.c.dataMu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Read directly from the backing file.
+ gapmr := gap.Range().Intersect(mr)
+ dst := dsts.TakeFirst64(gapmr.Length())
+ n, err := rw.c.backingFile.ReadToBlocksAt(rw.ctx, dst, gapmr.Start)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ // Partial reads are fine. But we must stop reading.
+ if n != dst.NumBytes() || err != nil {
+ rw.c.dataMu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), FileRangeGapIterator{}
+
+ default:
+ break
+ }
+ }
+ rw.c.dataMu.RUnlock()
+ return done, nil
+}
+
+// maybeGrowFile grows the file's size if data has been written past the old
+// size.
+//
+// Preconditions: rw.c.attrMu and rw.c.dataMu bust be locked.
+func (rw *inodeReadWriter) maybeGrowFile() {
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.offset > rw.c.attr.Size {
+ rw.c.attr.Size = rw.offset
+ rw.c.dirtyAttr.Size = true
+ }
+ if rw.offset > rw.c.attr.Usage {
+ // This is incorrect if CachingInodeOperations is caching a sparse
+ // file. (In Linux, keeping inode::i_blocks up to date is the
+ // filesystem's responsibility.)
+ rw.c.attr.Usage = rw.offset
+ rw.c.dirtyAttr.Usage = true
+ }
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: rw.c.attrMu must be locked.
+func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ // Hot path. Avoid defers.
+ rw.c.dataMu.Lock()
+
+ // Compute the range to write.
+ end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
+ if end == rw.offset { // srcs.NumBytes() == 0?
+ rw.c.dataMu.Unlock()
+ return 0, nil
+ }
+
+ mf := rw.c.mfp.MemoryFile()
+ var done uint64
+ seg, gap := rw.c.cache.Find(uint64(rw.offset))
+ for rw.offset < end {
+ mr := memmap.MappableRange{uint64(rw.offset), uint64(end)}
+ switch {
+ case seg.Ok() && seg.Start() < mr.End:
+ // Get internal mappings from the cache.
+ segMR := seg.Range().Intersect(mr)
+ ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write)
+ if err != nil {
+ rw.maybeGrowFile()
+ rw.c.dataMu.Unlock()
+ return done, err
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.offset += int64(n)
+ srcs = srcs.DropFirst64(n)
+ rw.c.dirty.MarkDirty(segMR)
+ if err != nil {
+ rw.maybeGrowFile()
+ rw.c.dataMu.Unlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok() && gap.Start() < mr.End:
+ // Write directly to the backing file.
+ gapmr := gap.Range().Intersect(mr)
+ src := srcs.TakeFirst64(gapmr.Length())
+ n, err := rw.c.backingFile.WriteFromBlocksAt(rw.ctx, src, gapmr.Start)
+ done += n
+ rw.offset += int64(n)
+ srcs = srcs.DropFirst64(n)
+ // Partial writes are fine. But we must stop writing.
+ if n != src.NumBytes() || err != nil {
+ rw.maybeGrowFile()
+ rw.c.dataMu.Unlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), FileRangeGapIterator{}
+
+ default:
+ break
+ }
+ }
+ rw.maybeGrowFile()
+ rw.c.dataMu.Unlock()
+ return done, nil
+}
+
+// useHostPageCache returns true if c uses c.backingFile.FD() for all file I/O
+// and memory mappings, and false if c.cache may contain data cached from
+// c.backingFile.
+func (c *CachingInodeOperations) useHostPageCache() bool {
+ return !c.forcePageCache && c.backingFile.FD() >= 0
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ // Hot path. Avoid defers.
+ c.mapsMu.Lock()
+ mapped := c.mappings.AddMapping(ms, ar, offset, writable)
+ // Do this unconditionally since whether we have c.backingFile.FD() >= 0
+ // can change across save/restore.
+ for _, r := range mapped {
+ c.hostFileMapper.IncRefOn(r)
+ }
+ if !c.useHostPageCache() {
+ // c.Evict() will refuse to evict memory-mapped pages, so tell the
+ // MemoryFile to not bother trying.
+ mf := c.mfp.MemoryFile()
+ for _, r := range mapped {
+ mf.MarkUnevictable(c, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ }
+ if c.useHostPageCache() && !usage.IncrementalMappedAccounting {
+ for _, r := range mapped {
+ usage.MemoryAccounting.Inc(r.Length(), usage.Mapped)
+ }
+ }
+ c.mapsMu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ // Hot path. Avoid defers.
+ c.mapsMu.Lock()
+ unmapped := c.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ c.hostFileMapper.DecRefOn(r)
+ }
+ if c.useHostPageCache() {
+ if !usage.IncrementalMappedAccounting {
+ for _, r := range unmapped {
+ usage.MemoryAccounting.Dec(r.Length(), usage.Mapped)
+ }
+ }
+ c.mapsMu.Unlock()
+ return
+ }
+
+ // Pages that are no longer referenced by any application memory mappings
+ // are now considered unused; allow MemoryFile to evict them when
+ // necessary.
+ mf := c.mfp.MemoryFile()
+ c.dataMu.Lock()
+ for _, r := range unmapped {
+ // Since these pages are no longer mapped, they are no longer
+ // concurrently dirtyable by a writable memory mapping.
+ c.dirty.AllowClean(r)
+ mf.MarkEvictable(c, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ c.dataMu.Unlock()
+ c.mapsMu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return c.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ // Hot path. Avoid defer.
+ if c.useHostPageCache() {
+ return []memmap.Translation{
+ {
+ Source: optional,
+ File: c,
+ Offset: optional.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+ }
+
+ c.dataMu.Lock()
+
+ // Constrain translations to c.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(c.attr.Size)
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ c.dataMu.Unlock()
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ mf := c.mfp.MemoryFile()
+ cerr := c.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, c.backingFile.ReadToBlocksAt)
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := c.cache.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ // TODO(jamieliu): Make Translations writable even if writability is
+ // not required if already kept-dirty by another writable translation.
+ perms := usermem.AccessType{
+ Read: true,
+ Execute: true,
+ }
+ if at.Write {
+ // From this point forward, this memory can be dirtied through the
+ // mapping at any time.
+ c.dirty.KeepDirty(segMR)
+ perms.Write = true
+ }
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: mf,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: perms,
+ })
+ translatedEnd = segMR.End
+ }
+
+ c.dataMu.Unlock()
+
+ // Don't return the error returned by c.cache.Fill if it occurred outside
+ // of required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange {
+ const maxReadahead = 64 << 10 // 64 KB, chosen arbitrarily
+ if required.Length() >= maxReadahead {
+ return required
+ }
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.Start = required.Start
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.End = optional.Start + maxReadahead
+ return optional
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error {
+ // Whether we have a host fd (and consequently what platform.File is
+ // mapped) can change across save/restore, so invalidate all translations
+ // unconditionally.
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+ c.mappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Sync the cache's contents so that if we have a host fd after restore,
+ // the remote file's contents are coherent.
+ mf := c.mfp.MemoryFile()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+ if err := SyncDirtyAll(ctx, &c.cache, &c.dirty, uint64(c.attr.Size), mf, c.backingFile.WriteFromBlocksAt); err != nil {
+ return err
+ }
+
+ // Discard the cache so that it's not stored in saved state. This is safe
+ // because per InvalidateUnsavable invariants, no new translations can have
+ // been returned after we invalidated all existing translations above.
+ c.cache.DropAll(mf)
+ c.dirty.RemoveAll()
+
+ return nil
+}
+
+// Evict implements pgalloc.EvictableMemoryUser.Evict.
+func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.EvictableRange) {
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
+
+ mr := memmap.MappableRange{er.Start, er.End}
+ mf := c.mfp.MemoryFile()
+ // Only allow pages that are no longer memory-mapped to be evicted.
+ for mgap := c.mappings.LowerBoundGap(mr.Start); mgap.Ok() && mgap.Start() < mr.End; mgap = mgap.NextGap() {
+ mgapMR := mgap.Range().Intersect(mr)
+ if mgapMR.Length() == 0 {
+ continue
+ }
+ if err := SyncDirty(ctx, mgapMR, &c.cache, &c.dirty, uint64(c.attr.Size), mf, c.backingFile.WriteFromBlocksAt); err != nil {
+ log.Warningf("Failed to writeback cached data %v: %v", mgapMR, err)
+ }
+ c.cache.Drop(mgapMR, mf)
+ c.dirty.KeepClean(mgapMR)
+ }
+}
+
+// IncRef implements platform.File.IncRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the platform.File
+// during translation.
+func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
+ // Hot path. Avoid defers.
+ c.dataMu.Lock()
+ seg, gap := c.refs.Find(fr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < fr.End:
+ seg = c.refs.Isolate(seg, fr)
+ seg.SetValue(seg.Value() + 1)
+ seg, gap = seg.NextNonEmpty()
+ case gap.Ok() && gap.Start() < fr.End:
+ newRange := gap.Range().Intersect(fr)
+ if usage.IncrementalMappedAccounting {
+ usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
+ }
+ seg, gap = c.refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty()
+ default:
+ c.refs.MergeAdjacent(fr)
+ c.dataMu.Unlock()
+ return
+ }
+ }
+}
+
+// DecRef implements platform.File.DecRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the platform.File
+// during translation.
+func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
+ // Hot path. Avoid defers.
+ c.dataMu.Lock()
+ seg := c.refs.FindSegment(fr.Start)
+
+ for seg.Ok() && seg.Start() < fr.End {
+ seg = c.refs.Isolate(seg, fr)
+ if old := seg.Value(); old == 1 {
+ if usage.IncrementalMappedAccounting {
+ usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
+ }
+ seg = c.refs.Remove(seg).NextSegment()
+ } else {
+ seg.SetValue(old - 1)
+ seg = seg.NextSegment()
+ }
+ }
+ c.refs.MergeAdjacent(fr)
+ c.dataMu.Unlock()
+
+}
+
+// MapInternal implements platform.File.MapInternal. This is used when we
+// directly map an underlying host fd and CachingInodeOperations is used as the
+// platform.File during translation.
+func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write)
+}
+
+// FD implements platform.File.FD. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the platform.File
+// during translation.
+func (c *CachingInodeOperations) FD() int {
+ return c.backingFile.FD()
+}
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
new file mode 100644
index 000000000..c572f3396
--- /dev/null
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -0,0 +1,162 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// getattr returns the 9p attributes of the p9.File. On success, Mode, Size, and RDev
+// are guaranteed to be masked as valid.
+func getattr(ctx context.Context, file contextFile) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ // Retrieve attributes over the wire.
+ qid, valid, attr, err := file.getAttr(ctx, p9.AttrMaskAll())
+ if err != nil {
+ return qid, valid, attr, err
+ }
+
+ // Require mode, size, and raw device id.
+ if !valid.Mode || !valid.Size || !valid.RDev {
+ return qid, valid, attr, syscall.EIO
+ }
+
+ return qid, valid, attr, nil
+}
+
+func unstable(ctx context.Context, valid p9.AttrMask, pattr p9.Attr, mounter fs.FileOwner, client *p9.Client) fs.UnstableAttr {
+ return fs.UnstableAttr{
+ Size: int64(pattr.Size),
+ Usage: int64(pattr.Size),
+ Perms: perms(valid, pattr, client),
+ Owner: owner(mounter, valid, pattr),
+ AccessTime: atime(ctx, valid, pattr),
+ ModificationTime: mtime(ctx, valid, pattr),
+ StatusChangeTime: ctime(ctx, valid, pattr),
+ Links: links(valid, pattr),
+ }
+}
+
+func perms(valid p9.AttrMask, pattr p9.Attr, client *p9.Client) fs.FilePermissions {
+ if pattr.Mode.IsDir() && !p9.VersionSupportsMultiUser(client.Version()) {
+ // If user and group permissions bits are not supplied, use
+ // "other" bits to supplement them.
+ //
+ // Older Gofer's fake directories only have "other" permission,
+ // but will often be accessed via user or group permissions.
+ if pattr.Mode&0770 == 0 {
+ other := pattr.Mode & 07
+ pattr.Mode = pattr.Mode | other<<3 | other<<6
+ }
+ }
+ return fs.FilePermsFromP9(pattr.Mode)
+}
+
+func owner(mounter fs.FileOwner, valid p9.AttrMask, pattr p9.Attr) fs.FileOwner {
+ // Unless the file returned its UID and GID, it belongs to the mounting
+ // task's EUID/EGID.
+ owner := mounter
+ if valid.UID {
+ owner.UID = auth.KUID(pattr.UID)
+ }
+ if valid.GID {
+ owner.GID = auth.KGID(pattr.GID)
+ }
+ return owner
+}
+
+// bsize returns a block size from 9p attributes.
+func bsize(pattr p9.Attr) int64 {
+ if pattr.BlockSize > 0 {
+ return int64(pattr.BlockSize)
+ }
+ // Some files may have no clue of their block size. Better not to report
+ // something misleading or buggy and have a safe default.
+ return usermem.PageSize
+}
+
+// ntype returns an fs.InodeType from 9p attributes.
+func ntype(pattr p9.Attr) fs.InodeType {
+ switch {
+ case pattr.Mode.IsNamedPipe():
+ return fs.Pipe
+ case pattr.Mode.IsDir():
+ return fs.Directory
+ case pattr.Mode.IsSymlink():
+ return fs.Symlink
+ case pattr.Mode.IsCharacterDevice():
+ return fs.CharacterDevice
+ case pattr.Mode.IsBlockDevice():
+ return fs.BlockDevice
+ case pattr.Mode.IsSocket():
+ return fs.Socket
+ case pattr.Mode.IsRegular():
+ fallthrough
+ default:
+ return fs.RegularFile
+ }
+}
+
+// ctime returns a change time from 9p attributes.
+func ctime(ctx context.Context, valid p9.AttrMask, pattr p9.Attr) ktime.Time {
+ if valid.CTime {
+ return ktime.FromUnix(int64(pattr.CTimeSeconds), int64(pattr.CTimeNanoSeconds))
+ }
+ // Approximate ctime with mtime if ctime isn't available.
+ return mtime(ctx, valid, pattr)
+}
+
+// atime returns an access time from 9p attributes.
+func atime(ctx context.Context, valid p9.AttrMask, pattr p9.Attr) ktime.Time {
+ if valid.ATime {
+ return ktime.FromUnix(int64(pattr.ATimeSeconds), int64(pattr.ATimeNanoSeconds))
+ }
+ return ktime.NowFromContext(ctx)
+}
+
+// mtime returns a modification time from 9p attributes.
+func mtime(ctx context.Context, valid p9.AttrMask, pattr p9.Attr) ktime.Time {
+ if valid.MTime {
+ return ktime.FromUnix(int64(pattr.MTimeSeconds), int64(pattr.MTimeNanoSeconds))
+ }
+ return ktime.NowFromContext(ctx)
+}
+
+// links returns a hard link count from 9p attributes.
+func links(valid p9.AttrMask, pattr p9.Attr) uint64 {
+ // For gofer file systems that support link count (such as a local file gofer),
+ // we return the link count reported by the underlying file system.
+ if valid.NLink {
+ return pattr.NLink
+ }
+
+ // This node is likely backed by a file system that doesn't support links.
+ // We could readdir() and count children directories to provide an accurate
+ // link count. However this may be expensive since the gofer may be backed by remote
+ // storage. Instead, simply return 2 links for directories and 1 for everything else
+ // since no one relies on an accurate link count for gofer-based file systems.
+ switch ntype(pattr) {
+ case fs.Directory:
+ return 2
+ default:
+ return 1
+ }
+}
diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go
new file mode 100644
index 000000000..c59344589
--- /dev/null
+++ b/pkg/sentry/fs/gofer/cache_policy.go
@@ -0,0 +1,183 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// cachePolicy is a 9p cache policy. It has methods that determine what to
+// cache (if anything) for a given inode.
+type cachePolicy int
+
+const (
+ // Cache nothing.
+ cacheNone cachePolicy = iota
+
+ // Use virtual file system cache for everything.
+ cacheAll
+
+ // Use virtual file system cache for everything, but send writes to the
+ // fs agent immediately.
+ cacheAllWritethrough
+
+ // Use the (host) page cache for reads/writes, but don't cache anything
+ // else. This allows the sandbox filesystem to stay in sync with any
+ // changes to the remote filesystem.
+ //
+ // This policy should *only* be used with remote filesystems that
+ // donate their host FDs to the sandbox and thus use the host page
+ // cache, otherwise the dirent state will be inconsistent.
+ cacheRemoteRevalidating
+)
+
+// String returns the string name of the cache policy.
+func (cp cachePolicy) String() string {
+ switch cp {
+ case cacheNone:
+ return "cacheNone"
+ case cacheAll:
+ return "cacheAll"
+ case cacheAllWritethrough:
+ return "cacheAllWritethrough"
+ case cacheRemoteRevalidating:
+ return "cacheRemoteRevalidating"
+ default:
+ return "unknown"
+ }
+}
+
+func parseCachePolicy(policy string) (cachePolicy, error) {
+ switch policy {
+ case "fscache":
+ return cacheAll, nil
+ case "none":
+ return cacheNone, nil
+ case "fscache_writethrough":
+ return cacheAllWritethrough, nil
+ case "remote_revalidating":
+ return cacheRemoteRevalidating, nil
+ }
+ return cacheNone, fmt.Errorf("unsupported cache mode: %s", policy)
+}
+
+// cacheUAtters determines whether unstable attributes should be cached for the
+// given inode.
+func (cp cachePolicy) cacheUAttrs(inode *fs.Inode) bool {
+ if !fs.IsFile(inode.StableAttr) && !fs.IsDir(inode.StableAttr) {
+ return false
+ }
+ return cp == cacheAll || cp == cacheAllWritethrough
+}
+
+// cacheReaddir determines whether readdir results should be cached.
+func (cp cachePolicy) cacheReaddir() bool {
+ return cp == cacheAll || cp == cacheAllWritethrough
+}
+
+// useCachingInodeOps determines whether the page cache should be used for the
+// given inode. If the remote filesystem donates host FDs to the sentry, then
+// the host kernel's page cache will be used, otherwise we will use a
+// sentry-internal page cache.
+func (cp cachePolicy) useCachingInodeOps(inode *fs.Inode) bool {
+ // Do cached IO for regular files only. Some "character devices" expect
+ // no caching.
+ if !fs.IsFile(inode.StableAttr) {
+ return false
+ }
+ return cp == cacheAll || cp == cacheAllWritethrough
+}
+
+// writeThough indicates whether writes to the file should be synced to the
+// gofer immediately.
+func (cp cachePolicy) writeThrough(inode *fs.Inode) bool {
+ return cp == cacheNone || cp == cacheAllWritethrough
+}
+
+// revalidate revalidates the child Inode if the cache policy allows it.
+//
+// Depending on the cache policy, revalidate will walk from the parent to the
+// child inode, and if any unstable attributes have changed, will update the
+// cached attributes on the child inode. If the walk fails, or the returned
+// inode id is different from the one being revalidated, then the entire Dirent
+// must be reloaded.
+func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child *fs.Inode) bool {
+ if cp == cacheAll || cp == cacheAllWritethrough {
+ return false
+ }
+
+ if cp == cacheNone {
+ return true
+ }
+
+ childIops, ok := child.InodeOperations.(*inodeOperations)
+ if !ok {
+ panic(fmt.Sprintf("revalidating inode operations of unknown type %T", child.InodeOperations))
+ }
+ parentIops, ok := parent.InodeOperations.(*inodeOperations)
+ if !ok {
+ panic(fmt.Sprintf("revalidating inode operations with parent of unknown type %T", parent.InodeOperations))
+ }
+
+ // Walk from parent to child again.
+ //
+ // TODO(b/112031682): If we have a directory FD in the parent
+ // inodeOperations, then we can use fstatat(2) to get the inode
+ // attributes instead of making this RPC.
+ qids, f, mask, attr, err := parentIops.fileState.file.walkGetAttr(ctx, []string{name})
+ if err != nil {
+ // Can't look up the name. Trigger reload.
+ return true
+ }
+ f.close(ctx)
+
+ // If the Path has changed, then we are not looking at the file file.
+ // We must reload.
+ if qids[0].Path != childIops.fileState.key.Inode {
+ return true
+ }
+
+ // If we are not caching unstable attrs, then there is nothing to
+ // update on this inode.
+ if !cp.cacheUAttrs(child) {
+ return false
+ }
+
+ // Update the inode's cached unstable attrs.
+ s := childIops.session()
+ childIops.cachingInodeOps.UpdateUnstable(unstable(ctx, mask, attr, s.mounter, s.client))
+
+ return false
+}
+
+// keep indicates that dirents should be kept pinned in the dirent tree even if
+// there are no application references on the file.
+func (cp cachePolicy) keep(d *fs.Dirent) bool {
+ if cp == cacheNone {
+ return false
+ }
+ sattr := d.Inode.StableAttr
+ // NOTE(b/31979197): Only cache files, directories, and symlinks.
+ return fs.IsFile(sattr) || fs.IsDir(sattr) || fs.IsSymlink(sattr)
+}
+
+// cacheNegativeDirents indicates that negative dirents should be held in the
+// dirent tree.
+func (cp cachePolicy) cacheNegativeDirents() bool {
+ return cp == cacheAll || cp == cacheAllWritethrough
+}
diff --git a/pkg/sentry/fs/gofer/context_file.go b/pkg/sentry/fs/gofer/context_file.go
new file mode 100644
index 000000000..be53ac4d9
--- /dev/null
+++ b/pkg/sentry/fs/gofer/context_file.go
@@ -0,0 +1,190 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextFile is a wrapper around p9.File that notifies the context that
+// it's about to sleep before calling the Gofer over P9.
+type contextFile struct {
+ file p9.File
+}
+
+func (c *contextFile) walk(ctx context.Context, names []string) ([]p9.QID, contextFile, error) {
+ ctx.UninterruptibleSleepStart(false)
+
+ q, f, err := c.file.Walk(names)
+ if err != nil {
+ ctx.UninterruptibleSleepFinish(false)
+ return nil, contextFile{}, err
+ }
+ ctx.UninterruptibleSleepFinish(false)
+ return q, contextFile{file: f}, nil
+}
+
+func (c *contextFile) statFS(ctx context.Context) (p9.FSStat, error) {
+ ctx.UninterruptibleSleepStart(false)
+ s, err := c.file.StatFS()
+ ctx.UninterruptibleSleepFinish(false)
+ return s, err
+}
+
+func (c *contextFile) getAttr(ctx context.Context, req p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, m, a, err := c.file.GetAttr(req)
+ ctx.UninterruptibleSleepFinish(false)
+ return q, m, a, err
+}
+
+func (c *contextFile) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.SetAttr(valid, attr)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Allocate(mode, offset, length)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) rename(ctx context.Context, directory contextFile, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Rename(directory.file, name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) close(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Close()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) open(ctx context.Context, mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
+ ctx.UninterruptibleSleepStart(false)
+ f, q, u, err := c.file.Open(mode)
+ ctx.UninterruptibleSleepFinish(false)
+ return f, q, u, err
+}
+
+func (c *contextFile) readAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := c.file.ReadAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (c *contextFile) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := c.file.WriteAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (c *contextFile) fsync(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.FSync()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) create(ctx context.Context, name string, flags p9.OpenFlags, permissions p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fd, _, _, _, err := c.file.Create(name, flags, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return fd, err
+}
+
+func (c *contextFile) mkdir(ctx context.Context, name string, permissions p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, err := c.file.Mkdir(name, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return q, err
+}
+
+func (c *contextFile) symlink(ctx context.Context, oldName string, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, err := c.file.Symlink(oldName, newName, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return q, err
+}
+
+func (c *contextFile) link(ctx context.Context, target *contextFile, newName string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Link(target.file, newName)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) mknod(ctx context.Context, name string, permissions p9.FileMode, major uint32, minor uint32, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, err := c.file.Mknod(name, permissions, major, minor, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return q, err
+}
+
+func (c *contextFile) unlinkAt(ctx context.Context, name string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.UnlinkAt(name, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) readdir(ctx context.Context, offset uint64, count uint32) ([]p9.Dirent, error) {
+ ctx.UninterruptibleSleepStart(false)
+ d, err := c.file.Readdir(offset, count)
+ ctx.UninterruptibleSleepFinish(false)
+ return d, err
+}
+
+func (c *contextFile) readlink(ctx context.Context) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ s, err := c.file.Readlink()
+ ctx.UninterruptibleSleepFinish(false)
+ return s, err
+}
+
+func (c *contextFile) flush(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.Flush()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) walkGetAttr(ctx context.Context, names []string) ([]p9.QID, contextFile, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ q, f, m, a, err := c.file.WalkGetAttr(names)
+ if err != nil {
+ ctx.UninterruptibleSleepFinish(false)
+ return nil, contextFile{}, p9.AttrMask{}, p9.Attr{}, err
+ }
+ ctx.UninterruptibleSleepFinish(false)
+ return q, contextFile{file: f}, m, a, nil
+}
+
+func (c *contextFile) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, error) {
+ ctx.UninterruptibleSleepStart(false)
+ f, err := c.file.Connect(flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return f, err
+}
diff --git a/pkg/sentry/fs/gofer/device.go b/pkg/sentry/fs/gofer/device.go
new file mode 100644
index 000000000..1de6c247c
--- /dev/null
+++ b/pkg/sentry/fs/gofer/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// goferDevice is the gofer virtual device.
+var goferDevice = device.NewAnonMultiDevice()
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
new file mode 100644
index 000000000..fb4f50113
--- /dev/null
+++ b/pkg/sentry/fs/gofer/file.go
@@ -0,0 +1,333 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+var (
+ opensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a writable+executable file was opened from a gofer.")
+ opens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a 9P file was opened from a gofer.")
+ opensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a host file was opened from a gofer.")
+ reads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
+ readWait9P = metric.MustCreateNewUint64Metric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.")
+ readsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.")
+ readWaitHost = metric.MustCreateNewUint64Metric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.")
+)
+
+// fileOperations implements fs.FileOperations for a remote file system.
+//
+// +stateify savable
+type fileOperations struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosplice"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ // inodeOperations is the inodeOperations backing the file. It is protected
+ // by a reference held by File.Dirent.Inode which is stable until
+ // FileOperations.Release is called.
+ inodeOperations *inodeOperations `state:"wait"`
+
+ // dirCursor is the directory cursor.
+ dirCursor string
+
+ // handles are the opened remote file system handles, which may
+ // be shared with other files.
+ handles *handles `state:"nosave"`
+
+ // flags are the flags used to open handles.
+ flags fs.FileFlags `state:"wait"`
+}
+
+// fileOperations implements fs.FileOperations.
+var _ fs.FileOperations = (*fileOperations)(nil)
+
+// NewFile returns a file. NewFile is not appropriate with host pipes and sockets.
+//
+// The `name` argument is only used to log a warning if we are returning a
+// writeable+executable file. (A metric counter is incremented in this case as
+// well.) Note that we cannot call d.BaseName() directly in this function,
+// because that would lead to a lock order violation, since this is called in
+// d.Create which holds d.mu, while d.BaseName() takes d.parent.mu, and the two
+// locks must be taken in the opposite order.
+func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileFlags, i *inodeOperations, handles *handles) *fs.File {
+ // Remote file systems enforce readability/writability at an offset,
+ // see fs/9p/vfs_inode.c:v9fs_vfs_atomic_open -> fs/open.c:finish_open.
+ flags.Pread = true
+ flags.Pwrite = true
+
+ if fs.IsFile(dirent.Inode.StableAttr) {
+ // If cache policy is "remote revalidating", then we must
+ // ensure that we have a host FD. Otherwise, the
+ // sentry-internal page cache will be used, and we can end up
+ // in an inconsistent state if the remote file changes.
+ cp := dirent.Inode.InodeOperations.(*inodeOperations).session().cachePolicy
+ if cp == cacheRemoteRevalidating && handles.Host == nil {
+ panic(fmt.Sprintf("remote-revalidating cache policy requires gofer to donate host FD, but file %q did not have host FD", name))
+ }
+ }
+
+ f := &fileOperations{
+ inodeOperations: i,
+ handles: handles,
+ flags: flags,
+ }
+ if flags.Write {
+ if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil {
+ opensWX.Increment()
+ log.Warningf("Opened a writable executable: %q", name)
+ }
+ }
+ if handles.Host != nil {
+ opensHost.Increment()
+ } else {
+ opens9P.Increment()
+ }
+ return fs.NewFile(ctx, dirent, flags, f)
+}
+
+// Release implements fs.FileOpeations.Release.
+func (f *fileOperations) Release() {
+ f.handles.DecRef()
+}
+
+// Readdir implements fs.FileOperations.Readdir.
+func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &f.dirCursor,
+ }
+ n, err := fs.DirentReaddir(ctx, file.Dirent, f, root, dirCtx, file.Offset())
+ if f.inodeOperations.session().cachePolicy.cacheUAttrs(file.Dirent.Inode) {
+ f.inodeOperations.cachingInodeOps.TouchAccessTime(ctx, file.Dirent.Inode)
+ }
+ return n, err
+}
+
+// IterateDir implements fs.DirIterator.IterateDir.
+func (f *fileOperations) IterateDir(ctx context.Context, dirCtx *fs.DirCtx, offset int) (int, error) {
+ f.inodeOperations.readdirMu.Lock()
+ defer f.inodeOperations.readdirMu.Unlock()
+
+ // Fetch directory entries if needed.
+ if !f.inodeOperations.session().cachePolicy.cacheReaddir() || f.inodeOperations.readdirCache == nil {
+ entries, err := f.readdirAll(ctx)
+ if err != nil {
+ return offset, err
+ }
+
+ // Cache the readdir result.
+ f.inodeOperations.readdirCache = fs.NewSortedDentryMap(entries)
+ }
+
+ // Serialize the entries.
+ n, err := fs.GenericReaddir(dirCtx, f.inodeOperations.readdirCache)
+ return offset + n, err
+}
+
+// readdirAll fetches fs.DentAttrs for f, using the attributes of g.
+func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr, error) {
+ entries := make(map[string]fs.DentAttr)
+ var readOffset uint64
+ for {
+ // We choose some arbitrary high number of directory entries (64k) and call
+ // Readdir until we've exhausted them all.
+ dirents, err := f.handles.File.readdir(ctx, readOffset, 64*1024)
+ if err != nil {
+ return nil, err
+ }
+ if len(dirents) == 0 {
+ // We're done, we reached EOF.
+ break
+ }
+
+ // The last dirent contains the offset into the next set of dirents. The gofer
+ // returns the offset as an index into directories, not as a byte offset, because
+ // converting a byte offset to an index into directories entries is a huge pain.
+ // But everything is fine if we're consistent.
+ readOffset = dirents[len(dirents)-1].Offset
+
+ for _, dirent := range dirents {
+ if dirent.Name == "." || dirent.Name == ".." {
+ // These must not be included in Readdir results.
+ continue
+ }
+
+ // Find a best approximation of the type.
+ var nt fs.InodeType
+ switch dirent.Type {
+ case p9.TypeDir:
+ nt = fs.Directory
+ case p9.TypeSymlink:
+ nt = fs.Symlink
+ default:
+ nt = fs.RegularFile
+ }
+
+ // Install the DentAttr.
+ entries[dirent.Name] = fs.DentAttr{
+ Type: nt,
+ // Construct the key to find the virtual inode.
+ // Directory entries reside on the same Device
+ // and SecondaryDevice as their parent.
+ InodeID: goferDevice.Map(device.MultiDeviceKey{
+ Device: f.inodeOperations.fileState.key.Device,
+ SecondaryDevice: f.inodeOperations.fileState.key.SecondaryDevice,
+ Inode: dirent.QID.Path,
+ }),
+ }
+ }
+ }
+
+ return entries, nil
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if fs.IsDir(file.Dirent.Inode.StableAttr) {
+ // Not all remote file systems enforce this so this client does.
+ return 0, syserror.EISDIR
+ }
+ cp := f.inodeOperations.session().cachePolicy
+ if cp.useCachingInodeOps(file.Dirent.Inode) {
+ n, err := f.inodeOperations.cachingInodeOps.Write(ctx, src, offset)
+ if err != nil {
+ return n, err
+ }
+ if cp.writeThrough(file.Dirent.Inode) {
+ // Write out the file.
+ err = f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode)
+ }
+ return n, err
+ }
+ if f.inodeOperations.fileState.hostMappable != nil {
+ return f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset)
+ }
+ return src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
+}
+
+// incrementReadCounters increments the read counters for the read starting at the given time. We
+// use this function rather than using a defer in Read() to avoid the performance hit of defer.
+func (f *fileOperations) incrementReadCounters(start time.Time) {
+ if f.handles.Host != nil {
+ readsHost.Increment()
+ fs.IncrementWait(readWaitHost, start)
+ } else {
+ reads9P.Increment()
+ fs.IncrementWait(readWait9P, start)
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ var start time.Time
+ if fs.RecordWaitTime {
+ start = time.Now()
+ }
+ if fs.IsDir(file.Dirent.Inode.StableAttr) {
+ // Not all remote file systems enforce this so this client does.
+ f.incrementReadCounters(start)
+ return 0, syserror.EISDIR
+ }
+
+ if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
+ n, err := f.inodeOperations.cachingInodeOps.Read(ctx, file, dst, offset)
+ f.incrementReadCounters(start)
+ return n, err
+ }
+ n, err := dst.CopyOutFrom(ctx, f.handles.readWriterAt(ctx, offset))
+ f.incrementReadCounters(start)
+ return n, err
+}
+
+// Fsync implements fs.FileOperations.Fsync.
+func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start int64, end int64, syncType fs.SyncType) error {
+ switch syncType {
+ case fs.SyncAll, fs.SyncData:
+ if err := file.Dirent.Inode.WriteOut(ctx); err != nil {
+ return err
+ }
+ fallthrough
+ case fs.SyncBackingStorage:
+ // Sync remote caches.
+ if f.handles.Host != nil {
+ // Sync the host fd directly.
+ return syscall.Fsync(f.handles.Host.FD())
+ }
+ // Otherwise sync on the p9.File handle.
+ return f.handles.File.fsync(ctx)
+ }
+ panic("invalid sync type")
+}
+
+// Flush implements fs.FileOperations.Flush.
+func (f *fileOperations) Flush(ctx context.Context, file *fs.File) error {
+ // If this file is not opened writable then there is nothing to flush.
+ // We do this because some p9 server implementations of Flush are
+ // over-zealous.
+ //
+ // FIXME(edahlgren): weaken these implementations and remove this check.
+ if !file.Flags().Write {
+ return nil
+ }
+ // Execute the flush.
+ return f.handles.File.flush(ctx)
+}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (f *fileOperations) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ return f.inodeOperations.configureMMap(file, opts)
+}
+
+// UnstableAttr implements fs.FileOperations.UnstableAttr.
+func (f *fileOperations) UnstableAttr(ctx context.Context, file *fs.File) (fs.UnstableAttr, error) {
+ s := f.inodeOperations.session()
+ if s.cachePolicy.cacheUAttrs(file.Dirent.Inode) {
+ return f.inodeOperations.cachingInodeOps.UnstableAttr(ctx, file.Dirent.Inode)
+ }
+ // Use f.handles.File, which represents 9P fids that have been opened,
+ // instead of inodeFileState.file, which represents 9P fids that have not.
+ // This may be significantly more efficient in some implementations.
+ _, valid, pattr, err := getattr(ctx, f.handles.File)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ return unstable(ctx, valid, pattr, s.mounter, s.client), nil
+}
+
+// Seek implements fs.FileOperations.Seek.
+func (f *fileOperations) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return fsutil.SeekWithDirCursor(ctx, file, whence, offset, &f.dirCursor)
+}
diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go
new file mode 100644
index 000000000..31264e065
--- /dev/null
+++ b/pkg/sentry/fs/gofer/file_state.go
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// afterLoad is invoked by stateify.
+func (f *fileOperations) afterLoad() {
+ load := func() error {
+ f.inodeOperations.fileState.waitForLoad()
+
+ // Manually load the open handles.
+ var err error
+ // TODO(b/38173783): Context is not plumbed to save/restore.
+ f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags)
+ if err != nil {
+ return fmt.Errorf("failed to re-open handle: %v", err)
+ }
+ return nil
+ }
+ fs.Async(fs.CatchError(load))
+}
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
new file mode 100644
index 000000000..6ab89fcc2
--- /dev/null
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -0,0 +1,247 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gofer implements a remote 9p filesystem.
+package gofer
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// The following are options defined by the Linux 9p client that we support,
+// see Documentation/filesystems/9p.txt.
+const (
+ // The transport method.
+ transportKey = "trans"
+
+ // The file tree to access when the file server
+ // is exporting several file systems. Stands for "attach name".
+ anameKey = "aname"
+
+ // The caching policy.
+ cacheKey = "cache"
+
+ // The file descriptor for reading with trans=fd.
+ readFDKey = "rfdno"
+
+ // The file descriptor for writing with trans=fd.
+ writeFDKey = "wfdno"
+
+ // The number of bytes to use for a 9p packet payload.
+ msizeKey = "msize"
+
+ // The 9p protocol version.
+ versionKey = "version"
+
+ // If set to true allows the creation of unix domain sockets inside the
+ // sandbox using files backed by the gofer. If set to false, unix sockets
+ // cannot be bound to gofer files without an overlay on top.
+ privateUnixSocketKey = "privateunixsocket"
+)
+
+// defaultAname is the default attach name.
+const defaultAname = "/"
+
+// defaultMSize is the message size used for chunking large read and write requests.
+// This has been tested to give good enough performance up to 64M.
+const defaultMSize = 1024 * 1024 // 1M
+
+// defaultVersion is the default 9p protocol version. Will negotiate downwards with
+// file server if needed.
+var defaultVersion = p9.HighestVersionString()
+
+// Number of names of non-children to cache, preventing unneeded walks. 64 is
+// plenty for nodejs, which seems to stat about 4 children on every require().
+const nonChildrenCacheSize = 64
+
+var (
+ // ErrNoTransport is returned when there is no 'trans' option.
+ ErrNoTransport = errors.New("missing required option: 'trans='")
+
+ // ErrFileNoReadFD is returned when there is no 'rfdno' option.
+ ErrFileNoReadFD = errors.New("missing required option: 'rfdno='")
+
+ // ErrFileNoWriteFD is returned when there is no 'wfdno' option.
+ ErrFileNoWriteFD = errors.New("missing required option: 'wfdno='")
+)
+
+// filesystem is a 9p client.
+//
+// +stateify savable
+type filesystem struct{}
+
+var _ fs.Filesystem = (*filesystem)(nil)
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+// The name matches fs/9p/vfs_super.c:v9fs_fs_type.name.
+const FilesystemName = "9p"
+
+// Name is the name of the filesystem.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount prohibits users from using mount(2) with this file system.
+func (*filesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// The 9p Linux client returns FS_RENAME_DOES_D_MOVE, see fs/9p/vfs_super.c.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns an attached 9p client that can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // Parse and validate the mount options.
+ o, err := options(data)
+ if err != nil {
+ return nil, err
+ }
+
+ // Construct the 9p root to mount. We intentionally diverge from Linux in that
+ // the first Tversion and Tattach requests are done lazily.
+ return Root(ctx, device, f, flags, o)
+}
+
+// opts are parsed 9p mount options.
+type opts struct {
+ fd int
+ aname string
+ policy cachePolicy
+ msize uint32
+ version string
+ privateunixsocket bool
+}
+
+// options parses mount(2) data into structured options.
+func options(data string) (opts, error) {
+ var o opts
+
+ // Parse generic comma-separated key=value options, this file system expects them.
+ options := fs.GenericMountSourceOptions(data)
+
+ // Check for the required 'trans=fd' option.
+ trans, ok := options[transportKey]
+ if !ok {
+ return o, ErrNoTransport
+ }
+ if trans != "fd" {
+ return o, fmt.Errorf("unsupported transport: 'trans=%s'", trans)
+ }
+ delete(options, transportKey)
+
+ // Check for the required 'rfdno=' option.
+ srfd, ok := options[readFDKey]
+ if !ok {
+ return o, ErrFileNoReadFD
+ }
+ delete(options, readFDKey)
+
+ // Check for the required 'wfdno=' option.
+ swfd, ok := options[writeFDKey]
+ if !ok {
+ return o, ErrFileNoWriteFD
+ }
+ delete(options, writeFDKey)
+
+ // Parse the read fd.
+ rfd, err := strconv.Atoi(srfd)
+ if err != nil {
+ return o, fmt.Errorf("invalid fd for 'rfdno=%s': %v", srfd, err)
+ }
+
+ // Parse the write fd.
+ wfd, err := strconv.Atoi(swfd)
+ if err != nil {
+ return o, fmt.Errorf("invalid fd for 'wfdno=%s': %v", swfd, err)
+ }
+
+ // Require that the read and write fd are the same.
+ if rfd != wfd {
+ return o, fmt.Errorf("fd in 'rfdno=%d' and 'wfdno=%d' must match", rfd, wfd)
+ }
+ o.fd = rfd
+
+ // Parse the attach name.
+ o.aname = defaultAname
+ if an, ok := options[anameKey]; ok {
+ o.aname = an
+ delete(options, anameKey)
+ }
+
+ // Parse the cache policy. Reject unsupported policies.
+ o.policy = cacheAll
+ if policy, ok := options[cacheKey]; ok {
+ cp, err := parseCachePolicy(policy)
+ if err != nil {
+ return o, err
+ }
+ o.policy = cp
+ delete(options, cacheKey)
+ }
+
+ // Parse the message size. Reject malformed options.
+ o.msize = uint32(defaultMSize)
+ if m, ok := options[msizeKey]; ok {
+ i, err := strconv.ParseUint(m, 10, 32)
+ if err != nil {
+ return o, fmt.Errorf("invalid message size for 'msize=%s': %v", m, err)
+ }
+ o.msize = uint32(i)
+ delete(options, msizeKey)
+ }
+
+ // Parse the protocol version.
+ o.version = defaultVersion
+ if v, ok := options[versionKey]; ok {
+ o.version = v
+ delete(options, versionKey)
+ }
+
+ // Parse the unix socket policy. Reject non-booleans.
+ if v, ok := options[privateUnixSocketKey]; ok {
+ b, err := strconv.ParseBool(v)
+ if err != nil {
+ return o, fmt.Errorf("invalid boolean value for '%s=%s': %v", privateUnixSocketKey, v, err)
+ }
+ o.privateunixsocket = b
+ delete(options, privateUnixSocketKey)
+ }
+
+ // Fail to attach if the caller wanted us to do something that we
+ // don't support.
+ if len(options) > 0 {
+ return o, fmt.Errorf("unsupported mount options: %v", options)
+ }
+
+ return o, nil
+}
diff --git a/pkg/sentry/fs/gofer/gofer_state_autogen.go b/pkg/sentry/fs/gofer/gofer_state_autogen.go
new file mode 100755
index 000000000..f274d0c39
--- /dev/null
+++ b/pkg/sentry/fs/gofer/gofer_state_autogen.go
@@ -0,0 +1,113 @@
+// automatically generated by stateify.
+
+package gofer
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *fileOperations) beforeSave() {}
+func (x *fileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("inodeOperations", &x.inodeOperations)
+ m.Save("dirCursor", &x.dirCursor)
+ m.Save("flags", &x.flags)
+}
+
+func (x *fileOperations) load(m state.Map) {
+ m.LoadWait("inodeOperations", &x.inodeOperations)
+ m.Load("dirCursor", &x.dirCursor)
+ m.LoadWait("flags", &x.flags)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *filesystem) beforeSave() {}
+func (x *filesystem) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *filesystem) afterLoad() {}
+func (x *filesystem) load(m state.Map) {
+}
+
+func (x *inodeOperations) beforeSave() {}
+func (x *inodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("fileState", &x.fileState)
+ m.Save("cachingInodeOps", &x.cachingInodeOps)
+}
+
+func (x *inodeOperations) afterLoad() {}
+func (x *inodeOperations) load(m state.Map) {
+ m.LoadWait("fileState", &x.fileState)
+ m.Load("cachingInodeOps", &x.cachingInodeOps)
+}
+
+func (x *inodeFileState) save(m state.Map) {
+ x.beforeSave()
+ var loading struct{} = x.saveLoading()
+ m.SaveValue("loading", loading)
+ m.Save("s", &x.s)
+ m.Save("sattr", &x.sattr)
+ m.Save("savedUAttr", &x.savedUAttr)
+ m.Save("hostMappable", &x.hostMappable)
+}
+
+func (x *inodeFileState) load(m state.Map) {
+ m.LoadWait("s", &x.s)
+ m.LoadWait("sattr", &x.sattr)
+ m.Load("savedUAttr", &x.savedUAttr)
+ m.Load("hostMappable", &x.hostMappable)
+ m.LoadValue("loading", new(struct{}), func(y interface{}) { x.loadLoading(y.(struct{})) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *endpointMaps) beforeSave() {}
+func (x *endpointMaps) save(m state.Map) {
+ x.beforeSave()
+ m.Save("direntMap", &x.direntMap)
+ m.Save("pathMap", &x.pathMap)
+}
+
+func (x *endpointMaps) afterLoad() {}
+func (x *endpointMaps) load(m state.Map) {
+ m.Load("direntMap", &x.direntMap)
+ m.Load("pathMap", &x.pathMap)
+}
+
+func (x *session) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("msize", &x.msize)
+ m.Save("version", &x.version)
+ m.Save("cachePolicy", &x.cachePolicy)
+ m.Save("aname", &x.aname)
+ m.Save("superBlockFlags", &x.superBlockFlags)
+ m.Save("connID", &x.connID)
+ m.Save("inodeMappings", &x.inodeMappings)
+ m.Save("mounter", &x.mounter)
+ m.Save("endpoints", &x.endpoints)
+}
+
+func (x *session) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.LoadWait("msize", &x.msize)
+ m.LoadWait("version", &x.version)
+ m.LoadWait("cachePolicy", &x.cachePolicy)
+ m.LoadWait("aname", &x.aname)
+ m.LoadWait("superBlockFlags", &x.superBlockFlags)
+ m.LoadWait("connID", &x.connID)
+ m.LoadWait("inodeMappings", &x.inodeMappings)
+ m.LoadWait("mounter", &x.mounter)
+ m.LoadWait("endpoints", &x.endpoints)
+ m.AfterLoad(x.afterLoad)
+}
+
+func init() {
+ state.Register("gofer.fileOperations", (*fileOperations)(nil), state.Fns{Save: (*fileOperations).save, Load: (*fileOperations).load})
+ state.Register("gofer.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load})
+ state.Register("gofer.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load})
+ state.Register("gofer.inodeFileState", (*inodeFileState)(nil), state.Fns{Save: (*inodeFileState).save, Load: (*inodeFileState).load})
+ state.Register("gofer.endpointMaps", (*endpointMaps)(nil), state.Fns{Save: (*endpointMaps).save, Load: (*endpointMaps).load})
+ state.Register("gofer.session", (*session)(nil), state.Fns{Save: (*session).save, Load: (*session).load})
+}
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
new file mode 100644
index 000000000..c7098cd36
--- /dev/null
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/secio"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+)
+
+// handles are the open handles of a gofer file. They are reference counted to
+// support open handle sharing between files for read only filesystems.
+//
+// If Host != nil then it will be used exclusively over File.
+type handles struct {
+ refs.AtomicRefCount
+
+ // File is a p9.File handle. Must not be nil.
+ File contextFile
+
+ // Host is an *fd.FD handle. May be nil.
+ Host *fd.FD
+}
+
+// DecRef drops a reference on handles.
+func (h *handles) DecRef() {
+ h.DecRefWithDestructor(func() {
+ if h.Host != nil {
+ if err := h.Host.Close(); err != nil {
+ log.Warningf("error closing host file: %v", err)
+ }
+ }
+ // FIXME(b/38173783): Context is not plumbed here.
+ if err := h.File.close(context.Background()); err != nil {
+ log.Warningf("error closing p9 file: %v", err)
+ }
+ })
+}
+
+func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*handles, error) {
+ _, newFile, err := file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ var p9flags p9.OpenFlags
+ switch {
+ case flags.Read && flags.Write:
+ p9flags = p9.ReadWrite
+ case flags.Read && !flags.Write:
+ p9flags = p9.ReadOnly
+ case !flags.Read && flags.Write:
+ p9flags = p9.WriteOnly
+ default:
+ panic("impossible fs.FileFlags")
+ }
+
+ hostFile, _, _, err := newFile.open(ctx, p9flags)
+ if err != nil {
+ newFile.close(ctx)
+ return nil, err
+ }
+ h := &handles{
+ File: newFile,
+ Host: hostFile,
+ }
+ return h, nil
+}
+
+type handleReadWriter struct {
+ ctx context.Context
+ h *handles
+ off int64
+}
+
+func (h *handles) readWriterAt(ctx context.Context, offset int64) *handleReadWriter {
+ return &handleReadWriter{ctx, h, offset}
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *handleReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ var r io.Reader
+ if rw.h.Host != nil {
+ r = secio.NewOffsetReader(rw.h.Host, rw.off)
+ } else {
+ r = &p9.ReadWriterFile{File: rw.h.File.file, Offset: uint64(rw.off)}
+ }
+
+ rw.ctx.UninterruptibleSleepStart(false)
+ defer rw.ctx.UninterruptibleSleepFinish(false)
+ n, err := safemem.FromIOReader{r}.ReadToBlocks(dsts)
+ rw.off += int64(n)
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *handleReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ var w io.Writer
+ if rw.h.Host != nil {
+ w = secio.NewOffsetWriter(rw.h.Host, rw.off)
+ } else {
+ w = &p9.ReadWriterFile{File: rw.h.File.file, Offset: uint64(rw.off)}
+ }
+
+ rw.ctx.UninterruptibleSleepStart(false)
+ defer rw.ctx.UninterruptibleSleepFinish(false)
+ n, err := safemem.FromIOWriter{w}.WriteFromBlocks(srcs)
+ rw.off += int64(n)
+ return n, err
+}
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
new file mode 100644
index 000000000..dcb3b2880
--- /dev/null
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -0,0 +1,606 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "errors"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fdpipe"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/host"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// inodeOperations implements fs.InodeOperations.
+//
+// +stateify savable
+type inodeOperations struct {
+ fsutil.InodeNotVirtual `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+
+ // fileState implements fs.CachedFileObject. It exists
+ // to break a circular load dependency between inodeOperations
+ // and cachingInodeOps (below).
+ fileState *inodeFileState `state:"wait"`
+
+ // cachingInodeOps implement memmap.Mappable for inodeOperations.
+ cachingInodeOps *fsutil.CachingInodeOperations
+
+ // readdirMu protects readdirCache and concurrent Readdirs.
+ readdirMu sync.Mutex `state:"nosave"`
+
+ // readdirCache is a cache of readdir results in the form of
+ // a fs.SortedDentryMap.
+ //
+ // Starts out as nil, and is initialized under readdirMu lazily;
+ // invalidating the cache means setting it to nil.
+ readdirCache *fs.SortedDentryMap `state:"nosave"`
+}
+
+// inodeFileState implements fs.CachedFileObject and otherwise fully
+// encapsulates state that needs to be manually loaded on restore for
+// this file object.
+//
+// This unfortunate structure exists because fs.CachingInodeOperations
+// defines afterLoad and therefore cannot be lazily loaded (to break a
+// circular load dependency between it and inodeOperations). Even with
+// lazy loading, this approach defines the dependencies between objects
+// and the expected load behavior more concretely.
+//
+// +stateify savable
+type inodeFileState struct {
+ // s is common file system state for Gofers.
+ s *session `state:"wait"`
+
+ // MultiDeviceKey consists of:
+ //
+ // * Device: file system device from a specific gofer.
+ // * SecondaryDevice: unique identifier of the attach point.
+ // * Inode: the inode of this resource, unique per Device.=
+ //
+ // These fields combined enable consistent hashing of virtual inodes
+ // on goferDevice.
+ key device.MultiDeviceKey `state:"nosave"`
+
+ // file is the p9 file that contains a single unopened fid.
+ file contextFile `state:"nosave"`
+
+ // sattr caches the stable attributes.
+ sattr fs.StableAttr `state:"wait"`
+
+ // handlesMu protects the below fields.
+ handlesMu sync.RWMutex `state:"nosave"`
+
+ // If readHandles is non-nil, it holds handles that are either read-only or
+ // read/write. If writeHandles is non-nil, it holds write-only handles if
+ // writeHandlesRW is false, and read/write handles if writeHandlesRW is
+ // true.
+ //
+ // Once readHandles becomes non-nil, it can't be changed until
+ // inodeFileState.Release(), because of a defect in the
+ // fsutil.CachedFileObject interface: there's no way for the caller of
+ // fsutil.CachedFileObject.FD() to keep the returned FD open, so if we
+ // racily replace readHandles after inodeFileState.FD() has returned
+ // readHandles.Host.FD(), fsutil.CachingInodeOperations may use a closed
+ // FD. writeHandles can be changed if writeHandlesRW is false, since
+ // inodeFileState.FD() can't return a write-only FD, but can't be changed
+ // if writeHandlesRW is true for the same reason.
+ readHandles *handles `state:"nosave"`
+ writeHandles *handles `state:"nosave"`
+ writeHandlesRW bool `state:"nosave"`
+
+ // loading is acquired when the inodeFileState begins an asynchronous
+ // load. It releases when the load is complete. Callers that require all
+ // state to be available should call waitForLoad() to ensure that.
+ loading sync.Mutex `state:".(struct{})"`
+
+ // savedUAttr is only allocated during S/R. It points to the save-time
+ // unstable attributes and is used to validate restore-time ones.
+ //
+ // Note that these unstable attributes are only used to detect cross-S/R
+ // external file system metadata changes. They may differ from the
+ // cached unstable attributes in cachingInodeOps, as that might differ
+ // from the external file system attributes if there had been WriteOut
+ // failures. S/R is transparent to Sentry and the latter will continue
+ // using its cached values after restore.
+ savedUAttr *fs.UnstableAttr
+
+ // hostMappable is created when using 'cacheRemoteRevalidating' to map pages
+ // directly from host.
+ hostMappable *fsutil.HostMappable
+}
+
+// Release releases file handles.
+func (i *inodeFileState) Release(ctx context.Context) {
+ i.file.close(ctx)
+ if i.readHandles != nil {
+ i.readHandles.DecRef()
+ }
+ if i.writeHandles != nil {
+ i.writeHandles.DecRef()
+ }
+}
+
+func (i *inodeFileState) canShareHandles() bool {
+ // Only share handles for regular files, since for other file types,
+ // distinct handles may have special semantics even if they represent the
+ // same file. Disable handle sharing for cache policy cacheNone, since this
+ // is legacy behavior.
+ return fs.IsFile(i.sattr) && i.s.cachePolicy != cacheNone
+}
+
+// Preconditions: i.handlesMu must be locked for writing.
+func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles) {
+ if flags.Read && i.readHandles == nil {
+ h.IncRef()
+ i.readHandles = h
+ }
+ if flags.Write {
+ if i.writeHandles == nil {
+ h.IncRef()
+ i.writeHandles = h
+ i.writeHandlesRW = flags.Read
+ } else if !i.writeHandlesRW && flags.Read {
+ // Upgrade i.writeHandles.
+ i.writeHandles.DecRef()
+ h.IncRef()
+ i.writeHandles = h
+ i.writeHandlesRW = flags.Read
+ }
+ }
+}
+
+// getHandles returns a set of handles for a new file using i opened with the
+// given flags.
+func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags) (*handles, error) {
+ if !i.canShareHandles() {
+ return newHandles(ctx, i.file, flags)
+ }
+ i.handlesMu.Lock()
+ defer i.handlesMu.Unlock()
+ // Do we already have usable shared handles?
+ if flags.Write {
+ if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) {
+ i.writeHandles.IncRef()
+ return i.writeHandles, nil
+ }
+ } else if i.readHandles != nil {
+ i.readHandles.IncRef()
+ return i.readHandles, nil
+ }
+ // No; get new handles and cache them for future sharing.
+ h, err := newHandles(ctx, i.file, flags)
+ if err != nil {
+ return nil, err
+ }
+ i.setSharedHandlesLocked(flags, h)
+ return h, nil
+}
+
+// ReadToBlocksAt implements fsutil.CachedFileObject.ReadToBlocksAt.
+func (i *inodeFileState) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+ return i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts)
+}
+
+// WriteFromBlocksAt implements fsutil.CachedFileObject.WriteFromBlocksAt.
+func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+ return i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs)
+}
+
+// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
+ if i.skipSetAttr(mask) {
+ return nil
+ }
+ as, ans := attr.AccessTime.Unix()
+ ms, mns := attr.ModificationTime.Unix()
+ // An update of status change time is implied by mask.AccessTime
+ // or mask.ModificationTime. Updating status change time to a
+ // time earlier than the system time is not possible.
+ return i.file.setAttr(
+ ctx,
+ p9.SetAttrMask{
+ Permissions: mask.Perms,
+ Size: mask.Size,
+ UID: mask.UID,
+ GID: mask.GID,
+ ATime: mask.AccessTime,
+ ATimeNotSystemTime: true,
+ MTime: mask.ModificationTime,
+ MTimeNotSystemTime: true,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(attr.Perms.LinuxMode()),
+ UID: p9.UID(attr.Owner.UID),
+ GID: p9.GID(attr.Owner.GID),
+ Size: uint64(attr.Size),
+ ATimeSeconds: uint64(as),
+ ATimeNanoSeconds: uint64(ans),
+ MTimeSeconds: uint64(ms),
+ MTimeNanoSeconds: uint64(mns),
+ })
+}
+
+// skipSetAttr checks if attribute change can be skipped. It can be skipped
+// when:
+// - Mask is empty
+// - Mask contains only attributes that cannot be set in the gofer
+// - Mask contains only atime and/or mtime, and host FD exists
+//
+// Updates to atime and mtime can be skipped because cached value will be
+// "close enough" to host value, given that operation went directly to host FD.
+// Skipping atime updates is particularly important to reduce the number of
+// operations sent to the Gofer for readonly files.
+func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
+ // First remove attributes that cannot be updated.
+ cpy := mask
+ cpy.Type = false
+ cpy.DeviceID = false
+ cpy.InodeID = false
+ cpy.BlockSize = false
+ cpy.Usage = false
+ cpy.Links = false
+ if cpy.Empty() {
+ return true
+ }
+
+ // Then check if more than just atime and mtime is being set.
+ cpy.AccessTime = false
+ cpy.ModificationTime = false
+ if !cpy.Empty() {
+ return false
+ }
+
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+ return (i.readHandles != nil && i.readHandles.Host != nil) ||
+ (i.writeHandles != nil && i.writeHandles.Host != nil)
+}
+
+// Sync implements fsutil.CachedFileObject.Sync.
+func (i *inodeFileState) Sync(ctx context.Context) error {
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+ if i.writeHandles == nil {
+ return nil
+ }
+ return i.writeHandles.File.fsync(ctx)
+}
+
+// FD implements fsutil.CachedFileObject.FD.
+func (i *inodeFileState) FD() int {
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+ if i.writeHandlesRW && i.writeHandles != nil && i.writeHandles.Host != nil {
+ return int(i.writeHandles.Host.FD())
+ }
+ if i.readHandles != nil && i.readHandles.Host != nil {
+ return int(i.readHandles.Host.FD())
+ }
+ return -1
+}
+
+// waitForLoad makes sure any restore-issued loading is done.
+func (i *inodeFileState) waitForLoad() {
+ // This is not a no-op. The loading mutex is hold upon restore until
+ // all loading actions are done.
+ i.loading.Lock()
+ i.loading.Unlock()
+}
+
+func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, error) {
+ _, valid, pattr, err := getattr(ctx, i.file)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ return unstable(ctx, valid, pattr, i.s.mounter, i.s.client), nil
+}
+
+func (i *inodeFileState) Allocate(ctx context.Context, offset, length int64) error {
+ i.handlesMu.RLock()
+ defer i.handlesMu.RUnlock()
+
+ // No options are supported for now.
+ mode := p9.AllocateMode{}
+ return i.writeHandles.File.allocate(ctx, mode, uint64(offset), uint64(length))
+}
+
+// session extracts the gofer's session from the MountSource.
+func (i *inodeOperations) session() *session {
+ return i.fileState.s
+}
+
+// Release implements fs.InodeOperations.Release.
+func (i *inodeOperations) Release(ctx context.Context) {
+ i.cachingInodeOps.Release()
+
+ // Releasing the fileState may make RPCs to the gofer. There is
+ // no need to wait for those to return, so we can do this
+ // asynchronously.
+ //
+ // We use AsyncWithContext to avoid needing to allocate an extra
+ // anonymous function on the heap.
+ fs.AsyncWithContext(ctx, i.fileState.Release)
+}
+
+// Mappable implements fs.InodeOperations.Mappable.
+func (i *inodeOperations) Mappable(inode *fs.Inode) memmap.Mappable {
+ if i.session().cachePolicy.useCachingInodeOps(inode) {
+ return i.cachingInodeOps
+ }
+ // This check is necessary because it's returning an interface type.
+ if i.fileState.hostMappable != nil {
+ return i.fileState.hostMappable
+ }
+ return nil
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (i *inodeOperations) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ return i.cachingInodeOps.UnstableAttr(ctx, inode)
+ }
+ return i.fileState.unstableAttr(ctx)
+}
+
+// Check implements fs.InodeOperations.Check.
+func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ switch d.Inode.StableAttr.Type {
+ case fs.Socket:
+ return i.getFileSocket(ctx, d, flags)
+ case fs.Pipe:
+ return i.getFilePipe(ctx, d, flags)
+ default:
+ return i.getFileDefault(ctx, d, flags)
+ }
+}
+
+func (i *inodeOperations) getFileSocket(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ f, err := i.fileState.file.connect(ctx, p9.AnonymousSocket)
+ if err != nil {
+ return nil, syscall.EIO
+ }
+ fsf, err := host.NewSocketWithDirent(ctx, d, f, flags)
+ if err != nil {
+ f.Close()
+ return nil, err
+ }
+ return fsf, nil
+}
+
+func (i *inodeOperations) getFilePipe(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ // Try to open as a host pipe; if that doesn't work, handle it normally.
+ pipeOps, err := fdpipe.Open(ctx, i, flags)
+ if err == errNotHostFile {
+ return i.getFileDefault(ctx, d, flags)
+ }
+ if err != nil {
+ return nil, err
+ }
+ return fs.NewFile(ctx, d, flags, pipeOps), nil
+}
+
+// errNotHostFile indicates that the file is not a host file.
+var errNotHostFile = errors.New("not a host file")
+
+// NonBlockingOpen implements fdpipe.NonBlockingOpener for opening host named pipes.
+func (i *inodeOperations) NonBlockingOpen(ctx context.Context, p fs.PermMask) (*fd.FD, error) {
+ i.fileState.waitForLoad()
+
+ // Get a cloned fid which we will open.
+ _, newFile, err := i.fileState.file.walk(ctx, nil)
+ if err != nil {
+ log.Warningf("Open Walk failed: %v", err)
+ return nil, err
+ }
+ defer newFile.close(ctx)
+
+ flags, err := openFlagsFromPerms(p)
+ if err != nil {
+ log.Warningf("Open flags %s parsing failed: %v", p, err)
+ return nil, err
+ }
+ hostFile, _, _, err := newFile.open(ctx, flags)
+ // If the host file returned is nil and the error is nil,
+ // then this was never a host file to begin with, and should
+ // be treated like a remote file.
+ if hostFile == nil && err == nil {
+ return nil, errNotHostFile
+ }
+ return hostFile, err
+}
+
+func (i *inodeOperations) getFileDefault(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ h, err := i.fileState.getHandles(ctx, flags)
+ if err != nil {
+ return nil, err
+ }
+ return NewFile(ctx, d, d.BaseName(), flags, i, h), nil
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (i *inodeOperations) SetPermissions(ctx context.Context, inode *fs.Inode, p fs.FilePermissions) bool {
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ return i.cachingInodeOps.SetPermissions(ctx, inode, p)
+ }
+
+ mask := p9.SetAttrMask{Permissions: true}
+ pattr := p9.SetAttr{Permissions: p9.FileMode(p.LinuxMode())}
+ // Execute the chmod.
+ return i.fileState.file.setAttr(ctx, mask, pattr) == nil
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (i *inodeOperations) SetOwner(ctx context.Context, inode *fs.Inode, owner fs.FileOwner) error {
+ // Save the roundtrip.
+ if !owner.UID.Ok() && !owner.GID.Ok() {
+ return nil
+ }
+
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ return i.cachingInodeOps.SetOwner(ctx, inode, owner)
+ }
+
+ var mask p9.SetAttrMask
+ var attr p9.SetAttr
+ if owner.UID.Ok() {
+ mask.UID = true
+ attr.UID = p9.UID(owner.UID)
+ }
+ if owner.GID.Ok() {
+ mask.GID = true
+ attr.GID = p9.GID(owner.GID)
+ }
+ return i.fileState.file.setAttr(ctx, mask, attr)
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (i *inodeOperations) SetTimestamps(ctx context.Context, inode *fs.Inode, ts fs.TimeSpec) error {
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ return i.cachingInodeOps.SetTimestamps(ctx, inode, ts)
+ }
+
+ return utimes(ctx, i.fileState.file, ts)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length int64) error {
+ // This can only be called for files anyway.
+ if i.session().cachePolicy.useCachingInodeOps(inode) {
+ return i.cachingInodeOps.Truncate(ctx, inode, length)
+ }
+ if i.session().cachePolicy == cacheRemoteRevalidating {
+ return i.fileState.hostMappable.Truncate(ctx, length)
+ }
+
+ return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)})
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, length int64) error {
+ // This can only be called for files anyway.
+ if i.session().cachePolicy.useCachingInodeOps(inode) {
+ return i.cachingInodeOps.Allocate(ctx, offset, length)
+ }
+ if i.session().cachePolicy == cacheRemoteRevalidating {
+ return i.fileState.hostMappable.Allocate(ctx, offset, length)
+ }
+
+ // No options are supported for now.
+ mode := p9.AllocateMode{}
+ return i.fileState.file.allocate(ctx, mode, uint64(offset), uint64(length))
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ if !i.session().cachePolicy.cacheUAttrs(inode) {
+ return nil
+ }
+
+ return i.cachingInodeOps.WriteOut(ctx, inode)
+}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (i *inodeOperations) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if !fs.IsSymlink(inode.StableAttr) {
+ return "", syscall.ENOLINK
+ }
+ return i.fileState.file.readlink(ctx)
+}
+
+// Getlink implementfs fs.InodeOperations.Getlink.
+func (i *inodeOperations) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ if !fs.IsSymlink(i.fileState.sattr) {
+ return nil, syserror.ENOLINK
+ }
+ return nil, fs.ErrResolveViaReadlink
+}
+
+// StatFS makes a StatFS request.
+func (i *inodeOperations) StatFS(ctx context.Context) (fs.Info, error) {
+ fsstat, err := i.fileState.file.statFS(ctx)
+ if err != nil {
+ return fs.Info{}, err
+ }
+
+ info := fs.Info{
+ // This is primarily for distinguishing a gofer file system in
+ // tests. Testing is important, so instead of defining
+ // something completely random, use a standard value.
+ Type: linux.V9FS_MAGIC,
+ TotalBlocks: fsstat.Blocks,
+ FreeBlocks: fsstat.BlocksFree,
+ TotalFiles: fsstat.Files,
+ FreeFiles: fsstat.FilesFree,
+ }
+
+ // If blocks available is non-zero, prefer that.
+ if fsstat.BlocksAvailable != 0 {
+ info.FreeBlocks = fsstat.BlocksAvailable
+ }
+
+ return info, nil
+}
+
+func (i *inodeOperations) configureMMap(file *fs.File, opts *memmap.MMapOpts) error {
+ if i.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
+ return fsutil.GenericConfigureMMap(file, i.cachingInodeOps, opts)
+ }
+ if i.fileState.hostMappable != nil {
+ return fsutil.GenericConfigureMMap(file, i.fileState.hostMappable, opts)
+ }
+ return syserror.ENODEV
+}
+
+func init() {
+ syserror.AddErrorUnwrapper(func(err error) (syscall.Errno, bool) {
+ if _, ok := err.(p9.ErrSocket); ok {
+ // Treat as an I/O error.
+ return syscall.EIO, true
+ }
+ return 0, false
+ })
+}
+
+// AddLink implements InodeOperations.AddLink, but is currently a noop.
+// FIXME(b/63117438): Remove this from InodeOperations altogether.
+func (*inodeOperations) AddLink() {}
+
+// DropLink implements InodeOperations.DropLink, but is currently a noop.
+// FIXME(b/63117438): Remove this from InodeOperations altogether.
+func (*inodeOperations) DropLink() {}
+
+// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
+// FIXME(b/63117438): Remove this from InodeOperations altogether.
+func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go
new file mode 100644
index 000000000..ac22ee4b1
--- /dev/null
+++ b/pkg/sentry/fs/gofer/inode_state.go
@@ -0,0 +1,172 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "errors"
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+)
+
+// Some fs implementations may not support atime, ctime, or mtime in getattr.
+// The unstable() logic would try to use clock time for them. However, we do not
+// want to use such time during S/R as that would cause restore timestamp
+// checking failure. Hence a dummy stable-time clock is needed.
+//
+// Note that application-visible UnstableAttrs either come from CachingInodeOps
+// (in which case they are saved), or they are requested from the gofer on each
+// stat (for non-caching), so the dummy time only affects the modification
+// timestamp check.
+type dummyClock struct {
+ time.Clock
+}
+
+// Now returns a stable dummy time.
+func (d *dummyClock) Now() time.Time {
+ return time.Time{}
+}
+
+type dummyClockContext struct {
+ context.Context
+}
+
+// Value implements context.Context
+func (d *dummyClockContext) Value(key interface{}) interface{} {
+ switch key {
+ case time.CtxRealtimeClock:
+ return &dummyClock{}
+ default:
+ return d.Context.Value(key)
+ }
+}
+
+// beforeSave is invoked by stateify.
+func (i *inodeFileState) beforeSave() {
+ if _, ok := i.s.inodeMappings[i.sattr.InodeID]; !ok {
+ panic(fmt.Sprintf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings)))
+ }
+ if i.sattr.Type == fs.RegularFile {
+ uattr, err := i.unstableAttr(&dummyClockContext{context.Background()})
+ if err != nil {
+ panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.s.inodeMappings[i.sattr.InodeID], err)})
+ }
+ i.savedUAttr = &uattr
+ }
+}
+
+// saveLoading is invoked by stateify.
+func (i *inodeFileState) saveLoading() struct{} {
+ return struct{}{}
+}
+
+// splitAbsolutePath splits the path on slashes ignoring the leading slash.
+func splitAbsolutePath(path string) []string {
+ if len(path) == 0 {
+ panic("There is no path!")
+ }
+ if path != filepath.Clean(path) {
+ panic(fmt.Sprintf("path %q is not clean", path))
+ }
+ // This case is to return {} rather than {""}
+ if path == "/" {
+ return []string{}
+ }
+ if path[0] != '/' {
+ panic(fmt.Sprintf("path %q is not absolute", path))
+ }
+
+ s := strings.Split(path, "/")
+
+ // Since p is absolute, the first component of s
+ // is an empty string. We must remove that.
+ return s[1:]
+}
+
+// loadLoading is invoked by stateify.
+func (i *inodeFileState) loadLoading(_ struct{}) {
+ i.loading.Lock()
+}
+
+// afterLoad is invoked by stateify.
+func (i *inodeFileState) afterLoad() {
+ load := func() (err error) {
+ // See comment on i.loading().
+ defer func() {
+ if err == nil {
+ i.loading.Unlock()
+ }
+ }()
+
+ // Manually restore the p9.File.
+ name, ok := i.s.inodeMappings[i.sattr.InodeID]
+ if !ok {
+ // This should be impossible, see assertion in
+ // beforeSave.
+ return fmt.Errorf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings))
+ }
+ // TODO(b/38173783): Context is not plumbed to save/restore.
+ ctx := &dummyClockContext{context.Background()}
+
+ _, i.file, err = i.s.attach.walk(ctx, splitAbsolutePath(name))
+ if err != nil {
+ return fs.ErrCorruption{fmt.Errorf("failed to walk to %q: %v", name, err)}
+ }
+
+ // Remap the saved inode number into the gofer device using the
+ // actual device and actual inode that exists in our new
+ // environment.
+ qid, mask, attrs, err := i.file.getAttr(ctx, p9.AttrMaskAll())
+ if err != nil {
+ return fs.ErrCorruption{fmt.Errorf("failed to get file attributes of %s: %v", name, err)}
+ }
+ if !mask.RDev {
+ return fs.ErrCorruption{fmt.Errorf("file %s lacks device", name)}
+ }
+ i.key = device.MultiDeviceKey{
+ Device: attrs.RDev,
+ SecondaryDevice: i.s.connID,
+ Inode: qid.Path,
+ }
+ if !goferDevice.Load(i.key, i.sattr.InodeID) {
+ return fs.ErrCorruption{fmt.Errorf("gofer device %s -> %d conflict in gofer device mappings: %s", i.key, i.sattr.InodeID, goferDevice)}
+ }
+
+ if i.sattr.Type == fs.RegularFile {
+ env, ok := fs.CurrentRestoreEnvironment()
+ if !ok {
+ return errors.New("missing restore environment")
+ }
+ uattr := unstable(ctx, mask, attrs, i.s.mounter, i.s.client)
+ if env.ValidateFileSize && uattr.Size != i.savedUAttr.Size {
+ return fs.ErrCorruption{fmt.Errorf("file size has changed for %s: previously %d, now %d", i.s.inodeMappings[i.sattr.InodeID], i.savedUAttr.Size, uattr.Size)}
+ }
+ if env.ValidateFileTimestamp && uattr.ModificationTime != i.savedUAttr.ModificationTime {
+ return fs.ErrCorruption{fmt.Errorf("file modification time has changed for %s: previously %v, now %v", i.s.inodeMappings[i.sattr.InodeID], i.savedUAttr.ModificationTime, uattr.ModificationTime)}
+ }
+ i.savedUAttr = nil
+ }
+
+ return nil
+ }
+
+ fs.Async(fs.CatchError(load))
+}
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
new file mode 100644
index 000000000..092f8b586
--- /dev/null
+++ b/pkg/sentry/fs/gofer/path.go
@@ -0,0 +1,433 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
+// encoding of strings, which uses 2 bytes for the length prefix.
+const maxFilenameLen = (1 << 16) - 1
+
+// Lookup loads an Inode at name into a Dirent based on the session's cache
+// policy.
+func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ cp := i.session().cachePolicy
+ if cp.cacheReaddir() {
+ // Check to see if we have readdirCache that indicates the
+ // child does not exist. Avoid holding readdirMu longer than
+ // we need to.
+ i.readdirMu.Lock()
+ if i.readdirCache != nil && !i.readdirCache.Contains(name) {
+ // No such child.
+ i.readdirMu.Unlock()
+ if cp.cacheNegativeDirents() {
+ return fs.NewNegativeDirent(name), nil
+ }
+ return nil, syserror.ENOENT
+ }
+ i.readdirMu.Unlock()
+ }
+
+ // Get a p9.File for name.
+ qids, newFile, mask, p9attr, err := i.fileState.file.walkGetAttr(ctx, []string{name})
+ if err != nil {
+ if err == syscall.ENOENT {
+ if cp.cacheNegativeDirents() {
+ // Return a negative Dirent. It will stay cached until something
+ // is created over it.
+ return fs.NewNegativeDirent(name), nil
+ }
+ return nil, syserror.ENOENT
+ }
+ return nil, err
+ }
+
+ // Construct the Inode operations.
+ sattr, node := newInodeOperations(ctx, i.fileState.s, newFile, qids[0], mask, p9attr, false)
+
+ // Construct a positive Dirent.
+ return fs.NewDirent(fs.NewInode(node, dir.MountSource, sattr), name), nil
+}
+
+// Creates a new Inode at name and returns its File based on the session's cache policy.
+//
+// Ownership is currently ignored.
+func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ // Create replaces the directory fid with the newly created/opened
+ // file, so clone this directory so it doesn't change out from under
+ // this node.
+ _, newFile, err := i.fileState.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // Map the FileFlags to p9 OpenFlags.
+ var openFlags p9.OpenFlags
+ switch {
+ case flags.Read && flags.Write:
+ openFlags = p9.ReadWrite
+ case flags.Read:
+ openFlags = p9.ReadOnly
+ case flags.Write:
+ openFlags = p9.WriteOnly
+ default:
+ panic(fmt.Sprintf("Create called with unknown or unset open flags: %v", flags))
+ }
+
+ owner := fs.FileOwnerFromContext(ctx)
+ hostFile, err := newFile.create(ctx, name, openFlags, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
+ if err != nil {
+ // Could not create the file.
+ newFile.close(ctx)
+ return nil, err
+ }
+
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+
+ // Get an unopened p9.File for the file we created so that it can be cloned
+ // and re-opened multiple times after creation, while also getting its
+ // attributes. Both are required for inodeOperations.
+ qids, unopened, mask, p9attr, err := i.fileState.file.walkGetAttr(ctx, []string{name})
+ if err != nil {
+ newFile.close(ctx)
+ if hostFile != nil {
+ hostFile.Close()
+ }
+ return nil, err
+ }
+ if len(qids) != 1 {
+ log.Warningf("WalkGetAttr(%s) succeeded, but returned %d QIDs (%v), wanted 1", name, len(qids), qids)
+ newFile.close(ctx)
+ if hostFile != nil {
+ hostFile.Close()
+ }
+ unopened.close(ctx)
+ return nil, syserror.EIO
+ }
+ qid := qids[0]
+
+ // Construct the InodeOperations.
+ sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, p9attr, false)
+
+ // Construct the positive Dirent.
+ d := fs.NewDirent(fs.NewInode(iops, dir.MountSource, sattr), name)
+ defer d.DecRef()
+
+ // Construct the new file, caching the handles if allowed.
+ h := &handles{
+ File: newFile,
+ Host: hostFile,
+ }
+ if iops.fileState.canShareHandles() {
+ iops.fileState.handlesMu.Lock()
+ iops.fileState.setSharedHandlesLocked(flags, h)
+ iops.fileState.handlesMu.Unlock()
+ }
+ return NewFile(ctx, d, name, flags, iops, h), nil
+}
+
+// CreateLink uses Create to create a symlink between oldname and newname.
+func (i *inodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname string, newname string) error {
+ if len(newname) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ owner := fs.FileOwnerFromContext(ctx)
+ if _, err := i.fileState.file.symlink(ctx, oldname, newname, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ return err
+ }
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+ return nil
+}
+
+// CreateHardLink implements InodeOperations.CreateHardLink.
+func (i *inodeOperations) CreateHardLink(ctx context.Context, inode *fs.Inode, target *fs.Inode, newName string) error {
+ if len(newName) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ targetOpts, ok := target.InodeOperations.(*inodeOperations)
+ if !ok {
+ return syscall.EXDEV
+ }
+
+ if err := i.fileState.file.link(ctx, &targetOpts.fileState.file, newName); err != nil {
+ return err
+ }
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ // Increase link count.
+ targetOpts.cachingInodeOps.IncLinks(ctx)
+ }
+ i.touchModificationAndStatusChangeTime(ctx, inode)
+ return nil
+}
+
+// CreateDirectory uses Create to create a directory named s under inodeOperations.
+func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s string, perm fs.FilePermissions) error {
+ if len(s) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ owner := fs.FileOwnerFromContext(ctx)
+ if _, err := i.fileState.file.mkdir(ctx, s, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ return err
+ }
+ if i.session().cachePolicy.cacheUAttrs(dir) {
+ // Increase link count.
+ //
+ // N.B. This will update the modification time.
+ i.cachingInodeOps.IncLinks(ctx)
+ }
+ if i.session().cachePolicy.cacheReaddir() {
+ // Invalidate readdir cache.
+ i.markDirectoryDirty()
+ }
+ return nil
+}
+
+// Bind implements InodeOperations.Bind.
+func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, ep transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) {
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ if i.session().endpoints == nil {
+ return nil, syscall.EOPNOTSUPP
+ }
+
+ // Create replaces the directory fid with the newly created/opened
+ // file, so clone this directory so it doesn't change out from under
+ // this node.
+ _, newFile, err := i.fileState.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // Stabilize the endpoint map while creation is in progress.
+ unlock := i.session().endpoints.lock()
+ defer unlock()
+
+ // Create a regular file in the gofer and then mark it as a socket by
+ // adding this inode key in the 'endpoints' map.
+ owner := fs.FileOwnerFromContext(ctx)
+ hostFile, err := newFile.create(ctx, name, p9.ReadWrite, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
+ if err != nil {
+ return nil, err
+ }
+ // We're not going to use this file.
+ hostFile.Close()
+
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+
+ // Get the attributes of the file to create inode key.
+ qid, mask, attr, err := getattr(ctx, newFile)
+ if err != nil {
+ newFile.close(ctx)
+ return nil, err
+ }
+
+ key := device.MultiDeviceKey{
+ Device: attr.RDev,
+ SecondaryDevice: i.session().connID,
+ Inode: qid.Path,
+ }
+
+ // Create child dirent.
+
+ // Get an unopened p9.File for the file we created so that it can be
+ // cloned and re-opened multiple times after creation.
+ _, unopened, err := i.fileState.file.walk(ctx, []string{name})
+ if err != nil {
+ newFile.close(ctx)
+ return nil, err
+ }
+
+ // Construct the InodeOperations.
+ sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, attr, true)
+
+ // Construct the positive Dirent.
+ childDir := fs.NewDirent(fs.NewInode(iops, dir.MountSource, sattr), name)
+ i.session().endpoints.add(key, childDir, ep)
+ return childDir, nil
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ owner := fs.FileOwnerFromContext(ctx)
+ mode := p9.FileMode(perm.LinuxMode()) | p9.ModeNamedPipe
+
+ // N.B. FIFOs use major/minor numbers 0.
+ if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ return err
+ }
+
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+ return nil
+}
+
+// Remove implements InodeOperations.Remove.
+func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ var key device.MultiDeviceKey
+ removeSocket := false
+ if i.session().endpoints != nil {
+ // Find out if file being deleted is a socket that needs to be
+ // removed from endpoint map.
+ if d, err := i.Lookup(ctx, dir, name); err == nil {
+ defer d.DecRef()
+ if fs.IsSocket(d.Inode.StableAttr) {
+ child := d.Inode.InodeOperations.(*inodeOperations)
+ key = child.fileState.key
+ removeSocket = true
+
+ // Stabilize the endpoint map while deletion is in progress.
+ unlock := i.session().endpoints.lock()
+ defer unlock()
+ }
+ }
+ }
+
+ if err := i.fileState.file.unlinkAt(ctx, name, 0); err != nil {
+ return err
+ }
+ if removeSocket {
+ i.session().endpoints.remove(key)
+ }
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+
+ return nil
+}
+
+// Remove implements InodeOperations.RemoveDirectory.
+func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ // 0x200 = AT_REMOVEDIR.
+ if err := i.fileState.file.unlinkAt(ctx, name, 0x200); err != nil {
+ return err
+ }
+ if i.session().cachePolicy.cacheUAttrs(dir) {
+ // Decrease link count and updates atime.
+ i.cachingInodeOps.DecLinks(ctx)
+ }
+ if i.session().cachePolicy.cacheReaddir() {
+ // Invalidate readdir cache.
+ i.markDirectoryDirty()
+ }
+ return nil
+}
+
+// Rename renames this node.
+func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ if len(newName) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+
+ // Unwrap the new parent to a *inodeOperations.
+ newParentInodeOperations, ok := newParent.InodeOperations.(*inodeOperations)
+ if !ok {
+ return syscall.EXDEV
+ }
+
+ // Unwrap the old parent to a *inodeOperations.
+ oldParentInodeOperations, ok := oldParent.InodeOperations.(*inodeOperations)
+ if !ok {
+ return syscall.EXDEV
+ }
+
+ // Do the rename.
+ if err := i.fileState.file.rename(ctx, newParentInodeOperations.fileState.file, newName); err != nil {
+ return err
+ }
+
+ // Is the renamed entity a directory? Fix link counts.
+ if fs.IsDir(i.fileState.sattr) {
+ // Update cached state.
+ if i.session().cachePolicy.cacheUAttrs(oldParent) {
+ oldParentInodeOperations.cachingInodeOps.DecLinks(ctx)
+ }
+ if i.session().cachePolicy.cacheUAttrs(newParent) {
+ // Only IncLinks if there is a new addition to
+ // newParent. If this is replacement, then the total
+ // count remains the same.
+ if !replacement {
+ newParentInodeOperations.cachingInodeOps.IncLinks(ctx)
+ }
+ }
+ }
+ if i.session().cachePolicy.cacheReaddir() {
+ // Mark old directory dirty.
+ oldParentInodeOperations.markDirectoryDirty()
+ if oldParent != newParent {
+ // Mark new directory dirty.
+ newParentInodeOperations.markDirectoryDirty()
+ }
+ }
+
+ // Rename always updates ctime.
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ i.cachingInodeOps.TouchStatusChangeTime(ctx)
+ }
+ return nil
+}
+
+func (i *inodeOperations) touchModificationAndStatusChangeTime(ctx context.Context, inode *fs.Inode) {
+ if i.session().cachePolicy.cacheUAttrs(inode) {
+ i.cachingInodeOps.TouchModificationAndStatusChangeTime(ctx)
+ }
+ if i.session().cachePolicy.cacheReaddir() {
+ // Invalidate readdir cache.
+ i.markDirectoryDirty()
+ }
+}
+
+// markDirectoryDirty marks any cached data dirty for this directory. This is necessary in order
+// to ensure that this node does not retain stale state throughout its lifetime across multiple
+// open directory handles.
+//
+// Currently this means invalidating any readdir caches.
+func (i *inodeOperations) markDirectoryDirty() {
+ i.readdirMu.Lock()
+ defer i.readdirMu.Unlock()
+ i.readdirCache = nil
+}
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
new file mode 100644
index 000000000..085a358fe
--- /dev/null
+++ b/pkg/sentry/fs/gofer/session.go
@@ -0,0 +1,361 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// DefaultDirentCacheSize is the default dirent cache size for 9P mounts. It can
+// be adjusted independentely from the other dirent caches.
+var DefaultDirentCacheSize uint64 = fs.DefaultDirentCacheSize
+
+// +stateify savable
+type endpointMaps struct {
+ // mu protexts the direntMap, the keyMap, and the pathMap below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // direntMap links sockets to their dirents.
+ // It is filled concurrently with the keyMap and is stored upon save.
+ // Before saving, this map is used to populate the pathMap.
+ direntMap map[transport.BoundEndpoint]*fs.Dirent
+
+ // keyMap links MultiDeviceKeys (containing inode IDs) to their sockets.
+ // It is not stored during save because the inode ID may change upon restore.
+ keyMap map[device.MultiDeviceKey]transport.BoundEndpoint `state:"nosave"`
+
+ // pathMap links the sockets to their paths.
+ // It is filled before saving from the direntMap and is stored upon save.
+ // Upon restore, this map is used to re-populate the keyMap.
+ pathMap map[transport.BoundEndpoint]string
+}
+
+// add adds the endpoint to the maps.
+// A reference is taken on the dirent argument.
+//
+// Precondition: maps must have been locked with 'lock'.
+func (e *endpointMaps) add(key device.MultiDeviceKey, d *fs.Dirent, ep transport.BoundEndpoint) {
+ e.keyMap[key] = ep
+ d.IncRef()
+ e.direntMap[ep] = d
+}
+
+// remove deletes the key from the maps.
+//
+// Precondition: maps must have been locked with 'lock'.
+func (e *endpointMaps) remove(key device.MultiDeviceKey) {
+ endpoint := e.get(key)
+ delete(e.keyMap, key)
+
+ d := e.direntMap[endpoint]
+ d.DecRef()
+ delete(e.direntMap, endpoint)
+}
+
+// lock blocks other addition and removal operations from happening while
+// the backing file is being created or deleted. Returns a function that unlocks
+// the endpoint map.
+func (e *endpointMaps) lock() func() {
+ e.mu.Lock()
+ return func() { e.mu.Unlock() }
+}
+
+// get returns the endpoint mapped to the given key.
+//
+// Precondition: maps must have been locked for reading.
+func (e *endpointMaps) get(key device.MultiDeviceKey) transport.BoundEndpoint {
+ return e.keyMap[key]
+}
+
+// session holds state for each 9p session established during sys_mount.
+//
+// +stateify savable
+type session struct {
+ refs.AtomicRefCount
+
+ // msize is the value of the msize mount option, see fs/gofer/fs.go.
+ msize uint32 `state:"wait"`
+
+ // version is the value of the version mount option, see fs/gofer/fs.go.
+ version string `state:"wait"`
+
+ // cachePolicy is the cache policy.
+ cachePolicy cachePolicy `state:"wait"`
+
+ // aname is the value of the aname mount option, see fs/gofer/fs.go.
+ aname string `state:"wait"`
+
+ // The client associated with this session. This will be initialized lazily.
+ client *p9.Client `state:"nosave"`
+
+ // The p9.File pointing to attachName via the client. This will be initialized
+ // lazily.
+ attach contextFile `state:"nosave"`
+
+ // Flags provided to the mount.
+ superBlockFlags fs.MountSourceFlags `state:"wait"`
+
+ // connID is a unique identifier for the session connection.
+ connID string `state:"wait"`
+
+ // inodeMappings contains mappings of fs.Inodes associated with this session
+ // to paths relative to the attach point, where inodeMappings is keyed by
+ // Inode.StableAttr.InodeID.
+ inodeMappings map[uint64]string `state:"wait"`
+
+ // mounter is the EUID/EGID that mounted this file system.
+ mounter fs.FileOwner `state:"wait"`
+
+ // endpoints is used to map inodes that represent socket files to their
+ // corresponding endpoint. Socket files are created as regular files in the
+ // gofer and their presence in this map indicate that they should indeed be
+ // socket files. This allows unix domain sockets to be used with paths that
+ // belong to a gofer.
+ //
+ // TODO(b/77154739): there are few possible races with someone stat'ing the
+ // file and another deleting it concurrently, where the file will not be
+ // reported as socket file.
+ endpoints *endpointMaps `state:"wait"`
+}
+
+// Destroy tears down the session.
+func (s *session) Destroy() {
+ s.client.Close()
+}
+
+// Revalidate implements MountSource.Revalidate.
+func (s *session) Revalidate(ctx context.Context, name string, parent, child *fs.Inode) bool {
+ return s.cachePolicy.revalidate(ctx, name, parent, child)
+}
+
+// Keep implements MountSource.Keep.
+func (s *session) Keep(d *fs.Dirent) bool {
+ return s.cachePolicy.keep(d)
+}
+
+// ResetInodeMappings implements fs.MountSourceOperations.ResetInodeMappings.
+func (s *session) ResetInodeMappings() {
+ s.inodeMappings = make(map[uint64]string)
+}
+
+// SaveInodeMapping implements fs.MountSourceOperations.SaveInodeMapping.
+func (s *session) SaveInodeMapping(inode *fs.Inode, path string) {
+ // This is very unintuitive. We *CANNOT* trust the inode's StableAttrs,
+ // because overlay copyUp may have changed them out from under us.
+ // So much for "immutable".
+ sattr := inode.InodeOperations.(*inodeOperations).fileState.sattr
+ s.inodeMappings[sattr.InodeID] = path
+}
+
+// newInodeOperations creates a new 9p fs.InodeOperations backed by a p9.File and attributes
+// (p9.QID, p9.AttrMask, p9.Attr).
+//
+// Endpoints lock must not be held if socket == false.
+func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p9.QID, valid p9.AttrMask, attr p9.Attr, socket bool) (fs.StableAttr, *inodeOperations) {
+ deviceKey := device.MultiDeviceKey{
+ Device: attr.RDev,
+ SecondaryDevice: s.connID,
+ Inode: qid.Path,
+ }
+
+ sattr := fs.StableAttr{
+ Type: ntype(attr),
+ DeviceID: goferDevice.DeviceID(),
+ InodeID: goferDevice.Map(deviceKey),
+ BlockSize: bsize(attr),
+ }
+
+ if s.endpoints != nil {
+ if socket {
+ sattr.Type = fs.Socket
+ } else {
+ // If unix sockets are allowed on this filesystem, check if this file is
+ // supposed to be a socket file.
+ unlock := s.endpoints.lock()
+ if s.endpoints.get(deviceKey) != nil {
+ sattr.Type = fs.Socket
+ }
+ unlock()
+ }
+ }
+
+ fileState := &inodeFileState{
+ s: s,
+ file: file,
+ sattr: sattr,
+ key: deviceKey,
+ }
+ if s.cachePolicy == cacheRemoteRevalidating && fs.IsFile(sattr) {
+ fileState.hostMappable = fsutil.NewHostMappable(fileState)
+ }
+
+ uattr := unstable(ctx, valid, attr, s.mounter, s.client)
+ return sattr, &inodeOperations{
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, s.superBlockFlags.ForcePageCache),
+ }
+}
+
+// Root returns the root of a 9p mount. This mount is bound to a 9p server
+// based on conn. Otherwise configuration parameters are:
+//
+// * dev: connection id
+// * filesystem: the filesystem backing the mount
+// * superBlockFlags: the mount flags describing general mount options
+// * opts: parsed 9p mount options
+func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockFlags fs.MountSourceFlags, o opts) (*fs.Inode, error) {
+ // The mounting EUID/EGID will be cached by this file system. This will
+ // be used to assign ownership to files that the Gofer owns.
+ mounter := fs.FileOwnerFromContext(ctx)
+
+ conn, err := unet.NewSocket(o.fd)
+ if err != nil {
+ return nil, err
+ }
+
+ // Construct the session.
+ s := &session{
+ connID: dev,
+ msize: o.msize,
+ version: o.version,
+ cachePolicy: o.policy,
+ aname: o.aname,
+ superBlockFlags: superBlockFlags,
+ mounter: mounter,
+ }
+
+ if o.privateunixsocket {
+ s.endpoints = newEndpointMaps()
+ }
+
+ // Construct the MountSource with the session and superBlockFlags.
+ m := fs.NewMountSource(s, filesystem, superBlockFlags)
+
+ // Given that gofer files can consume host FDs, restrict the number
+ // of files that can be held by the cache.
+ m.SetDirentCacheMaxSize(DefaultDirentCacheSize)
+ m.SetDirentCacheLimiter(fs.DirentCacheLimiterFromContext(ctx))
+
+ // Send the Tversion request.
+ s.client, err = p9.NewClient(conn, s.msize, s.version)
+ if err != nil {
+ // Drop our reference on the session, it needs to be torn down.
+ s.DecRef()
+ return nil, err
+ }
+
+ // Notify that we're about to call the Gofer and block.
+ ctx.UninterruptibleSleepStart(false)
+ // Send the Tattach request.
+ s.attach.file, err = s.client.Attach(s.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ // Same as above.
+ s.DecRef()
+ return nil, err
+ }
+
+ qid, valid, attr, err := s.attach.getAttr(ctx, p9.AttrMaskAll())
+ if err != nil {
+ s.attach.close(ctx)
+ // Same as above, but after we execute the Close request.
+ s.DecRef()
+ return nil, err
+ }
+
+ sattr, iops := newInodeOperations(ctx, s, s.attach, qid, valid, attr, false)
+ return fs.NewInode(iops, m, sattr), nil
+}
+
+// newEndpointMaps creates a new endpointMaps.
+func newEndpointMaps() *endpointMaps {
+ return &endpointMaps{
+ direntMap: make(map[transport.BoundEndpoint]*fs.Dirent),
+ keyMap: make(map[device.MultiDeviceKey]transport.BoundEndpoint),
+ pathMap: make(map[transport.BoundEndpoint]string),
+ }
+}
+
+// fillKeyMap populates key and dirent maps upon restore from saved
+// pathmap.
+func (s *session) fillKeyMap(ctx context.Context) error {
+ unlock := s.endpoints.lock()
+ defer unlock()
+
+ for ep, dirPath := range s.endpoints.pathMap {
+ _, file, err := s.attach.walk(ctx, splitAbsolutePath(dirPath))
+ if err != nil {
+ return fmt.Errorf("error filling endpointmaps, failed to walk to %q: %v", dirPath, err)
+ }
+
+ qid, _, attr, err := file.getAttr(ctx, p9.AttrMaskAll())
+ if err != nil {
+ return fmt.Errorf("failed to get file attributes of %s: %v", dirPath, err)
+ }
+
+ key := device.MultiDeviceKey{
+ Device: attr.RDev,
+ SecondaryDevice: s.connID,
+ Inode: qid.Path,
+ }
+
+ s.endpoints.keyMap[key] = ep
+ }
+ return nil
+}
+
+// fillPathMap populates paths for endpoints from dirents in direntMap
+// before save.
+func (s *session) fillPathMap() error {
+ unlock := s.endpoints.lock()
+ defer unlock()
+
+ for ep, dir := range s.endpoints.direntMap {
+ mountRoot := dir.MountRoot()
+ defer mountRoot.DecRef()
+ dirPath, _ := dir.FullName(mountRoot)
+ if dirPath == "" {
+ return fmt.Errorf("error getting path from dirent")
+ }
+ s.endpoints.pathMap[ep] = dirPath
+ }
+ return nil
+}
+
+// restoreEndpointMaps recreates and fills the key and dirent maps.
+func (s *session) restoreEndpointMaps(ctx context.Context) error {
+ // When restoring, only need to create the keyMap because the dirent and path
+ // maps got stored through the save.
+ s.endpoints.keyMap = make(map[device.MultiDeviceKey]transport.BoundEndpoint)
+ if err := s.fillKeyMap(ctx); err != nil {
+ return fmt.Errorf("failed to insert sockets into endpoint map: %v", err)
+ }
+
+ // Re-create pathMap because it can no longer be trusted as socket paths can
+ // change while process continues to run. Empty pathMap will be re-filled upon
+ // next save.
+ s.endpoints.pathMap = make(map[transport.BoundEndpoint]string)
+ return nil
+}
diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go
new file mode 100644
index 000000000..68fbf3417
--- /dev/null
+++ b/pkg/sentry/fs/gofer/session_state.go
@@ -0,0 +1,115 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// beforeSave is invoked by stateify.
+func (s *session) beforeSave() {
+ if s.endpoints != nil {
+ if err := s.fillPathMap(); err != nil {
+ panic("failed to save paths to endpoint map before saving" + err.Error())
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (s *session) afterLoad() {
+ // The restore environment contains the 9p connection of this mount.
+ fsys := filesystem{}
+ env, ok := fs.CurrentRestoreEnvironment()
+ if !ok {
+ panic("failed to find restore environment")
+ }
+ mounts, ok := env.MountSources[fsys.Name()]
+ if !ok {
+ panic("failed to find mounts for filesystem type " + fsys.Name())
+ }
+ var args fs.MountArgs
+ var found bool
+ for _, mount := range mounts {
+ if mount.Dev == s.connID {
+ args = mount
+ found = true
+ }
+ }
+ if !found {
+ panic(fmt.Sprintf("no connection for connection id %q", s.connID))
+ }
+
+ // Validate the mount flags and options.
+ opts, err := options(args.DataString)
+ if err != nil {
+ panic("failed to parse mount options: " + err.Error())
+ }
+ if opts.msize != s.msize {
+ panic(fmt.Sprintf("new message size %v, want %v", opts.msize, s.msize))
+ }
+ if opts.version != s.version {
+ panic(fmt.Sprintf("new version %v, want %v", opts.version, s.version))
+ }
+ if opts.policy != s.cachePolicy {
+ panic(fmt.Sprintf("new cache policy %v, want %v", opts.policy, s.cachePolicy))
+ }
+ if opts.aname != s.aname {
+ panic(fmt.Sprintf("new attach name %v, want %v", opts.aname, s.aname))
+ }
+
+ // Check if endpointMaps exist when uds sockets are enabled
+ // (only pathmap will actualy have been saved).
+ if opts.privateunixsocket != (s.endpoints != nil) {
+ panic(fmt.Sprintf("new privateunixsocket option %v, want %v", opts.privateunixsocket, s.endpoints != nil))
+ }
+ if args.Flags != s.superBlockFlags {
+ panic(fmt.Sprintf("new mount flags %v, want %v", args.Flags, s.superBlockFlags))
+ }
+
+ // Manually restore the connection.
+ conn, err := unet.NewSocket(opts.fd)
+ if err != nil {
+ panic(fmt.Sprintf("failed to create Socket for FD %d: %v", opts.fd, err))
+ }
+
+ // Manually restore the client.
+ s.client, err = p9.NewClient(conn, s.msize, s.version)
+ if err != nil {
+ panic(fmt.Sprintf("failed to connect client to server: %v", err))
+ }
+
+ // Manually restore the attach point.
+ s.attach.file, err = s.client.Attach(s.aname)
+ if err != nil {
+ panic(fmt.Sprintf("failed to attach to aname: %v", err))
+ }
+
+ // If private unix sockets are enabled, create and fill the session's endpoint
+ // maps.
+ if opts.privateunixsocket {
+ // TODO(b/38173783): Context is not plumbed to save/restore.
+ ctx := &dummyClockContext{context.Background()}
+
+ if err = s.restoreEndpointMaps(ctx); err != nil {
+ panic("failed to restore endpoint maps: " + err.Error())
+ }
+ }
+
+}
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
new file mode 100644
index 000000000..cbd5b9a84
--- /dev/null
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -0,0 +1,141 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/host"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// BoundEndpoint returns a gofer-backed transport.BoundEndpoint.
+func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.BoundEndpoint {
+ if !fs.IsSocket(i.fileState.sattr) {
+ return nil
+ }
+
+ if i.session().endpoints != nil {
+ unlock := i.session().endpoints.lock()
+ defer unlock()
+ ep := i.session().endpoints.get(i.fileState.key)
+ if ep != nil {
+ return ep
+ }
+
+ // Not found in endpoints map, it may be a gofer backed unix socket...
+ }
+
+ inode.IncRef()
+ return &endpoint{inode, i.fileState.file.file, path}
+}
+
+// endpoint is a Gofer-backed transport.BoundEndpoint.
+//
+// An endpoint's lifetime is the time between when InodeOperations.BoundEndpoint()
+// is called and either BoundEndpoint.BidirectionalConnect or
+// BoundEndpoint.UnidirectionalConnect is called.
+type endpoint struct {
+ // inode is the filesystem inode which produced this endpoint.
+ inode *fs.Inode
+
+ // file is the p9 file that contains a single unopened fid.
+ file p9.File
+
+ // path is the sentry path where this endpoint is bound.
+ path string
+}
+
+func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) {
+ switch t {
+ case transport.SockStream:
+ return p9.StreamSocket, true
+ case transport.SockSeqpacket:
+ return p9.SeqpacketSocket, true
+ case transport.SockDgram:
+ return p9.DgramSocket, true
+ }
+ return 0, false
+}
+
+// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
+func (e *endpoint) BidirectionalConnect(ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
+ cf, ok := unixSockToP9(ce.Type())
+ if !ok {
+ return syserr.ErrConnectionRefused
+ }
+
+ // No lock ordering required as only the ConnectingEndpoint has a mutex.
+ ce.Lock()
+
+ // Check connecting state.
+ if ce.Connected() {
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ hostFile, err := e.file.Connect(cf)
+ if err != nil {
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+
+ c, serr := host.NewConnectedEndpoint(hostFile, ce.WaiterQueue(), e.path)
+ if serr != nil {
+ ce.Unlock()
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, cf, serr)
+ return serr
+ }
+
+ returnConnect(c, c)
+ ce.Unlock()
+ c.Init()
+
+ return nil
+}
+
+// UnidirectionalConnect implements
+// transport.BoundEndpoint.UnidirectionalConnect.
+func (e *endpoint) UnidirectionalConnect() (transport.ConnectedEndpoint, *syserr.Error) {
+ hostFile, err := e.file.Connect(p9.DgramSocket)
+ if err != nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ c, serr := host.NewConnectedEndpoint(hostFile, &waiter.Queue{}, e.path)
+ if serr != nil {
+ log.Warningf("Gofer returned invalid host socket for UnidirectionalConnect; file %+v: %v", e.file, serr)
+ return nil, serr
+ }
+ c.Init()
+
+ // We don't need the receiver.
+ c.CloseRecv()
+ c.Release()
+
+ return c, nil
+}
+
+// Release implements transport.BoundEndpoint.Release.
+func (e *endpoint) Release() {
+ e.inode.DecRef()
+}
diff --git a/pkg/sentry/fs/gofer/util.go b/pkg/sentry/fs/gofer/util.go
new file mode 100644
index 000000000..d0e1096ce
--- /dev/null
+++ b/pkg/sentry/fs/gofer/util.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/p9"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+func utimes(ctx context.Context, file contextFile, ts fs.TimeSpec) error {
+ if ts.ATimeOmit && ts.MTimeOmit {
+ return nil
+ }
+ mask := p9.SetAttrMask{
+ ATime: !ts.ATimeOmit,
+ ATimeNotSystemTime: !ts.ATimeSetSystemTime,
+ MTime: !ts.MTimeOmit,
+ MTimeNotSystemTime: !ts.MTimeSetSystemTime,
+ }
+ as, ans := ts.ATime.Unix()
+ ms, mns := ts.MTime.Unix()
+ attr := p9.SetAttr{
+ ATimeSeconds: uint64(as),
+ ATimeNanoSeconds: uint64(ans),
+ MTimeSeconds: uint64(ms),
+ MTimeNanoSeconds: uint64(mns),
+ }
+ // 9p2000.L SetAttr: "If a time bit is set without the corresponding SET bit,
+ // the current system time on the server is used instead of the value sent
+ // in the request."
+ return file.setAttr(ctx, mask, attr)
+}
+
+func openFlagsFromPerms(p fs.PermMask) (p9.OpenFlags, error) {
+ switch {
+ case p.Read && p.Write:
+ return p9.ReadWrite, nil
+ case p.Write:
+ return p9.WriteOnly, nil
+ case p.Read:
+ return p9.ReadOnly, nil
+ default:
+ return 0, syscall.EINVAL
+ }
+}
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go
new file mode 100644
index 000000000..9ebb9bbb3
--- /dev/null
+++ b/pkg/sentry/fs/host/control.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+type scmRights struct {
+ fds []int
+}
+
+func newSCMRights(fds []int) control.SCMRights {
+ return &scmRights{fds}
+}
+
+// Files implements control.SCMRights.Files.
+func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFiles, bool) {
+ n := max
+ var trunc bool
+ if l := len(c.fds); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+
+ rf := control.RightsFiles(fdsToFiles(ctx, c.fds[:n]))
+
+ // Only consume converted FDs (fdsToFiles may convert fewer than n FDs).
+ c.fds = c.fds[len(rf):]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (c *scmRights) Clone() transport.RightsControlMessage {
+ // Host rights never need to be cloned.
+ return nil
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (c *scmRights) Release() {
+ for _, fd := range c.fds {
+ syscall.Close(fd)
+ }
+ c.fds = nil
+}
+
+// If an error is encountered, only files created before the error will be
+// returned. This is what Linux does.
+func fdsToFiles(ctx context.Context, fds []int) []*fs.File {
+ files := make([]*fs.File, 0, len(fds))
+ for _, fd := range fds {
+ // Get flags. We do it here because they may be modified
+ // by subsequent functions.
+ fileFlags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_GETFL, 0)
+ if errno != 0 {
+ ctx.Warningf("Error retrieving host FD flags: %v", error(errno))
+ break
+ }
+
+ // Create the file backed by hostFD.
+ file, err := NewFile(ctx, fd, fs.FileOwnerFromContext(ctx))
+ if err != nil {
+ ctx.Warningf("Error creating file from host FD: %v", err)
+ break
+ }
+
+ // Set known flags.
+ file.SetFlags(fs.SettableFileFlags{
+ NonBlocking: fileFlags&syscall.O_NONBLOCK != 0,
+ })
+
+ files = append(files, file)
+ }
+ return files
+}
diff --git a/pkg/sentry/fs/host/descriptor.go b/pkg/sentry/fs/host/descriptor.go
new file mode 100644
index 000000000..ffcd57a94
--- /dev/null
+++ b/pkg/sentry/fs/host/descriptor.go
@@ -0,0 +1,120 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "path"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fdnotifier"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// descriptor wraps a host fd.
+//
+// +stateify savable
+type descriptor struct {
+ // donated is true if the host fd was donated by another process.
+ donated bool
+
+ // If origFD >= 0, it is the host fd that this file was originally created
+ // from, which must be available at time of restore. The FD can be closed
+ // after descriptor is created. Only set if donated is true.
+ origFD int
+
+ // wouldBlock is true if value (below) points to a file that can
+ // return EWOULDBLOCK for operations that would block.
+ wouldBlock bool
+
+ // value is the wrapped host fd. It is never saved or restored
+ // directly. How it is restored depends on whether it was
+ // donated and the fs.MountSource it was originally
+ // opened/created from.
+ value int `state:"nosave"`
+}
+
+// newDescriptor returns a wrapped host file descriptor. On success,
+// the descriptor is registered for event notifications with queue.
+func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) {
+ ownedFD := fd
+ origFD := -1
+ if saveable {
+ var err error
+ ownedFD, err = syscall.Dup(fd)
+ if err != nil {
+ return nil, err
+ }
+ origFD = fd
+ }
+ if wouldBlock {
+ if err := syscall.SetNonblock(ownedFD, true); err != nil {
+ return nil, err
+ }
+ if err := fdnotifier.AddFD(int32(ownedFD), queue); err != nil {
+ return nil, err
+ }
+ }
+ return &descriptor{
+ donated: donated,
+ origFD: origFD,
+ wouldBlock: wouldBlock,
+ value: ownedFD,
+ }, nil
+}
+
+// initAfterLoad initializes the value of the descriptor after Load.
+func (d *descriptor) initAfterLoad(mo *superOperations, id uint64, queue *waiter.Queue) error {
+ if d.donated {
+ var err error
+ d.value, err = syscall.Dup(d.origFD)
+ if err != nil {
+ return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err)
+ }
+ } else {
+ name, ok := mo.inodeMappings[id]
+ if !ok {
+ return fmt.Errorf("failed to find path for inode number %d", id)
+ }
+ fullpath := path.Join(mo.root, name)
+
+ var err error
+ d.value, err = open(nil, fullpath)
+ if err != nil {
+ return fmt.Errorf("failed to open %q: %v", fullpath, err)
+ }
+ }
+ if d.wouldBlock {
+ if err := syscall.SetNonblock(d.value, true); err != nil {
+ return err
+ }
+ if err := fdnotifier.AddFD(int32(d.value), queue); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Release releases all resources held by descriptor.
+func (d *descriptor) Release() {
+ if d.wouldBlock {
+ fdnotifier.RemoveFD(int32(d.value))
+ }
+ if err := syscall.Close(d.value); err != nil {
+ log.Warningf("error closing fd %d: %v", d.value, err)
+ }
+ d.value = -1
+}
diff --git a/pkg/sentry/fs/host/descriptor_state.go b/pkg/sentry/fs/host/descriptor_state.go
new file mode 100644
index 000000000..8167390a9
--- /dev/null
+++ b/pkg/sentry/fs/host/descriptor_state.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+// beforeSave is invoked by stateify.
+func (d *descriptor) beforeSave() {
+ if d.donated && d.origFD < 0 {
+ panic("donated file descriptor cannot be saved")
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *descriptor) afterLoad() {
+ // value must be manually restored by the descriptor's parent using
+ // initAfterLoad.
+ d.value = -1
+}
diff --git a/pkg/sentry/fs/host/device.go b/pkg/sentry/fs/host/device.go
new file mode 100644
index 000000000..055024c44
--- /dev/null
+++ b/pkg/sentry/fs/host/device.go
@@ -0,0 +1,25 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+)
+
+// hostFileDevice is the host file virtual device.
+var hostFileDevice = device.NewAnonMultiDevice()
+
+// hostPipeDevice is the host pipe virtual device.
+var hostPipeDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
new file mode 100644
index 000000000..ad0a3ec85
--- /dev/null
+++ b/pkg/sentry/fs/host/file.go
@@ -0,0 +1,286 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/fdnotifier"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/secio"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// fileOperations implements fs.FileOperations for a host file descriptor.
+//
+// +stateify savable
+type fileOperations struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosplice"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // iops are the Inode operations for this file.
+ iops *inodeOperations `state:"wait"`
+
+ // a scratch buffer for reading directory entries.
+ dirinfo *dirInfo `state:"nosave"`
+
+ // dirCursor is the directory cursor.
+ dirCursor string
+}
+
+// fileOperations implements fs.FileOperations.
+var _ fs.FileOperations = (*fileOperations)(nil)
+
+// NewFile creates a new File backed by the provided host file descriptor. If
+// NewFile succeeds, ownership of the FD is transferred to the returned File.
+//
+// The returned File cannot be saved, since there is no guarantee that the same
+// FD will exist or represent the same file at time of restore. If such a
+// guarantee does exist, use ImportFile instead.
+func NewFile(ctx context.Context, fd int, mounter fs.FileOwner) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, mounter, false, false)
+}
+
+// ImportFile creates a new File backed by the provided host file descriptor.
+// Unlike NewFile, the file descriptor used by the File is duped from FD to
+// ensure that later changes to FD are not reflected by the fs.File.
+//
+// If the returned file is saved, it will be restored by re-importing the FD
+// originally passed to ImportFile. It is the restorer's responsibility to
+// ensure that the FD represents the same file.
+func ImportFile(ctx context.Context, fd int, mounter fs.FileOwner, isTTY bool) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, mounter, true, isTTY)
+}
+
+// newFileFromDonatedFD returns an fs.File from a donated FD. If the FD is
+// saveable, then saveable is true.
+func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner, saveable, isTTY bool) (*fs.File, error) {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(donated, &s); err != nil {
+ return nil, err
+ }
+ flags, err := fileFlagsFromDonatedFD(donated)
+ if err != nil {
+ return nil, err
+ }
+ switch s.Mode & syscall.S_IFMT {
+ case syscall.S_IFSOCK:
+ if isTTY {
+ return nil, fmt.Errorf("cannot import host socket as TTY")
+ }
+
+ s, err := newSocket(ctx, donated, saveable)
+ if err != nil {
+ return nil, err
+ }
+ s.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags.NonBlocking,
+ })
+ return s, nil
+ default:
+ msrc := newMountSource(ctx, "/", mounter, &Filesystem{}, fs.MountSourceFlags{}, false /* dontTranslateOwnership */)
+ inode, err := newInode(ctx, msrc, donated, saveable, true /* donated */)
+ if err != nil {
+ return nil, err
+ }
+ iops := inode.InodeOperations.(*inodeOperations)
+
+ name := fmt.Sprintf("host:[%d]", inode.StableAttr.InodeID)
+ dirent := fs.NewDirent(inode, name)
+ defer dirent.DecRef()
+
+ if isTTY {
+ return newTTYFile(ctx, dirent, flags, iops), nil
+ }
+
+ return newFile(ctx, dirent, flags, iops), nil
+ }
+}
+
+func fileFlagsFromDonatedFD(donated int) (fs.FileFlags, error) {
+ flags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(donated), syscall.F_GETFL, 0)
+ if errno != 0 {
+ log.Warningf("Failed to get file flags for donated FD %d (errno=%d)", donated, errno)
+ return fs.FileFlags{}, syscall.EIO
+ }
+ accmode := flags & syscall.O_ACCMODE
+ return fs.FileFlags{
+ Direct: flags&syscall.O_DIRECT != 0,
+ NonBlocking: flags&syscall.O_NONBLOCK != 0,
+ Sync: flags&syscall.O_SYNC != 0,
+ Append: flags&syscall.O_APPEND != 0,
+ Read: accmode == syscall.O_RDONLY || accmode == syscall.O_RDWR,
+ Write: accmode == syscall.O_WRONLY || accmode == syscall.O_RDWR,
+ }, nil
+}
+
+// newFile returns a new fs.File.
+func newFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
+ if !iops.ReturnsWouldBlock() {
+ // Allow reading/writing at an arbitrary offset for files
+ // that support it.
+ flags.Pread = true
+ flags.Pwrite = true
+ }
+ return fs.NewFile(ctx, dirent, flags, &fileOperations{iops: iops})
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (f *fileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ f.iops.fileState.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(f.iops.fileState.FD()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (f *fileOperations) EventUnregister(e *waiter.Entry) {
+ f.iops.fileState.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(f.iops.fileState.FD()))
+}
+
+// Readiness uses the poll() syscall to check the status of the underlying FD.
+func (f *fileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(f.iops.fileState.FD()), mask)
+}
+
+// Readdir implements fs.FileOperations.Readdir.
+func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &f.dirCursor,
+ }
+ return fs.DirentReaddir(ctx, file.Dirent, f, root, dirCtx, file.Offset())
+}
+
+// IterateDir implements fs.DirIterator.IterateDir.
+func (f *fileOperations) IterateDir(ctx context.Context, dirCtx *fs.DirCtx, offset int) (int, error) {
+ if f.dirinfo == nil {
+ f.dirinfo = new(dirInfo)
+ f.dirinfo.buf = make([]byte, usermem.PageSize)
+ }
+ entries, err := f.iops.readdirAll(f.dirinfo)
+ if err != nil {
+ return offset, err
+ }
+ count, err := fs.GenericReaddir(dirCtx, fs.NewSortedDentryMap(entries))
+ return offset + count, err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ // Would this file block?
+ if f.iops.ReturnsWouldBlock() {
+ // These files can't be memory mapped, assert this. This also
+ // means that writes do not need to synchronize with memory
+ // mappings nor metadata cached by this file's fs.Inode.
+ if canMap(file.Dirent.Inode) {
+ panic("files that can return EWOULDBLOCK cannot be memory mapped")
+ }
+ // Ignore the offset, these files don't support writing at
+ // an arbitrary offset.
+ writer := fd.NewReadWriter(f.iops.fileState.FD())
+ n, err := src.CopyInTo(ctx, safemem.FromIOWriter{writer})
+ if isBlockError(err) {
+ err = syserror.ErrWouldBlock
+ }
+ return n, err
+ }
+ if !file.Dirent.Inode.MountSource.Flags.ForcePageCache {
+ writer := secio.NewOffsetWriter(fd.NewReadWriter(f.iops.fileState.FD()), offset)
+ return src.CopyInTo(ctx, safemem.FromIOWriter{writer})
+ }
+ return f.iops.cachingInodeOps.Write(ctx, src, offset)
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ // Would this file block?
+ if f.iops.ReturnsWouldBlock() {
+ // These files can't be memory mapped, assert this. This also
+ // means that reads do not need to synchronize with memory
+ // mappings nor metadata cached by this file's fs.Inode.
+ if canMap(file.Dirent.Inode) {
+ panic("files that can return EWOULDBLOCK cannot be memory mapped")
+ }
+ // Ignore the offset, these files don't support reading at
+ // an arbitrary offset.
+ reader := fd.NewReadWriter(f.iops.fileState.FD())
+ n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{reader})
+ if isBlockError(err) {
+ // If we got any data at all, return it as a "completed" partial read
+ // rather than retrying until complete.
+ if n != 0 {
+ err = nil
+ } else {
+ err = syserror.ErrWouldBlock
+ }
+ }
+ return n, err
+ }
+ if !file.Dirent.Inode.MountSource.Flags.ForcePageCache {
+ reader := secio.NewOffsetReader(fd.NewReadWriter(f.iops.fileState.FD()), offset)
+ return dst.CopyOutFrom(ctx, safemem.FromIOReader{reader})
+ }
+ return f.iops.cachingInodeOps.Read(ctx, file, dst, offset)
+}
+
+// Fsync implements fs.FileOperations.Fsync.
+func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start int64, end int64, syncType fs.SyncType) error {
+ switch syncType {
+ case fs.SyncAll, fs.SyncData:
+ if err := file.Dirent.Inode.WriteOut(ctx); err != nil {
+ return err
+ }
+ fallthrough
+ case fs.SyncBackingStorage:
+ return syscall.Fsync(f.iops.fileState.FD())
+ }
+ panic("invalid sync type")
+}
+
+// Flush implements fs.FileOperations.Flush.
+func (f *fileOperations) Flush(context.Context, *fs.File) error {
+ // This is a no-op because flushing the resource backing this
+ // file would mean closing it. We can't do that because other
+ // open files may depend on the backing host FD.
+ return nil
+}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (f *fileOperations) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ if !canMap(file.Dirent.Inode) {
+ return syserror.ENODEV
+ }
+ return fsutil.GenericConfigureMMap(file, f.iops.cachingInodeOps, opts)
+}
+
+// Seek implements fs.FileOperations.Seek.
+func (f *fileOperations) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return fsutil.SeekWithDirCursor(ctx, file, whence, offset, &f.dirCursor)
+}
diff --git a/pkg/sentry/fs/host/fs.go b/pkg/sentry/fs/host/fs.go
new file mode 100644
index 000000000..b1b8dc0b6
--- /dev/null
+++ b/pkg/sentry/fs/host/fs.go
@@ -0,0 +1,339 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host implements an fs.Filesystem for files backed by host
+// file descriptors.
+package host
+
+import (
+ "fmt"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// FilesystemName is the name under which Filesystem is registered.
+const FilesystemName = "whitelistfs"
+
+const (
+ // whitelistKey is the mount option containing a comma-separated list
+ // of host paths to whitelist.
+ whitelistKey = "whitelist"
+
+ // rootPathKey is the mount option containing the root path of the
+ // mount.
+ rootPathKey = "root"
+
+ // dontTranslateOwnershipKey is the key to superOperations.dontTranslateOwnership.
+ dontTranslateOwnershipKey = "dont_translate_ownership"
+)
+
+// maxTraversals determines link traversals in building the whitelist.
+const maxTraversals = 10
+
+// Filesystem is a pseudo file system that is only available during the setup
+// to lock down the configurations. This filesystem should only be mounted at root.
+//
+// Think twice before exposing this to applications.
+//
+// +stateify savable
+type Filesystem struct {
+ // whitelist is a set of host paths to whitelist.
+ paths []string
+}
+
+var _ fs.Filesystem = (*Filesystem)(nil)
+
+// Name is the identifier of this file system.
+func (*Filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount prohibits users from using mount(2) with this file system.
+func (*Filesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*Filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+func (*Filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns an fs.Inode exposing the host file system. It is intended to be locked
+// down in PreExec below.
+func (f *Filesystem) Mount(ctx context.Context, _ string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // Parse generic comma-separated key=value options.
+ options := fs.GenericMountSourceOptions(data)
+
+ // Grab the whitelist if one was specified.
+ // TODO(edahlgren/mpratt/hzy): require another option "testonly" in order to allow
+ // no whitelist.
+ if wl, ok := options[whitelistKey]; ok {
+ f.paths = strings.Split(wl, "|")
+ delete(options, whitelistKey)
+ }
+
+ // If the rootPath was set, use it. Othewise default to the root of the
+ // host fs.
+ rootPath := "/"
+ if rp, ok := options[rootPathKey]; ok {
+ rootPath = rp
+ delete(options, rootPathKey)
+
+ // We must relativize the whitelisted paths to the new root.
+ for i, p := range f.paths {
+ rel, err := filepath.Rel(rootPath, p)
+ if err != nil {
+ return nil, fmt.Errorf("whitelist path %q must be a child of root path %q", p, rootPath)
+ }
+ f.paths[i] = path.Join("/", rel)
+ }
+ }
+ fd, err := open(nil, rootPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to find root: %v", err)
+ }
+
+ var dontTranslateOwnership bool
+ if v, ok := options[dontTranslateOwnershipKey]; ok {
+ b, err := strconv.ParseBool(v)
+ if err != nil {
+ return nil, fmt.Errorf("invalid value for %q: %v", dontTranslateOwnershipKey, err)
+ }
+ dontTranslateOwnership = b
+ delete(options, dontTranslateOwnershipKey)
+ }
+
+ // Fail if the caller passed us more options than we know about.
+ if len(options) > 0 {
+ return nil, fmt.Errorf("unsupported mount options: %v", options)
+ }
+
+ // The mounting EUID/EGID will be cached by this file system. This will
+ // be used to assign ownership to files that we own.
+ owner := fs.FileOwnerFromContext(ctx)
+
+ // Construct the host file system mount and inode.
+ msrc := newMountSource(ctx, rootPath, owner, f, flags, dontTranslateOwnership)
+ return newInode(ctx, msrc, fd, false /* saveable */, false /* donated */)
+}
+
+// InstallWhitelist locks down the MountNamespace to only the currently installed
+// Dirents and the given paths.
+func (f *Filesystem) InstallWhitelist(ctx context.Context, m *fs.MountNamespace) error {
+ return installWhitelist(ctx, m, f.paths)
+}
+
+func installWhitelist(ctx context.Context, m *fs.MountNamespace, paths []string) error {
+ if len(paths) == 0 || (len(paths) == 1 && paths[0] == "") {
+ // Warning will be logged during filter installation if the empty
+ // whitelist matters (allows for host file access).
+ return nil
+ }
+
+ // Done tracks entries already added.
+ done := make(map[string]bool)
+ root := m.Root()
+ defer root.DecRef()
+
+ for i := 0; i < len(paths); i++ {
+ // Make sure the path is absolute. This is a sanity check.
+ if !path.IsAbs(paths[i]) {
+ return fmt.Errorf("path %q is not absolute", paths[i])
+ }
+
+ // We need to add all the intermediate paths, in case one of
+ // them is a symlink that needs to be resolved.
+ for j := 1; j <= len(paths[i]); j++ {
+ if j < len(paths[i]) && paths[i][j] != '/' {
+ continue
+ }
+ current := paths[i][:j]
+
+ // Lookup the given component in the tree.
+ remainingTraversals := uint(maxTraversals)
+ d, err := m.FindLink(ctx, root, nil, current, &remainingTraversals)
+ if err != nil {
+ log.Warningf("populate failed for %q: %v", current, err)
+ continue
+ }
+
+ // It's critical that this DecRef happens after the
+ // freeze below. This ensures that the dentry is in
+ // place to be frozen. Otherwise, we freeze without
+ // these entries.
+ defer d.DecRef()
+
+ // Expand the last component if necessary.
+ if current == paths[i] {
+ // Is it a directory or symlink?
+ sattr := d.Inode.StableAttr
+ if fs.IsDir(sattr) {
+ for name := range childDentAttrs(ctx, d) {
+ paths = append(paths, path.Join(current, name))
+ }
+ }
+ if fs.IsSymlink(sattr) {
+ // Only expand symlinks once. The
+ // folder structure may contain
+ // recursive symlinks and we don't want
+ // to end up infinitely expanding this
+ // symlink. This is safe because this
+ // is the last component. If a later
+ // path wants to symlink something
+ // beneath this symlink that will still
+ // be handled by the FindLink above.
+ if done[current] {
+ continue
+ }
+
+ s, err := d.Inode.Readlink(ctx)
+ if err != nil {
+ log.Warningf("readlink failed for %q: %v", current, err)
+ continue
+ }
+ if path.IsAbs(s) {
+ paths = append(paths, s)
+ } else {
+ target := path.Join(path.Dir(current), s)
+ paths = append(paths, target)
+ }
+ }
+ }
+
+ // Only report this one once even though we may look
+ // it up more than once. If we whitelist /a/b,/a then
+ // /a will be "done" when it is looked up for /a/b,
+ // however we still need to expand all of its contents
+ // when whitelisting /a.
+ if !done[current] {
+ log.Debugf("whitelisted: %s", current)
+ }
+ done[current] = true
+ }
+ }
+
+ // Freeze the mount tree in place. This prevents any new paths from
+ // being opened and any old ones from being removed. If we do provide
+ // tmpfs mounts, we'll want to freeze/thaw those separately.
+ m.Freeze()
+ return nil
+}
+
+func childDentAttrs(ctx context.Context, d *fs.Dirent) map[string]fs.DentAttr {
+ dirname, _ := d.FullName(nil /* root */)
+ dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
+ if err != nil {
+ log.Warningf("failed to open directory %q: %v", dirname, err)
+ return nil
+ }
+ dir.DecRef()
+ var stubSerializer fs.CollectEntriesSerializer
+ if err := dir.Readdir(ctx, &stubSerializer); err != nil {
+ log.Warningf("failed to iterate on host directory %q: %v", dirname, err)
+ return nil
+ }
+ delete(stubSerializer.Entries, ".")
+ delete(stubSerializer.Entries, "..")
+ return stubSerializer.Entries
+}
+
+// newMountSource constructs a new host fs.MountSource
+// relative to a root path. The root should match the mount point.
+func newMountSource(ctx context.Context, root string, mounter fs.FileOwner, filesystem fs.Filesystem, flags fs.MountSourceFlags, dontTranslateOwnership bool) *fs.MountSource {
+ return fs.NewMountSource(&superOperations{
+ root: root,
+ inodeMappings: make(map[uint64]string),
+ mounter: mounter,
+ dontTranslateOwnership: dontTranslateOwnership,
+ }, filesystem, flags)
+}
+
+// superOperations implements fs.MountSourceOperations.
+//
+// +stateify savable
+type superOperations struct {
+ fs.SimpleMountSourceOperations
+
+ // root is the path of the mount point. All inode mappings
+ // are relative to this root.
+ root string
+
+ // inodeMappings contains mappings of fs.Inodes associated
+ // with this MountSource to paths under root.
+ inodeMappings map[uint64]string
+
+ // mounter is the cached EUID/EGID that mounted this file system.
+ mounter fs.FileOwner
+
+ // dontTranslateOwnership indicates whether to not translate file
+ // ownership.
+ //
+ // By default, files/directories owned by the sandbox uses UID/GID
+ // of the mounter. For files/directories that are not owned by the
+ // sandbox, file UID/GID is translated to a UID/GID which cannot
+ // be mapped in the sandboxed application's user namespace. The
+ // UID/GID will look like the nobody UID/GID (65534) but is not
+ // strictly owned by the user "nobody".
+ //
+ // If whitelistfs is a lower filesystem in an overlay, set
+ // dont_translate_ownership=true in mount options.
+ dontTranslateOwnership bool
+}
+
+var _ fs.MountSourceOperations = (*superOperations)(nil)
+
+// ResetInodeMappings implements fs.MountSourceOperations.ResetInodeMappings.
+func (m *superOperations) ResetInodeMappings() {
+ m.inodeMappings = make(map[uint64]string)
+}
+
+// SaveInodeMapping implements fs.MountSourceOperations.SaveInodeMapping.
+func (m *superOperations) SaveInodeMapping(inode *fs.Inode, path string) {
+ // This is very unintuitive. We *CANNOT* trust the inode's StableAttrs,
+ // because overlay copyUp may have changed them out from under us.
+ // So much for "immutable".
+ sattr := inode.InodeOperations.(*inodeOperations).fileState.sattr
+ m.inodeMappings[sattr.InodeID] = path
+}
+
+// Keep implements fs.MountSourceOperations.Keep.
+//
+// TODO(b/72455313,b/77596690): It is possible to change the permissions on a
+// host file while it is in the dirent cache (say from RO to RW), but it is not
+// possible to re-open the file with more relaxed permissions, since the host
+// FD is already open and stored in the inode.
+//
+// Using the dirent LRU cache increases the odds that this bug is encountered.
+// Since host file access is relatively fast anyways, we disable the LRU cache
+// for host fs files. Once we can properly deal with permissions changes and
+// re-opening host files, we should revisit whether or not to make use of the
+// LRU cache.
+func (*superOperations) Keep(*fs.Dirent) bool {
+ return false
+}
+
+func init() {
+ fs.RegisterFilesystem(&Filesystem{})
+}
diff --git a/pkg/sentry/fs/host/host_state_autogen.go b/pkg/sentry/fs/host/host_state_autogen.go
new file mode 100755
index 000000000..22cfa1222
--- /dev/null
+++ b/pkg/sentry/fs/host/host_state_autogen.go
@@ -0,0 +1,142 @@
+// automatically generated by stateify.
+
+package host
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *descriptor) save(m state.Map) {
+ x.beforeSave()
+ m.Save("donated", &x.donated)
+ m.Save("origFD", &x.origFD)
+ m.Save("wouldBlock", &x.wouldBlock)
+}
+
+func (x *descriptor) load(m state.Map) {
+ m.Load("donated", &x.donated)
+ m.Load("origFD", &x.origFD)
+ m.Load("wouldBlock", &x.wouldBlock)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *fileOperations) beforeSave() {}
+func (x *fileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("iops", &x.iops)
+ m.Save("dirCursor", &x.dirCursor)
+}
+
+func (x *fileOperations) afterLoad() {}
+func (x *fileOperations) load(m state.Map) {
+ m.LoadWait("iops", &x.iops)
+ m.Load("dirCursor", &x.dirCursor)
+}
+
+func (x *Filesystem) beforeSave() {}
+func (x *Filesystem) save(m state.Map) {
+ x.beforeSave()
+ m.Save("paths", &x.paths)
+}
+
+func (x *Filesystem) afterLoad() {}
+func (x *Filesystem) load(m state.Map) {
+ m.Load("paths", &x.paths)
+}
+
+func (x *superOperations) beforeSave() {}
+func (x *superOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleMountSourceOperations", &x.SimpleMountSourceOperations)
+ m.Save("root", &x.root)
+ m.Save("inodeMappings", &x.inodeMappings)
+ m.Save("mounter", &x.mounter)
+ m.Save("dontTranslateOwnership", &x.dontTranslateOwnership)
+}
+
+func (x *superOperations) afterLoad() {}
+func (x *superOperations) load(m state.Map) {
+ m.Load("SimpleMountSourceOperations", &x.SimpleMountSourceOperations)
+ m.Load("root", &x.root)
+ m.Load("inodeMappings", &x.inodeMappings)
+ m.Load("mounter", &x.mounter)
+ m.Load("dontTranslateOwnership", &x.dontTranslateOwnership)
+}
+
+func (x *inodeOperations) beforeSave() {}
+func (x *inodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("fileState", &x.fileState)
+ m.Save("cachingInodeOps", &x.cachingInodeOps)
+}
+
+func (x *inodeOperations) afterLoad() {}
+func (x *inodeOperations) load(m state.Map) {
+ m.LoadWait("fileState", &x.fileState)
+ m.Load("cachingInodeOps", &x.cachingInodeOps)
+}
+
+func (x *inodeFileState) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.queue) { m.Failf("queue is %v, expected zero", x.queue) }
+ m.Save("mops", &x.mops)
+ m.Save("descriptor", &x.descriptor)
+ m.Save("sattr", &x.sattr)
+ m.Save("savedUAttr", &x.savedUAttr)
+}
+
+func (x *inodeFileState) load(m state.Map) {
+ m.LoadWait("mops", &x.mops)
+ m.LoadWait("descriptor", &x.descriptor)
+ m.LoadWait("sattr", &x.sattr)
+ m.Load("savedUAttr", &x.savedUAttr)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *ConnectedEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("queue", &x.queue)
+ m.Save("path", &x.path)
+ m.Save("ref", &x.ref)
+ m.Save("readClosed", &x.readClosed)
+ m.Save("writeClosed", &x.writeClosed)
+ m.Save("srfd", &x.srfd)
+ m.Save("stype", &x.stype)
+}
+
+func (x *ConnectedEndpoint) load(m state.Map) {
+ m.Load("queue", &x.queue)
+ m.Load("path", &x.path)
+ m.Load("ref", &x.ref)
+ m.Load("readClosed", &x.readClosed)
+ m.Load("writeClosed", &x.writeClosed)
+ m.LoadWait("srfd", &x.srfd)
+ m.Load("stype", &x.stype)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *TTYFileOperations) beforeSave() {}
+func (x *TTYFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("fileOperations", &x.fileOperations)
+ m.Save("session", &x.session)
+ m.Save("fgProcessGroup", &x.fgProcessGroup)
+}
+
+func (x *TTYFileOperations) afterLoad() {}
+func (x *TTYFileOperations) load(m state.Map) {
+ m.Load("fileOperations", &x.fileOperations)
+ m.Load("session", &x.session)
+ m.Load("fgProcessGroup", &x.fgProcessGroup)
+}
+
+func init() {
+ state.Register("host.descriptor", (*descriptor)(nil), state.Fns{Save: (*descriptor).save, Load: (*descriptor).load})
+ state.Register("host.fileOperations", (*fileOperations)(nil), state.Fns{Save: (*fileOperations).save, Load: (*fileOperations).load})
+ state.Register("host.Filesystem", (*Filesystem)(nil), state.Fns{Save: (*Filesystem).save, Load: (*Filesystem).load})
+ state.Register("host.superOperations", (*superOperations)(nil), state.Fns{Save: (*superOperations).save, Load: (*superOperations).load})
+ state.Register("host.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load})
+ state.Register("host.inodeFileState", (*inodeFileState)(nil), state.Fns{Save: (*inodeFileState).save, Load: (*inodeFileState).load})
+ state.Register("host.ConnectedEndpoint", (*ConnectedEndpoint)(nil), state.Fns{Save: (*ConnectedEndpoint).save, Load: (*ConnectedEndpoint).load})
+ state.Register("host.TTYFileOperations", (*TTYFileOperations)(nil), state.Fns{Save: (*TTYFileOperations).save, Load: (*TTYFileOperations).load})
+}
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
new file mode 100644
index 000000000..7a230e426
--- /dev/null
+++ b/pkg/sentry/fs/host/inode.go
@@ -0,0 +1,527 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/secio"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// inodeOperations implements fs.InodeOperations for an fs.Inodes backed
+// by a host file descriptor.
+//
+// +stateify savable
+type inodeOperations struct {
+ fsutil.InodeNotVirtual `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+
+ // fileState implements fs.CachedFileObject. It exists
+ // to break a circular load dependency between inodeOperations
+ // and cachingInodeOps (below).
+ fileState *inodeFileState `state:"wait"`
+
+ // cachedInodeOps implements memmap.Mappable.
+ cachingInodeOps *fsutil.CachingInodeOperations
+
+ // readdirMu protects the file offset on the host FD. This is needed
+ // for readdir because getdents must use the kernel offset, so
+ // concurrent readdirs must be exclusive.
+ //
+ // All read/write functions pass the offset directly to the kernel and
+ // thus don't need a lock.
+ readdirMu sync.Mutex `state:"nosave"`
+}
+
+// inodeFileState implements fs.CachedFileObject and otherwise fully
+// encapsulates state that needs to be manually loaded on restore for
+// this file object.
+//
+// This unfortunate structure exists because fs.CachingInodeOperations
+// defines afterLoad and therefore cannot be lazily loaded (to break a
+// circular load dependency between it and inodeOperations). Even with
+// lazy loading, this approach defines the dependencies between objects
+// and the expected load behavior more concretely.
+//
+// +stateify savable
+type inodeFileState struct {
+ // Common file system state.
+ mops *superOperations `state:"wait"`
+
+ // descriptor is the backing host FD.
+ descriptor *descriptor `state:"wait"`
+
+ // Event queue for blocking operations.
+ queue waiter.Queue `state:"zerovalue"`
+
+ // sattr is used to restore the inodeOperations.
+ sattr fs.StableAttr `state:"wait"`
+
+ // savedUAttr is only allocated during S/R. It points to the save-time
+ // unstable attributes and is used to validate restore-time ones.
+ //
+ // Note that these unstable attributes are only used to detect cross-S/R
+ // external file system metadata changes. They may differ from the
+ // cached unstable attributes in cachingInodeOps, as that might differ
+ // from the external file system attributes if there had been WriteOut
+ // failures. S/R is transparent to Sentry and the latter will continue
+ // using its cached values after restore.
+ savedUAttr *fs.UnstableAttr
+}
+
+// ReadToBlocksAt implements fsutil.CachedFileObject.ReadToBlocksAt.
+func (i *inodeFileState) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ // TODO(jamieliu): Using safemem.FromIOReader here is wasteful for two
+ // reasons:
+ //
+ // - Using preadv instead of iterated preads saves on host system calls.
+ //
+ // - Host system calls can handle destination memory that would fault in
+ // gr3 (i.e. they can accept safemem.Blocks with NeedSafecopy() == true),
+ // so the buffering performed by FromIOReader is unnecessary.
+ //
+ // This also applies to the write path below.
+ return safemem.FromIOReader{secio.NewOffsetReader(fd.NewReadWriter(i.FD()), int64(offset))}.ReadToBlocks(dsts)
+}
+
+// WriteFromBlocksAt implements fsutil.CachedFileObject.WriteFromBlocksAt.
+func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ return safemem.FromIOWriter{secio.NewOffsetWriter(fd.NewReadWriter(i.FD()), int64(offset))}.WriteFromBlocks(srcs)
+}
+
+// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
+ if mask.Empty() {
+ return nil
+ }
+ if mask.UID || mask.GID {
+ return syserror.EPERM
+ }
+ if mask.Perms {
+ if err := syscall.Fchmod(i.FD(), uint32(attr.Perms.LinuxMode())); err != nil {
+ return err
+ }
+ }
+ if mask.Size {
+ if err := syscall.Ftruncate(i.FD(), attr.Size); err != nil {
+ return err
+ }
+ }
+ if mask.AccessTime || mask.ModificationTime {
+ ts := fs.TimeSpec{
+ ATime: attr.AccessTime,
+ ATimeOmit: !mask.AccessTime,
+ MTime: attr.ModificationTime,
+ MTimeOmit: !mask.ModificationTime,
+ }
+ if err := setTimestamps(i.FD(), ts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Sync implements fsutil.CachedFileObject.Sync.
+func (i *inodeFileState) Sync(ctx context.Context) error {
+ return syscall.Fsync(i.FD())
+}
+
+// FD implements fsutil.CachedFileObject.FD.
+func (i *inodeFileState) FD() int {
+ return i.descriptor.value
+}
+
+func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, error) {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.FD(), &s); err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ return unstableAttr(i.mops, &s), nil
+}
+
+// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error {
+ return syscall.Fallocate(i.FD(), 0, offset, length)
+}
+
+// inodeOperations implements fs.InodeOperations.
+var _ fs.InodeOperations = (*inodeOperations)(nil)
+
+// newInode returns a new fs.Inode backed by the host FD.
+func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool, donated bool) (*fs.Inode, error) {
+ // Retrieve metadata.
+ var s syscall.Stat_t
+ err := syscall.Fstat(fd, &s)
+ if err != nil {
+ return nil, err
+ }
+
+ fileState := &inodeFileState{
+ mops: msrc.MountSourceOperations.(*superOperations),
+ sattr: stableAttr(&s),
+ }
+
+ // Initialize the wrapped host file descriptor.
+ fileState.descriptor, err = newDescriptor(
+ fd,
+ donated,
+ saveable,
+ wouldBlock(&s),
+ &fileState.queue,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Build the fs.InodeOperations.
+ uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s)
+ iops := &inodeOperations{
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, msrc.Flags.ForcePageCache),
+ }
+
+ // Return the fs.Inode.
+ return fs.NewInode(iops, msrc, fileState.sattr), nil
+}
+
+// Mappable implements fs.InodeOperations.Mappable.
+func (i *inodeOperations) Mappable(inode *fs.Inode) memmap.Mappable {
+ if !canMap(inode) {
+ return nil
+ }
+ return i.cachingInodeOps
+}
+
+// ReturnsWouldBlock returns true if this host FD can return EWOULDBLOCK for
+// operations that would block.
+func (i *inodeOperations) ReturnsWouldBlock() bool {
+ return i.fileState.descriptor.wouldBlock
+}
+
+// Release implements fs.InodeOperations.Release.
+func (i *inodeOperations) Release(context.Context) {
+ i.fileState.descriptor.Release()
+ i.cachingInodeOps.Release()
+}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
+ // Get a new FD relative to i at name.
+ fd, err := open(i, name)
+ if err != nil {
+ if err == syserror.ENOENT {
+ return nil, syserror.ENOENT
+ }
+ return nil, err
+ }
+
+ inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */)
+ if err != nil {
+ return nil, err
+ }
+
+ // Return the fs.Dirent.
+ return fs.NewDirent(inode, name), nil
+}
+
+// Create implements fs.InodeOperations.Create.
+func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
+ // Create a file relative to i at name.
+ //
+ // N.B. We always open this file O_RDWR regardless of flags because a
+ // future GetFile might want more access. Open allows this regardless
+ // of perm.
+ fd, err := openAt(i, name, syscall.O_RDWR|syscall.O_CREAT|syscall.O_EXCL, perm.LinuxMode())
+ if err != nil {
+ return nil, err
+ }
+
+ inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */)
+ if err != nil {
+ return nil, err
+ }
+
+ d := fs.NewDirent(inode, name)
+ defer d.DecRef()
+ return inode.GetFile(ctx, d, flags)
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ return syscall.Mkdirat(i.fileState.FD(), name, uint32(perm.LinuxMode()))
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+func (i *inodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname string, newname string) error {
+ return createLink(i.fileState.FD(), oldname, newname)
+}
+
+// CreateHardLink implements fs.InodeOperations.CreateHardLink.
+func (*inodeOperations) CreateHardLink(context.Context, *fs.Inode, *fs.Inode, string) error {
+ return syserror.EPERM
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (*inodeOperations) CreateFifo(context.Context, *fs.Inode, string, fs.FilePermissions) error {
+ return syserror.EPERM
+}
+
+// Remove implements fs.InodeOperations.Remove.
+func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
+ return unlinkAt(i.fileState.FD(), name, false /* dir */)
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
+ return unlinkAt(i.fileState.FD(), name, true /* dir */)
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ op, ok := oldParent.InodeOperations.(*inodeOperations)
+ if !ok {
+ return syscall.EXDEV
+ }
+ np, ok := newParent.InodeOperations.(*inodeOperations)
+ if !ok {
+ return syscall.EXDEV
+ }
+ return syscall.Renameat(op.fileState.FD(), oldName, np.fileState.FD(), newName)
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, data transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) {
+ return nil, syserror.EOPNOTSUPP
+}
+
+// BoundEndpoint implements fs.InodeOperations.BoundEndpoint.
+func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.BoundEndpoint {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return newFile(ctx, d, flags, i), nil
+}
+
+// canMap returns true if this fs.Inode can be memory mapped.
+func canMap(inode *fs.Inode) bool {
+ // FIXME(b/38213152): Some obscure character devices can be mapped.
+ return fs.IsFile(inode.StableAttr)
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (i *inodeOperations) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ // When the kernel supports mapping host FDs, we do so to take
+ // advantage of the host page cache. We forego updating fs.Inodes
+ // because the host manages consistency of its own inode structures.
+ //
+ // For fs.Inodes that can never be mapped we take advantage of
+ // synchronizing metadata updates through host caches.
+ //
+ // So can we use host kernel metadata caches?
+ if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
+ // Then just obtain the attributes.
+ return i.fileState.unstableAttr(ctx)
+ }
+ // No, we're maintaining consistency of metadata ourselves.
+ return i.cachingInodeOps.UnstableAttr(ctx, inode)
+}
+
+// Check implements fs.InodeOperations.Check.
+func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (i *inodeOperations) SetOwner(context.Context, *fs.Inode, fs.FileOwner) error {
+ return syserror.EPERM
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (i *inodeOperations) SetPermissions(ctx context.Context, inode *fs.Inode, f fs.FilePermissions) bool {
+ // Can we use host kernel metadata caches?
+ if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
+ // Then just change the timestamps on the FD, the host
+ // will synchronize the metadata update with any host
+ // inode and page cache.
+ return syscall.Fchmod(i.fileState.FD(), uint32(f.LinuxMode())) == nil
+ }
+ // Otherwise update our cached metadata.
+ return i.cachingInodeOps.SetPermissions(ctx, inode, f)
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (i *inodeOperations) SetTimestamps(ctx context.Context, inode *fs.Inode, ts fs.TimeSpec) error {
+ // Can we use host kernel metadata caches?
+ if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
+ // Then just change the timestamps on the FD, the host
+ // will synchronize the metadata update with any host
+ // inode and page cache.
+ return setTimestamps(i.fileState.FD(), ts)
+ }
+ // Otherwise update our cached metadata.
+ return i.cachingInodeOps.SetTimestamps(ctx, inode, ts)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, size int64) error {
+ // Is the file not memory-mappable?
+ if !canMap(inode) {
+ // Then just change the file size on the FD, the host
+ // will synchronize the metadata update with any host
+ // inode and page cache.
+ return syscall.Ftruncate(i.fileState.FD(), size)
+ }
+ // Otherwise we need to go through cachingInodeOps, even if the host page
+ // cache is in use, to invalidate private copies of truncated pages.
+ return i.cachingInodeOps.Truncate(ctx, inode, size)
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, length int64) error {
+ // Is the file not memory-mappable?
+ if !canMap(inode) {
+ // Then just send the call to the FD, the host will synchronize the metadata
+ // update with any host inode and page cache.
+ return i.fileState.Allocate(ctx, offset, length)
+ }
+ // Otherwise we need to go through cachingInodeOps, even if the host page
+ // cache is in use, to invalidate private copies of truncated pages.
+ return i.cachingInodeOps.Allocate(ctx, offset, length)
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ // Have we been using host kernel metadata caches?
+ if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
+ // Then the metadata is already up to date on the host.
+ return nil
+ }
+ // Otherwise we need to write out cached pages and attributes
+ // that are dirty.
+ return i.cachingInodeOps.WriteOut(ctx, inode)
+}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (i *inodeOperations) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ return readLink(i.fileState.FD())
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (i *inodeOperations) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ if !fs.IsSymlink(i.fileState.sattr) {
+ return nil, syserror.ENOLINK
+ }
+ return nil, fs.ErrResolveViaReadlink
+}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) {
+ return fs.Info{}, syserror.ENOSYS
+}
+
+// AddLink implements fs.InodeOperations.AddLink.
+// FIXME(b/63117438): Remove this from InodeOperations altogether.
+func (i *inodeOperations) AddLink() {}
+
+// DropLink implements fs.InodeOperations.DropLink.
+// FIXME(b/63117438): Remove this from InodeOperations altogether.
+func (i *inodeOperations) DropLink() {}
+
+// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
+// FIXME(b/63117438): Remove this from InodeOperations altogether.
+func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
+
+// readdirAll returns all of the directory entries in i.
+func (i *inodeOperations) readdirAll(d *dirInfo) (map[string]fs.DentAttr, error) {
+ i.readdirMu.Lock()
+ defer i.readdirMu.Unlock()
+
+ fd := i.fileState.FD()
+
+ // syscall.ReadDirent will use getdents, which will seek the file past
+ // the last directory entry. To read the directory entries a second
+ // time, we need to seek back to the beginning.
+ if _, err := syscall.Seek(fd, 0, 0); err != nil {
+ if err == syscall.ESPIPE {
+ // All directories should be seekable. If this file
+ // isn't seekable, it is not a directory and we should
+ // return that more sane error.
+ err = syscall.ENOTDIR
+ }
+ return nil, err
+ }
+
+ names := make([]string, 0, 100)
+ for {
+ // Refill the buffer if necessary
+ if d.bufp >= d.nbuf {
+ d.bufp = 0
+ // ReadDirent will just do a sys_getdents64 to the kernel.
+ n, err := syscall.ReadDirent(fd, d.buf)
+ if err != nil {
+ return nil, err
+ }
+ if n == 0 {
+ break // EOF
+ }
+ d.nbuf = n
+ }
+
+ var nb int
+ // Parse the dirent buffer we just get and return the directory names along
+ // with the number of bytes consumed in the buffer.
+ nb, _, names = syscall.ParseDirent(d.buf[d.bufp:d.nbuf], -1, names)
+ d.bufp += nb
+ }
+
+ entries := make(map[string]fs.DentAttr)
+ for _, filename := range names {
+ // Lookup the type and host device and inode.
+ stat, lerr := fstatat(fd, filename, linux.AT_SYMLINK_NOFOLLOW)
+ if lerr == syscall.ENOENT {
+ // File disappeared between readdir and lstat.
+ // Just treat it as if it didn't exist.
+ continue
+ }
+
+ // There was a serious problem, we should probably report it.
+ if lerr != nil {
+ return nil, lerr
+ }
+
+ entries[filename] = fs.DentAttr{
+ Type: nodeType(&stat),
+ InodeID: hostFileDevice.Map(device.MultiDeviceKey{
+ Device: stat.Dev,
+ Inode: stat.Ino,
+ }),
+ }
+ }
+ return entries, nil
+}
diff --git a/pkg/sentry/fs/host/inode_state.go b/pkg/sentry/fs/host/inode_state.go
new file mode 100644
index 000000000..26cc755bc
--- /dev/null
+++ b/pkg/sentry/fs/host/inode_state.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// beforeSave is invoked by stateify.
+func (i *inodeFileState) beforeSave() {
+ if !i.queue.IsEmpty() {
+ panic("event queue must be empty")
+ }
+ if !i.descriptor.donated && i.sattr.Type == fs.RegularFile {
+ uattr, err := i.unstableAttr(context.Background())
+ if err != nil {
+ panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.mops.inodeMappings[i.sattr.InodeID], err)})
+ }
+ i.savedUAttr = &uattr
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (i *inodeFileState) afterLoad() {
+ // Initialize the descriptor value.
+ if err := i.descriptor.initAfterLoad(i.mops, i.sattr.InodeID, &i.queue); err != nil {
+ panic(fmt.Sprintf("failed to load value of descriptor: %v", err))
+ }
+
+ // Remap the inode number.
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.FD(), &s); err != nil {
+ panic(fs.ErrCorruption{fmt.Errorf("failed to get metadata for fd %d: %v", i.FD(), err)})
+ }
+ key := device.MultiDeviceKey{
+ Device: s.Dev,
+ Inode: s.Ino,
+ }
+ if !hostFileDevice.Load(key, i.sattr.InodeID) {
+ // This means there was a conflict at s.Dev and s.Ino with
+ // another inode mapping: two files that were unique on the
+ // saved filesystem are no longer unique on this filesystem.
+ // Since this violates the contract that filesystems cannot
+ // change across save and restore, error out.
+ panic(fs.ErrCorruption{fmt.Errorf("host %s conflict in host device mappings: %s", key, hostFileDevice)})
+ }
+
+ if !i.descriptor.donated && i.sattr.Type == fs.RegularFile {
+ env, ok := fs.CurrentRestoreEnvironment()
+ if !ok {
+ panic("missing restore environment")
+ }
+ uattr := unstableAttr(i.mops, &s)
+ if env.ValidateFileSize && uattr.Size != i.savedUAttr.Size {
+ panic(fs.ErrCorruption{fmt.Errorf("file size has changed for %s: previously %d, now %d", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.Size, uattr.Size)})
+ }
+ if env.ValidateFileTimestamp && uattr.ModificationTime != i.savedUAttr.ModificationTime {
+ panic(fs.ErrCorruption{fmt.Errorf("file modification time has changed for %s: previously %v, now %v", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.ModificationTime, uattr.ModificationTime)})
+ }
+ i.savedUAttr = nil
+ }
+}
diff --git a/pkg/sentry/fs/host/ioctl_unsafe.go b/pkg/sentry/fs/host/ioctl_unsafe.go
new file mode 100644
index 000000000..b5a85c4d9
--- /dev/null
+++ b/pkg/sentry/fs/host/ioctl_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+func ioctlGetTermios(fd int) (*linux.Termios, error) {
+ var t linux.Termios
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &t, nil
+}
+
+func ioctlSetTermios(fd int, req uint64, t *linux.Termios) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(unsafe.Pointer(t)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func ioctlGetWinsize(fd int) (*linux.Winsize, error) {
+ var w linux.Winsize
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCGWINSZ, uintptr(unsafe.Pointer(&w)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &w, nil
+}
+
+func ioctlSetWinsize(fd int, w *linux.Winsize) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCSWINSZ, uintptr(unsafe.Pointer(w)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
new file mode 100644
index 000000000..3ed137006
--- /dev/null
+++ b/pkg/sentry/fs/host/socket.go
@@ -0,0 +1,390 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/fdnotifier"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
+ unixsocket "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
+//
+// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
+const maxSendBufferSize = 8 << 20
+
+// ConnectedEndpoint is a host FD backed implementation of
+// transport.ConnectedEndpoint and transport.Receiver.
+//
+// +stateify savable
+type ConnectedEndpoint struct {
+ queue *waiter.Queue
+ path string
+
+ // ref keeps track of references to a connectedEndpoint.
+ ref refs.AtomicRefCount
+
+ // mu protects fd, readClosed and writeClosed.
+ mu sync.RWMutex `state:"nosave"`
+
+ // file is an *fd.FD containing the FD backing this endpoint. It must be
+ // set to nil if it has been closed.
+ file *fd.FD `state:"nosave"`
+
+ // readClosed is true if the FD has read shutdown or if it has been closed.
+ readClosed bool
+
+ // writeClosed is true if the FD has write shutdown or if it has been
+ // closed.
+ writeClosed bool
+
+ // If srfd >= 0, it is the host FD that file was imported from.
+ srfd int `state:"wait"`
+
+ // stype is the type of Unix socket.
+ stype transport.SockType
+
+ // sndbuf is the size of the send buffer.
+ //
+ // N.B. When this is smaller than the host size, we present it via
+ // GetSockOpt and message splitting/rejection in SendMsg, but do not
+ // prevent lots of small messages from filling the real send buffer
+ // size on the host.
+ sndbuf int `state:"nosave"`
+}
+
+// init performs initialization required for creating new ConnectedEndpoints and
+// for restoring them.
+func (c *ConnectedEndpoint) init() *syserr.Error {
+ family, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_DOMAIN)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if family != syscall.AF_UNIX {
+ // We only allow Unix sockets.
+ return syserr.ErrInvalidEndpointState
+ }
+
+ stype, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if err := syscall.SetNonblock(c.file.FD(), true); err != nil {
+ return syserr.FromError(err)
+ }
+
+ sndbuf, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ if sndbuf > maxSendBufferSize {
+ log.Warningf("Socket send buffer too large: %d", sndbuf)
+ return syserr.ErrInvalidEndpointState
+ }
+
+ c.stype = transport.SockType(stype)
+ c.sndbuf = sndbuf
+
+ return nil
+}
+
+// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host FD
+// that will pretend to be bound at a given sentry path.
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*ConnectedEndpoint, *syserr.Error) {
+ e := ConnectedEndpoint{
+ path: path,
+ queue: queue,
+ file: file,
+ srfd: -1,
+ }
+
+ if err := e.init(); err != nil {
+ return nil, err
+ }
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+
+ return &e, nil
+}
+
+// Init will do initialization required without holding other locks.
+func (c *ConnectedEndpoint) Init() {
+ if err := fdnotifier.AddFD(int32(c.file.FD()), c.queue); err != nil {
+ panic(err)
+ }
+}
+
+// NewSocketWithDirent allocates a new unix socket with host endpoint.
+//
+// This is currently only used by unsaveable Gofer nodes.
+//
+// NewSocketWithDirent takes ownership of f on success.
+func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.FileFlags) (*fs.File, error) {
+ f2 := fd.New(f.FD())
+ var q waiter.Queue
+ e, err := NewConnectedEndpoint(f2, &q, "" /* path */)
+ if err != nil {
+ f2.Release()
+ return nil, err.ToError()
+ }
+
+ // Take ownship of the FD.
+ f.Release()
+
+ e.Init()
+
+ ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
+
+ return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil
+}
+
+// newSocket allocates a new unix socket with host endpoint.
+func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) {
+ ownedfd := orgfd
+ srfd := -1
+ if saveable {
+ var err error
+ ownedfd, err = syscall.Dup(orgfd)
+ if err != nil {
+ return nil, err
+ }
+ srfd = orgfd
+ }
+ f := fd.New(ownedfd)
+ var q waiter.Queue
+ e, err := NewConnectedEndpoint(f, &q, "" /* path */)
+ if err != nil {
+ if saveable {
+ f.Close()
+ } else {
+ f.Release()
+ }
+ return nil, err.ToError()
+ }
+
+ e.srfd = srfd
+ e.Init()
+
+ ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
+
+ return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil
+}
+
+// Send implements transport.ConnectedEndpoint.Send.
+func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.writeClosed {
+ return 0, false, syserr.ErrClosedForSend
+ }
+
+ if !controlMessages.Empty() {
+ return 0, false, syserr.ErrInvalidEndpointState
+ }
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := c.stype == transport.SockStream
+
+ n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate)
+ if n < totalLen && err == nil {
+ // The host only returns a short write if it would otherwise
+ // block (and only for stream sockets).
+ err = syserror.EAGAIN
+ }
+ if n > 0 && err != syserror.EAGAIN {
+ // The caller may need to block to send more data, but
+ // otherwise there isn't anything that can be done about an
+ // error with a partial write.
+ err = nil
+ }
+
+ // There is no need for the callee to call SendNotify because fdWriteVec
+ // uses the host's sendmsg(2) and the host kernel's queue.
+ return n, false, syserr.FromError(err)
+}
+
+// SendNotify implements transport.ConnectedEndpoint.SendNotify.
+func (c *ConnectedEndpoint) SendNotify() {}
+
+// CloseSend implements transport.ConnectedEndpoint.CloseSend.
+func (c *ConnectedEndpoint) CloseSend() {
+ c.mu.Lock()
+ c.writeClosed = true
+ c.mu.Unlock()
+}
+
+// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
+func (c *ConnectedEndpoint) CloseNotify() {}
+
+// Writable implements transport.ConnectedEndpoint.Writable.
+func (c *ConnectedEndpoint) Writable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.writeClosed {
+ return true
+ }
+ return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0
+}
+
+// Passcred implements transport.ConnectedEndpoint.Passcred.
+func (c *ConnectedEndpoint) Passcred() bool {
+ // We don't support credential passing for host sockets.
+ return false
+}
+
+// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil
+}
+
+// EventUpdate implements transport.ConnectedEndpoint.EventUpdate.
+func (c *ConnectedEndpoint) EventUpdate() {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.file.FD() != -1 {
+ fdnotifier.UpdateFD(int32(c.file.FD()))
+ }
+}
+
+// Recv implements transport.Receiver.Recv.
+func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.readClosed {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive
+ }
+
+ var cm unet.ControlMessage
+ if numRights > 0 {
+ cm.EnableFDs(int(numRights))
+ }
+
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf)
+ if rl > 0 && err != nil {
+ // We got some data, so all we need to do on error is return
+ // the data that we got. Short reads are fine, no need to
+ // block.
+ err = nil
+ }
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ // There is no need for the callee to call RecvNotify because fdReadVec uses
+ // the host's recvmsg(2) and the host kernel's queue.
+
+ // Trim the control data if we received less than the full amount.
+ if cl < uint64(len(cm)) {
+ cm = cm[:cl]
+ }
+
+ // Avoid extra allocations in the case where there isn't any control data.
+ if len(cm) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+ }
+
+ fds, err := cm.ExtractFDs()
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ if len(fds) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+ }
+ return rl, ml, control.New(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
+}
+
+// close releases all resources related to the endpoint.
+func (c *ConnectedEndpoint) close() {
+ fdnotifier.RemoveFD(int32(c.file.FD()))
+ c.file.Close()
+ c.file = nil
+}
+
+// RecvNotify implements transport.Receiver.RecvNotify.
+func (c *ConnectedEndpoint) RecvNotify() {}
+
+// CloseRecv implements transport.Receiver.CloseRecv.
+func (c *ConnectedEndpoint) CloseRecv() {
+ c.mu.Lock()
+ c.readClosed = true
+ c.mu.Unlock()
+}
+
+// Readable implements transport.Receiver.Readable.
+func (c *ConnectedEndpoint) Readable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.readClosed {
+ return true
+ }
+ return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0
+}
+
+// SendQueuedSize implements transport.Receiver.SendQueuedSize.
+func (c *ConnectedEndpoint) SendQueuedSize() int64 {
+ // SendQueuedSize isn't supported for host sockets because we don't allow the
+ // sentry to call ioctl(2).
+ return -1
+}
+
+// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
+func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
+ // RecvQueuedSize isn't supported for host sockets because we don't allow the
+ // sentry to call ioctl(2).
+ return -1
+}
+
+// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
+func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
+ return int64(c.sndbuf)
+}
+
+// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
+func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ return int64(c.sndbuf)
+}
+
+// Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release.
+func (c *ConnectedEndpoint) Release() {
+ c.ref.DecRefWithDestructor(c.close)
+}
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
new file mode 100644
index 000000000..5efbb3ae8
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -0,0 +1,113 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// maxIovs is the maximum number of iovecs to pass to the host.
+var maxIovs = linux.UIO_MAXIOV
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+ for _, d := range dst {
+ done := copy(d, src)
+ src = src[done:]
+ if len(src) == 0 {
+ break
+ }
+ }
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+ for _, s := range src {
+ done := copy(dst, s)
+ dst = dst[done:]
+ if len(dst) == 0 {
+ break
+ }
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovecs []syscall.Iovec, intermediate []byte, err error) {
+ var iovsRequired int
+ for _, b := range bufs {
+ length += uintptr(len(b))
+ if len(b) > 0 {
+ iovsRequired++
+ }
+ }
+
+ stopLen := length
+ if length > uintptr(maxlen) {
+ if truncate {
+ stopLen = uintptr(maxlen)
+ err = syserror.EAGAIN
+ } else {
+ return 0, nil, nil, syserror.EMSGSIZE
+ }
+ }
+
+ if iovsRequired > maxIovs {
+ // The kernel will reject our call if we pass this many iovs.
+ // Use a single intermediate buffer instead.
+ b := make([]byte, stopLen)
+
+ return stopLen, []syscall.Iovec{{
+ Base: &b[0],
+ Len: uint64(stopLen),
+ }}, b, err
+ }
+
+ var total uintptr
+ iovecs = make([]syscall.Iovec, 0, iovsRequired)
+ for i := range bufs {
+ l := len(bufs[i])
+ if l == 0 {
+ continue
+ }
+
+ stop := l
+ if total+uintptr(stop) > stopLen {
+ stop = int(stopLen - total)
+ }
+
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(stop),
+ })
+
+ total += uintptr(stop)
+ if total >= stopLen {
+ break
+ }
+ }
+
+ return total, iovecs, nil, err
+}
diff --git a/pkg/sentry/fs/host/socket_state.go b/pkg/sentry/fs/host/socket_state.go
new file mode 100644
index 000000000..5676c451a
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_state.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+)
+
+// beforeSave is invoked by stateify.
+func (c *ConnectedEndpoint) beforeSave() {
+ if c.srfd < 0 {
+ panic("only host file descriptors provided at sentry startup can be saved")
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (c *ConnectedEndpoint) afterLoad() {
+ f, err := syscall.Dup(c.srfd)
+ if err != nil {
+ panic(fmt.Sprintf("failed to dup restored FD %d: %v", c.srfd, err))
+ }
+ c.file = fd.New(f)
+ if err := c.init(); err != nil {
+ panic(fmt.Sprintf("Could not restore host socket FD %d: %v", c.srfd, err))
+ }
+ c.Init()
+}
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
new file mode 100644
index 000000000..e57be0506
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, controlTrunc bool, err error) {
+ flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
+ if peek {
+ flags |= syscall.MSG_PEEK
+ }
+
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, false, err
+ }
+
+ var msg syscall.Msghdr
+ if len(control) != 0 {
+ msg.Control = &control[0]
+ msg.Controllen = uint64(len(control))
+ }
+
+ if len(iovecs) != 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, 0, 0, false, e
+ }
+
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
+ controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
+
+ if n > length {
+ return length, n, msg.Controllen, controlTrunc, err
+ }
+
+ return n, n, msg.Controllen, controlTrunc, err
+}
+
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
+
+ var msg syscall.Msghdr
+ if len(iovecs) > 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
+ }
+
+ return n, length, err
+}
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
new file mode 100644
index 000000000..e45b339f5
--- /dev/null
+++ b/pkg/sentry/fs/host/tty.go
@@ -0,0 +1,351 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// TTYFileOperations implements fs.FileOperations for a host file descriptor
+// that wraps a TTY FD.
+//
+// +stateify savable
+type TTYFileOperations struct {
+ fileOperations
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // session is the session attached to this TTYFileOperations.
+ session *kernel.Session
+
+ // fgProcessGroup is the foreground process group that is currently
+ // connected to this TTY.
+ fgProcessGroup *kernel.ProcessGroup
+}
+
+// newTTYFile returns a new fs.File that wraps a TTY FD.
+func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
+ return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{
+ fileOperations: fileOperations{iops: iops},
+ })
+}
+
+// InitForegroundProcessGroup sets the foreground process group and session for
+// the TTY. This should only be called once, after the foreground process group
+// has been created, but before it has started running.
+func (t *TTYFileOperations) InitForegroundProcessGroup(pg *kernel.ProcessGroup) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.fgProcessGroup != nil {
+ panic("foreground process group is already set")
+ }
+ t.fgProcessGroup = pg
+ t.session = pg.Session()
+}
+
+// ForegroundProcessGroup returns the foreground process for the TTY.
+func (t *TTYFileOperations) ForegroundProcessGroup() *kernel.ProcessGroup {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.fgProcessGroup
+}
+
+// Read implements fs.FileOperations.Read.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+//
+// See drivers/tty/n_tty.c:n_tty_read()=>job_control().
+func (t *TTYFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileOperations.Read(ctx, file, dst, offset)
+}
+
+// Write implements fs.FileOperations.Write.
+func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the write?
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ return t.fileOperations.Write(ctx, file, src, offset)
+}
+
+// Release implements fs.FileOperations.Release.
+func (t *TTYFileOperations) Release() {
+ t.mu.Lock()
+ t.fgProcessGroup = nil
+ t.mu.Unlock()
+
+ t.fileOperations.Release()
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (t *TTYFileOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Ignore arg[0]. This is the real FD:
+ fd := t.fileOperations.iops.fileState.FD()
+ ioctl := args[1].Uint64()
+ switch ioctl {
+ case linux.TCGETS:
+ termios, err := ioctlGetTermios(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TCSETS, linux.TCSETSW, linux.TCSETSF:
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+
+ var termios linux.Termios
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetTermios(fd, ioctl, &termios)
+ return 0, err
+
+ case linux.TIOCGPGRP:
+ // Args: pid_t *argp
+ // When successful, equivalent to *argp = tcgetpgrp(fd).
+ // Get the process group ID of the foreground process group on
+ // this terminal.
+
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Map the ProcessGroup into a ProcessGroupID in the task's PID
+ // namespace.
+ pgID := pidns.IDOfProcessGroup(t.fgProcessGroup)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSPGRP:
+ // Args: const pid_t *argp
+ // Equivalent to tcsetpgrp(fd, *argp).
+ // Set the foreground process group ID of this terminal.
+
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check that we are allowed to set the process group.
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from
+ // tty_check_change() to -ENOTTY.
+ if err == syserror.EIO {
+ return 0, syserror.ENOTTY
+ }
+ return 0, err
+ }
+
+ // Check that calling task's process group is in the TTY
+ // session.
+ if task.ThreadGroup().Session() != t.session {
+ return 0, syserror.ENOTTY
+ }
+
+ var pgID kernel.ProcessGroupID
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ // pgID must be non-negative.
+ if pgID < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Process group with pgID must exist in this PID namespace.
+ pidns := task.PIDNamespace()
+ pg := pidns.ProcessGroupWithID(pgID)
+ if pg == nil {
+ return 0, syserror.ESRCH
+ }
+
+ // Check that new process group is in the TTY session.
+ if pg.Session() != t.session {
+ return 0, syserror.EPERM
+ }
+
+ t.fgProcessGroup = pg
+ return 0, nil
+
+ case linux.TIOCGWINSZ:
+ // Args: struct winsize *argp
+ // Get window size.
+ winsize, err := ioctlGetWinsize(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSWINSZ:
+ // Args: const struct winsize *argp
+ // Set window size.
+
+ // Unlike setting the termios, any process group (even
+ // background ones) can set the winsize.
+
+ var winsize linux.Winsize
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetWinsize(fd, &winsize)
+ return 0, err
+
+ // Unimplemented commands.
+ case linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCNOTTY,
+ linux.TIOCSCTTY,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ fallthrough
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// checkChange checks that the process group is allowed to read, write, or
+// change the state of the TTY.
+//
+// This corresponds to Linux drivers/tty/tty_io.c:tty_check_change(). The logic
+// is a bit convoluted, but documented inline.
+//
+// Preconditions: t.mu must be held.
+func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ // No task? Linux does not have an analog for this case, but
+ // tty_check_change is more of a blacklist of cases than a
+ // whitelist, and is surprisingly permissive. Allowing the
+ // change seems most appropriate.
+ return nil
+ }
+
+ tg := task.ThreadGroup()
+ pg := tg.ProcessGroup()
+
+ // If the session for the task is different than the session for the
+ // controlling TTY, then the change is allowed. Seems like a bad idea,
+ // but that's exactly what linux does.
+ if tg.Session() != t.fgProcessGroup.Session() {
+ return nil
+ }
+
+ // If we are the foreground process group, then the change is allowed.
+ if pg == t.fgProcessGroup {
+ return nil
+ }
+
+ // We are not the foreground process group.
+
+ // Is the provided signal blocked or ignored?
+ if (task.SignalMask()&linux.SignalSetOf(sig) != 0) || tg.SignalHandlers().IsIgnored(sig) {
+ // If the signal is SIGTTIN, then we are attempting to read
+ // from the TTY. Don't send the signal and return EIO.
+ if sig == linux.SIGTTIN {
+ return syserror.EIO
+ }
+
+ // Otherwise, we are writing or changing terminal state. This is allowed.
+ return nil
+ }
+
+ // If the process group is an orphan, return EIO.
+ if pg.IsOrphan() {
+ return syserror.EIO
+ }
+
+ // Otherwise, send the signal to the process group and return ERESTARTSYS.
+ //
+ // Note that Linux also unconditionally sets TIF_SIGPENDING on current,
+ // but this isn't necessary in gVisor because the rationale given in
+ // 040b6362d58f "tty: fix leakage of -ERESTARTSYS to userland" doesn't
+ // apply: the sentry will handle -ERESTARTSYS in
+ // kernel.runApp.execute() even if the kernel.Task isn't interrupted.
+ //
+ // Linux ignores the result of kill_pgrp().
+ _ = pg.SendSignal(kernel.SignalInfoPriv(sig))
+ return kernel.ERESTARTSYS
+}
diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go
new file mode 100644
index 000000000..94ff7708e
--- /dev/null
+++ b/pkg/sentry/fs/host/util.go
@@ -0,0 +1,197 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "os"
+ "path"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+func open(parent *inodeOperations, name string) (int, error) {
+ if parent == nil && !path.IsAbs(name) {
+ return -1, syserror.EINVAL
+ }
+ name = path.Clean(name)
+
+ // Don't follow through symlinks.
+ flags := syscall.O_NOFOLLOW
+
+ if fd, err := openAt(parent, name, flags|syscall.O_RDWR, 0); err == nil {
+ return fd, nil
+ }
+ // Retry as read-only.
+ if fd, err := openAt(parent, name, flags|syscall.O_RDONLY, 0); err == nil {
+ return fd, nil
+ }
+
+ // Retry as write-only.
+ if fd, err := openAt(parent, name, flags|syscall.O_WRONLY, 0); err == nil {
+ return fd, nil
+ }
+
+ // Retry as a symlink, by including O_PATH as an option.
+ fd, err := openAt(parent, name, linux.O_PATH|flags, 0)
+ if err == nil {
+ return fd, nil
+ }
+
+ // Everything failed.
+ return -1, err
+}
+
+func openAt(parent *inodeOperations, name string, flags int, perm linux.FileMode) (int, error) {
+ if parent == nil {
+ return syscall.Open(name, flags, uint32(perm))
+ }
+ return syscall.Openat(parent.fileState.FD(), name, flags, uint32(perm))
+}
+
+func nodeType(s *syscall.Stat_t) fs.InodeType {
+ switch x := (s.Mode & syscall.S_IFMT); x {
+ case syscall.S_IFLNK:
+ return fs.Symlink
+ case syscall.S_IFIFO:
+ return fs.Pipe
+ case syscall.S_IFCHR:
+ return fs.CharacterDevice
+ case syscall.S_IFBLK:
+ return fs.BlockDevice
+ case syscall.S_IFSOCK:
+ return fs.Socket
+ case syscall.S_IFDIR:
+ return fs.Directory
+ case syscall.S_IFREG:
+ return fs.RegularFile
+ default:
+ // This shouldn't happen, but just in case...
+ log.Warningf("unknown host file type %d: assuming regular", x)
+ return fs.RegularFile
+ }
+}
+
+func wouldBlock(s *syscall.Stat_t) bool {
+ typ := nodeType(s)
+ return typ == fs.Pipe || typ == fs.Socket || typ == fs.CharacterDevice
+}
+
+func stableAttr(s *syscall.Stat_t) fs.StableAttr {
+ return fs.StableAttr{
+ Type: nodeType(s),
+ DeviceID: hostFileDevice.DeviceID(),
+ InodeID: hostFileDevice.Map(device.MultiDeviceKey{
+ Device: s.Dev,
+ Inode: s.Ino,
+ }),
+ BlockSize: int64(s.Blksize),
+ }
+}
+
+func owner(mo *superOperations, s *syscall.Stat_t) fs.FileOwner {
+ // User requested no translation, just return actual owner.
+ if mo.dontTranslateOwnership {
+ return fs.FileOwner{auth.KUID(s.Uid), auth.KGID(s.Gid)}
+ }
+
+ // Show only IDs relevant to the sandboxed task. I.e. if we not own the
+ // file, no sandboxed task can own the file. In that case, we
+ // use OverflowID for UID, implying that the IDs are not mapped in the
+ // "root" user namespace.
+ //
+ // E.g.
+ // sandbox's host EUID/EGID is 1/1.
+ // some_dir's host UID/GID is 2/1.
+ // Task that mounted this fs has virtualized EUID/EGID 5/5.
+ //
+ // If you executed `ls -n` in the sandboxed task, it would show:
+ // drwxwrxwrx [...] 65534 5 [...] some_dir
+
+ // Files are owned by OverflowID by default.
+ owner := fs.FileOwner{auth.KUID(auth.OverflowUID), auth.KGID(auth.OverflowGID)}
+
+ // If we own file on host, let mounting task's initial EUID own
+ // the file.
+ if s.Uid == hostUID {
+ owner.UID = mo.mounter.UID
+ }
+
+ // If our group matches file's group, make file's group match
+ // the mounting task's initial EGID.
+ for _, gid := range hostGIDs {
+ if s.Gid == gid {
+ owner.GID = mo.mounter.GID
+ break
+ }
+ }
+ return owner
+}
+
+func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr {
+ return fs.UnstableAttr{
+ Size: s.Size,
+ Usage: s.Blocks * 512,
+ Perms: fs.FilePermsFromMode(linux.FileMode(s.Mode)),
+ Owner: owner(mo, s),
+ AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec),
+ ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec),
+ StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec),
+ Links: s.Nlink,
+ }
+}
+
+type dirInfo struct {
+ buf []byte // buffer for directory I/O.
+ nbuf int // length of buf; return value from ReadDirent.
+ bufp int // location of next record in buf.
+}
+
+// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
+// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
+func isBlockError(err error) bool {
+ if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK {
+ return true
+ }
+ if pe, ok := err.(*os.PathError); ok {
+ return isBlockError(pe.Err)
+ }
+ return false
+}
+
+func hostEffectiveKIDs() (uint32, []uint32, error) {
+ gids, err := os.Getgroups()
+ if err != nil {
+ return 0, nil, err
+ }
+ egids := make([]uint32, len(gids))
+ for i, gid := range gids {
+ egids[i] = uint32(gid)
+ }
+ return uint32(os.Geteuid()), append(egids, uint32(os.Getegid())), nil
+}
+
+var hostUID uint32
+var hostGIDs []uint32
+
+func init() {
+ hostUID, hostGIDs, _ = hostEffectiveKIDs()
+}
diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go
new file mode 100644
index 000000000..b95a57c3f
--- /dev/null
+++ b/pkg/sentry/fs/host/util_unsafe.go
@@ -0,0 +1,137 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+)
+
+// NulByte is a single NUL byte. It is passed to readlinkat as an empty string.
+var NulByte byte = '\x00'
+
+func createLink(fd int, name string, linkName string) error {
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return err
+ }
+ linkNamePtr, err := syscall.BytePtrFromString(linkName)
+ if err != nil {
+ return err
+ }
+ _, _, errno := syscall.Syscall(
+ syscall.SYS_SYMLINKAT,
+ uintptr(unsafe.Pointer(namePtr)),
+ uintptr(fd),
+ uintptr(unsafe.Pointer(linkNamePtr)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func readLink(fd int) (string, error) {
+ // Buffer sizing copied from os.Readlink.
+ for l := 128; ; l *= 2 {
+ b := make([]byte, l)
+ n, _, errno := syscall.Syscall6(
+ syscall.SYS_READLINKAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(&NulByte)), // ""
+ uintptr(unsafe.Pointer(&b[0])),
+ uintptr(l),
+ 0, 0)
+ if errno != 0 {
+ return "", errno
+ }
+ if n < uintptr(l) {
+ return string(b[:n]), nil
+ }
+ }
+}
+
+func unlinkAt(fd int, name string, dir bool) error {
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return err
+ }
+ var flags uintptr
+ if dir {
+ flags = linux.AT_REMOVEDIR
+ }
+ _, _, errno := syscall.Syscall(
+ syscall.SYS_UNLINKAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(namePtr)),
+ flags,
+ )
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func timespecFromTimestamp(t ktime.Time, omit, setSysTime bool) syscall.Timespec {
+ if omit {
+ return syscall.Timespec{0, linux.UTIME_OMIT}
+ }
+ if setSysTime {
+ return syscall.Timespec{0, linux.UTIME_NOW}
+ }
+ return syscall.NsecToTimespec(t.Nanoseconds())
+}
+
+func setTimestamps(fd int, ts fs.TimeSpec) error {
+ if ts.ATimeOmit && ts.MTimeOmit {
+ return nil
+ }
+ var sts [2]syscall.Timespec
+ sts[0] = timespecFromTimestamp(ts.ATime, ts.ATimeOmit, ts.ATimeSetSystemTime)
+ sts[1] = timespecFromTimestamp(ts.MTime, ts.MTimeOmit, ts.MTimeSetSystemTime)
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_UTIMENSAT,
+ uintptr(fd),
+ 0, /* path */
+ uintptr(unsafe.Pointer(&sts)),
+ 0, /* flags */
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return stat, err
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_NEWFSTATAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(namePtr)),
+ uintptr(unsafe.Pointer(&stat)),
+ uintptr(flags),
+ 0, 0)
+ if errno != 0 {
+ return stat, errno
+ }
+ return stat, nil
+}
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
new file mode 100644
index 000000000..aef1a1cb9
--- /dev/null
+++ b/pkg/sentry/fs/inode.go
@@ -0,0 +1,440 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/lock"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+var opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.")
+
+// Inode is a file system object that can be simultaneously referenced by different
+// components of the VFS (Dirent, fs.File, etc).
+//
+// +stateify savable
+type Inode struct {
+ // AtomicRefCount is our reference count.
+ refs.AtomicRefCount
+
+ // InodeOperations is the file system specific behavior of the Inode.
+ InodeOperations InodeOperations
+
+ // StableAttr are stable cached attributes of the Inode.
+ StableAttr StableAttr
+
+ // LockCtx is the file lock context. It manages its own sychronization and tracks
+ // regions of the Inode that have locks held.
+ LockCtx LockCtx
+
+ // Watches is the set of inotify watches for this inode.
+ Watches *Watches
+
+ // MountSource is the mount source this Inode is a part of.
+ MountSource *MountSource
+
+ // overlay is the overlay entry for this Inode.
+ overlay *overlayEntry
+}
+
+// LockCtx is an Inode's lock context and contains different personalities of locks; both
+// Posix and BSD style locks are supported.
+//
+// Note that in Linux fcntl(2) and flock(2) locks are _not_ cooperative, because race and
+// deadlock conditions make merging them prohibitive. We do the same and keep them oblivious
+// to each other but provide a "context" as a convenient container.
+//
+// +stateify savable
+type LockCtx struct {
+ // Posix is a set of POSIX-style regional advisory locks, see fcntl(2).
+ Posix lock.Locks
+
+ // BSD is a set of BSD-style advisory file wide locks, see flock(2).
+ BSD lock.Locks
+}
+
+// NewInode constructs an Inode from InodeOperations, a MountSource, and stable attributes.
+//
+// NewInode takes a reference on msrc.
+func NewInode(iops InodeOperations, msrc *MountSource, sattr StableAttr) *Inode {
+ msrc.IncRef()
+ return &Inode{
+ InodeOperations: iops,
+ StableAttr: sattr,
+ Watches: newWatches(),
+ MountSource: msrc,
+ }
+}
+
+// DecRef drops a reference on the Inode.
+func (i *Inode) DecRef() {
+ i.DecRefWithDestructor(i.destroy)
+}
+
+// destroy releases the Inode and releases the msrc reference taken.
+func (i *Inode) destroy() {
+ // FIXME(b/38173783): Context is not plumbed here.
+ ctx := context.Background()
+ if err := i.WriteOut(ctx); err != nil {
+ // FIXME(b/65209558): Mark as warning again once noatime is
+ // properly supported.
+ log.Debugf("Inode %+v, failed to sync all metadata: %v", i.StableAttr, err)
+ }
+
+ // If this inode is being destroyed because it was unlinked, queue a
+ // deletion event. This may not be the case for inodes being revalidated.
+ if i.Watches.unlinked {
+ i.Watches.Notify("", linux.IN_DELETE_SELF, 0)
+ }
+
+ // Remove references from the watch owners to the watches on this inode,
+ // since the watches are about to be GCed. Note that we don't need to worry
+ // about the watch pins since if there were any active pins, this inode
+ // wouldn't be in the destructor.
+ i.Watches.targetDestroyed()
+
+ if i.overlay != nil {
+ i.overlay.release()
+ } else {
+ i.InodeOperations.Release(ctx)
+ }
+
+ i.MountSource.DecRef()
+}
+
+// Mappable calls i.InodeOperations.Mappable.
+func (i *Inode) Mappable() memmap.Mappable {
+ if i.overlay != nil {
+ // In an overlay, Mappable is always implemented by
+ // the overlayEntry metadata to synchronize memory
+ // access of files with copy up. But first check if
+ // the Inodes involved would be mappable in the first
+ // place.
+ i.overlay.copyMu.RLock()
+ ok := i.overlay.isMappableLocked()
+ i.overlay.copyMu.RUnlock()
+ if !ok {
+ return nil
+ }
+ return i.overlay
+ }
+ return i.InodeOperations.Mappable(i)
+}
+
+// WriteOut calls i.InodeOperations.WriteOut with i as the Inode.
+func (i *Inode) WriteOut(ctx context.Context) error {
+ if i.overlay != nil {
+ return overlayWriteOut(ctx, i.overlay)
+ }
+ return i.InodeOperations.WriteOut(ctx, i)
+}
+
+// Lookup calls i.InodeOperations.Lookup with i as the directory.
+func (i *Inode) Lookup(ctx context.Context, name string) (*Dirent, error) {
+ if i.overlay != nil {
+ d, _, err := overlayLookup(ctx, i.overlay, i, name)
+ return d, err
+ }
+ return i.InodeOperations.Lookup(ctx, i, name)
+}
+
+// Create calls i.InodeOperations.Create with i as the directory.
+func (i *Inode) Create(ctx context.Context, d *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) {
+ if i.overlay != nil {
+ return overlayCreate(ctx, i.overlay, d, name, flags, perm)
+ }
+ return i.InodeOperations.Create(ctx, i, name, flags, perm)
+}
+
+// CreateDirectory calls i.InodeOperations.CreateDirectory with i as the directory.
+func (i *Inode) CreateDirectory(ctx context.Context, d *Dirent, name string, perm FilePermissions) error {
+ if i.overlay != nil {
+ return overlayCreateDirectory(ctx, i.overlay, d, name, perm)
+ }
+ return i.InodeOperations.CreateDirectory(ctx, i, name, perm)
+}
+
+// CreateLink calls i.InodeOperations.CreateLink with i as the directory.
+func (i *Inode) CreateLink(ctx context.Context, d *Dirent, oldname string, newname string) error {
+ if i.overlay != nil {
+ return overlayCreateLink(ctx, i.overlay, d, oldname, newname)
+ }
+ return i.InodeOperations.CreateLink(ctx, i, oldname, newname)
+}
+
+// CreateHardLink calls i.InodeOperations.CreateHardLink with i as the directory.
+func (i *Inode) CreateHardLink(ctx context.Context, d *Dirent, target *Dirent, name string) error {
+ if i.overlay != nil {
+ return overlayCreateHardLink(ctx, i.overlay, d, target, name)
+ }
+ return i.InodeOperations.CreateHardLink(ctx, i, target.Inode, name)
+}
+
+// CreateFifo calls i.InodeOperations.CreateFifo with i as the directory.
+func (i *Inode) CreateFifo(ctx context.Context, d *Dirent, name string, perm FilePermissions) error {
+ if i.overlay != nil {
+ return overlayCreateFifo(ctx, i.overlay, d, name, perm)
+ }
+ return i.InodeOperations.CreateFifo(ctx, i, name, perm)
+}
+
+// Remove calls i.InodeOperations.Remove/RemoveDirectory with i as the directory.
+func (i *Inode) Remove(ctx context.Context, d *Dirent, remove *Dirent) error {
+ if i.overlay != nil {
+ return overlayRemove(ctx, i.overlay, d, remove)
+ }
+ switch remove.Inode.StableAttr.Type {
+ case Directory, SpecialDirectory:
+ return i.InodeOperations.RemoveDirectory(ctx, i, remove.name)
+ default:
+ return i.InodeOperations.Remove(ctx, i, remove.name)
+ }
+}
+
+// Rename calls i.InodeOperations.Rename with the given arguments.
+func (i *Inode) Rename(ctx context.Context, oldParent *Dirent, renamed *Dirent, newParent *Dirent, newName string, replacement bool) error {
+ if i.overlay != nil {
+ return overlayRename(ctx, i.overlay, oldParent, renamed, newParent, newName, replacement)
+ }
+ return i.InodeOperations.Rename(ctx, renamed.Inode, oldParent.Inode, renamed.name, newParent.Inode, newName, replacement)
+}
+
+// Bind calls i.InodeOperations.Bind with i as the directory.
+func (i *Inode) Bind(ctx context.Context, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
+ if i.overlay != nil {
+ return overlayBind(ctx, i.overlay, name, data, perm)
+ }
+ return i.InodeOperations.Bind(ctx, i, name, data, perm)
+}
+
+// BoundEndpoint calls i.InodeOperations.BoundEndpoint with i as the Inode.
+func (i *Inode) BoundEndpoint(path string) transport.BoundEndpoint {
+ if i.overlay != nil {
+ return overlayBoundEndpoint(i.overlay, path)
+ }
+ return i.InodeOperations.BoundEndpoint(i, path)
+}
+
+// GetFile calls i.InodeOperations.GetFile with the given arguments.
+func (i *Inode) GetFile(ctx context.Context, d *Dirent, flags FileFlags) (*File, error) {
+ if i.overlay != nil {
+ return overlayGetFile(ctx, i.overlay, d, flags)
+ }
+ opens.Increment()
+ return i.InodeOperations.GetFile(ctx, d, flags)
+}
+
+// UnstableAttr calls i.InodeOperations.UnstableAttr with i as the Inode.
+func (i *Inode) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
+ if i.overlay != nil {
+ return overlayUnstableAttr(ctx, i.overlay)
+ }
+ return i.InodeOperations.UnstableAttr(ctx, i)
+}
+
+// Getxattr calls i.InodeOperations.Getxattr with i as the Inode.
+func (i *Inode) Getxattr(name string) (string, error) {
+ if i.overlay != nil {
+ return overlayGetxattr(i.overlay, name)
+ }
+ return i.InodeOperations.Getxattr(i, name)
+}
+
+// Listxattr calls i.InodeOperations.Listxattr with i as the Inode.
+func (i *Inode) Listxattr() (map[string]struct{}, error) {
+ if i.overlay != nil {
+ return overlayListxattr(i.overlay)
+ }
+ return i.InodeOperations.Listxattr(i)
+}
+
+// CheckPermission will check if the caller may access this file in the
+// requested way for reading, writing, or executing.
+//
+// CheckPermission is like Linux's fs/namei.c:inode_permission. It
+// - checks file system mount flags,
+// - and utilizes InodeOperations.Check to check capabilities and modes.
+func (i *Inode) CheckPermission(ctx context.Context, p PermMask) error {
+ // First check the outer-most mounted filesystem.
+ if p.Write && i.MountSource.Flags.ReadOnly {
+ return syserror.EROFS
+ }
+
+ if i.overlay != nil {
+ // CheckPermission requires some special handling for
+ // an overlay.
+ //
+ // Writes will always be redirected to an upper filesystem,
+ // so ignore all lower layers being read-only.
+ //
+ // But still honor the upper-most filesystem's mount flags;
+ // we should not attempt to modify the writable layer if it
+ // is mounted read-only.
+ if p.Write && overlayUpperMountSource(i.MountSource).Flags.ReadOnly {
+ return syserror.EROFS
+ }
+ }
+
+ return i.check(ctx, p)
+}
+
+func (i *Inode) check(ctx context.Context, p PermMask) error {
+ if i.overlay != nil {
+ return overlayCheck(ctx, i.overlay, p)
+ }
+ if !i.InodeOperations.Check(ctx, i, p) {
+ return syserror.EACCES
+ }
+ return nil
+}
+
+// SetPermissions calls i.InodeOperations.SetPermissions with i as the Inode.
+func (i *Inode) SetPermissions(ctx context.Context, d *Dirent, f FilePermissions) bool {
+ if i.overlay != nil {
+ return overlaySetPermissions(ctx, i.overlay, d, f)
+ }
+ return i.InodeOperations.SetPermissions(ctx, i, f)
+}
+
+// SetOwner calls i.InodeOperations.SetOwner with i as the Inode.
+func (i *Inode) SetOwner(ctx context.Context, d *Dirent, o FileOwner) error {
+ if i.overlay != nil {
+ return overlaySetOwner(ctx, i.overlay, d, o)
+ }
+ return i.InodeOperations.SetOwner(ctx, i, o)
+}
+
+// SetTimestamps calls i.InodeOperations.SetTimestamps with i as the Inode.
+func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error {
+ if i.overlay != nil {
+ return overlaySetTimestamps(ctx, i.overlay, d, ts)
+ }
+ return i.InodeOperations.SetTimestamps(ctx, i, ts)
+}
+
+// Truncate calls i.InodeOperations.Truncate with i as the Inode.
+func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error {
+ if i.overlay != nil {
+ return overlayTruncate(ctx, i.overlay, d, size)
+ }
+ return i.InodeOperations.Truncate(ctx, i, size)
+}
+
+func (i *Inode) Allocate(ctx context.Context, d *Dirent, offset int64, length int64) error {
+ if i.overlay != nil {
+ return overlayAllocate(ctx, i.overlay, d, offset, length)
+ }
+ return i.InodeOperations.Allocate(ctx, i, offset, length)
+}
+
+// Readlink calls i.InodeOperations.Readlnk with i as the Inode.
+func (i *Inode) Readlink(ctx context.Context) (string, error) {
+ if i.overlay != nil {
+ return overlayReadlink(ctx, i.overlay)
+ }
+ return i.InodeOperations.Readlink(ctx, i)
+}
+
+// Getlink calls i.InodeOperations.Getlink.
+func (i *Inode) Getlink(ctx context.Context) (*Dirent, error) {
+ if i.overlay != nil {
+ return overlayGetlink(ctx, i.overlay)
+ }
+ return i.InodeOperations.Getlink(ctx, i)
+}
+
+// AddLink calls i.InodeOperations.AddLink.
+func (i *Inode) AddLink() {
+ if i.overlay != nil {
+ // FIXME(b/63117438): Remove this from InodeOperations altogether.
+ //
+ // This interface is only used by ramfs to update metadata of
+ // children. These filesystems should _never_ have overlay
+ // Inodes cached as children. So explicitly disallow this
+ // scenario and avoid plumbing Dirents through to do copy up.
+ panic("overlay Inodes cached in ramfs directories are not supported")
+ }
+ i.InodeOperations.AddLink()
+}
+
+// DropLink calls i.InodeOperations.DropLink.
+func (i *Inode) DropLink() {
+ if i.overlay != nil {
+ // Same as AddLink.
+ panic("overlay Inodes cached in ramfs directories are not supported")
+ }
+ i.InodeOperations.DropLink()
+}
+
+// IsVirtual calls i.InodeOperations.IsVirtual.
+func (i *Inode) IsVirtual() bool {
+ if i.overlay != nil {
+ // An overlay configuration does not support virtual files.
+ return false
+ }
+ return i.InodeOperations.IsVirtual()
+}
+
+// StatFS calls i.InodeOperations.StatFS.
+func (i *Inode) StatFS(ctx context.Context) (Info, error) {
+ if i.overlay != nil {
+ return overlayStatFS(ctx, i.overlay)
+ }
+ return i.InodeOperations.StatFS(ctx)
+}
+
+// CheckOwnership checks whether `ctx` owns this Inode or may act as its owner.
+// Compare Linux's fs/inode.c:inode_owner_or_capable().
+func (i *Inode) CheckOwnership(ctx context.Context) bool {
+ uattr, err := i.UnstableAttr(ctx)
+ if err != nil {
+ return false
+ }
+ creds := auth.CredentialsFromContext(ctx)
+ if uattr.Owner.UID == creds.EffectiveKUID {
+ return true
+ }
+ if creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(uattr.Owner.UID).Ok() {
+ return true
+ }
+ return false
+}
+
+// CheckCapability checks whether `ctx` has capability `cp` with respect to
+// operations on this Inode.
+//
+// Compare Linux's kernel/capability.c:capable_wrt_inode_uidgid().
+func (i *Inode) CheckCapability(ctx context.Context, cp linux.Capability) bool {
+ uattr, err := i.UnstableAttr(ctx)
+ if err != nil {
+ return false
+ }
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.UserNamespace.MapFromKUID(uattr.Owner.UID).Ok() {
+ return false
+ }
+ if !creds.UserNamespace.MapFromKGID(uattr.Owner.GID).Ok() {
+ return false
+ }
+ return creds.HasCapability(cp)
+}
diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go
new file mode 100644
index 000000000..0f2a66a79
--- /dev/null
+++ b/pkg/sentry/fs/inode_inotify.go
@@ -0,0 +1,169 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "sync"
+)
+
+// Watches is the collection of inotify watches on an inode.
+//
+// +stateify savable
+type Watches struct {
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // ws is the map of active watches in this collection, keyed by the inotify
+ // instance id of the owner.
+ ws map[uint64]*Watch
+
+ // unlinked indicates whether the target inode was ever unlinked. This is a
+ // hack to figure out if we should queue a IN_DELETE_SELF event when this
+ // watches collection is being destroyed, since otherwise we have no way of
+ // knowing if the target inode is going down due to a deletion or
+ // revalidation.
+ unlinked bool
+}
+
+func newWatches() *Watches {
+ return &Watches{}
+}
+
+// MarkUnlinked indicates the target for this set of watches to be unlinked.
+// This has implications for the IN_EXCL_UNLINK flag.
+func (w *Watches) MarkUnlinked() {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ w.unlinked = true
+}
+
+// Lookup returns a matching watch with the given id. Returns nil if no such
+// watch exists. Note that the result returned by this method only remains valid
+// if the inotify instance owning the watch is locked, preventing modification
+// of the returned watch and preventing the replacement of the watch by another
+// one from the same instance (since there may be at most one watch per
+// instance, per target).
+func (w *Watches) Lookup(id uint64) *Watch {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.ws[id]
+}
+
+// Add adds watch into this set of watches. The watch being added must be unique
+// - its ID() should not collide with any existing watches.
+func (w *Watches) Add(watch *Watch) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Sanity check, the new watch shouldn't collide with an existing
+ // watch. Silently replacing an existing watch would result in a ref leak on
+ // this inode. We could handle this collision by calling Unpin() on the
+ // existing watch, but then we end up leaking watch descriptor ids at the
+ // inotify level.
+ if _, exists := w.ws[watch.ID()]; exists {
+ panic(fmt.Sprintf("Watch collision with ID %+v", watch.ID()))
+ }
+ if w.ws == nil {
+ w.ws = make(map[uint64]*Watch)
+ }
+ w.ws[watch.ID()] = watch
+}
+
+// Remove removes a watch with the given id from this set of watches. The caller
+// is responsible for generating any watch removal event, as appropriate. The
+// provided id must match an existing watch in this collection.
+func (w *Watches) Remove(id uint64) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if w.ws == nil {
+ // This watch set is being destroyed. The thread executing the
+ // destructor is already in the process of deleting all our watches. We
+ // got here with no refs on the inode because we raced with the
+ // destructor notifying all the watch owners of the inode's destruction.
+ // See the comment in Watches.TargetDestroyed for why this race exists.
+ return
+ }
+
+ watch, ok := w.ws[id]
+ if !ok {
+ // While there's technically no problem with silently ignoring a missing
+ // watch, this is almost certainly a bug.
+ panic(fmt.Sprintf("Attempt to remove a watch, but no watch found with provided id %+v.", id))
+ }
+ delete(w.ws, watch.ID())
+}
+
+// Notify queues a new event with all watches in this set.
+func (w *Watches) Notify(name string, events, cookie uint32) {
+ // N.B. We don't defer the unlocks because Notify is in the hot path of
+ // all IO operations, and the defer costs too much for small IO
+ // operations.
+ w.mu.RLock()
+ for _, watch := range w.ws {
+ if name != "" && w.unlinked && !watch.NotifyParentAfterUnlink() {
+ // IN_EXCL_UNLINK - By default, when watching events on the children
+ // of a directory, events are generated for children even after they
+ // have been unlinked from the directory. This can result in large
+ // numbers of uninteresting events for some applications (e.g., if
+ // watching /tmp, in which many applications create temporary files
+ // whose names are immediately unlinked). Specifying IN_EXCL_UNLINK
+ // changes the default behavior, so that events are not generated
+ // for children after they have been unlinked from the watched
+ // directory. -- inotify(7)
+ //
+ // We know we're dealing with events for a parent when the name
+ // isn't empty.
+ continue
+ }
+ watch.Notify(name, events, cookie)
+ }
+ w.mu.RUnlock()
+}
+
+// Unpin unpins dirent from all watches in this set.
+func (w *Watches) Unpin(d *Dirent) {
+ w.mu.RLock()
+ defer w.mu.RUnlock()
+ for _, watch := range w.ws {
+ watch.Unpin(d)
+ }
+}
+
+// targetDestroyed is called by the inode destructor to notify the watch owners
+// of the impending destruction of the watch target.
+func (w *Watches) targetDestroyed() {
+ var ws map[uint64]*Watch
+
+ // We can't hold w.mu while calling watch.TargetDestroyed to preserve lock
+ // ordering w.r.t to the owner inotify instances. Instead, atomically move
+ // the watches map into a local variable so we can iterate over it safely.
+ //
+ // Because of this however, it is possible for the watches' owners to reach
+ // this inode while the inode has no refs. This is still safe because the
+ // owners can only reach the inode until this function finishes calling
+ // watch.TargetDestroyed() below and the inode is guaranteed to exist in the
+ // meanwhile. But we still have to be very careful not to rely on inode
+ // state that may have been already destroyed.
+ w.mu.Lock()
+ ws = w.ws
+ w.ws = nil
+ w.mu.Unlock()
+
+ for _, watch := range ws {
+ watch.TargetDestroyed()
+ }
+}
diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go
new file mode 100644
index 000000000..ea089dfae
--- /dev/null
+++ b/pkg/sentry/fs/inode_operations.go
@@ -0,0 +1,308 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "errors"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+var (
+ // ErrResolveViaReadlink is a special error value returned by
+ // InodeOperations.Getlink() to indicate that a link should be
+ // resolved automatically by walking to the path returned by
+ // InodeOperations.Readlink().
+ ErrResolveViaReadlink = errors.New("link should be resolved via Readlink()")
+)
+
+// TimeSpec contains access and modification timestamps. If either ATimeOmit or
+// MTimeOmit is true, then the corresponding timestamp should not be updated.
+// If either ATimeSetSystemTime or MTimeSetSystemTime are set then the
+// corresponding timestamp should be ignored and the time will be set to the
+// current system time.
+type TimeSpec struct {
+ ATime ktime.Time
+ ATimeOmit bool
+ ATimeSetSystemTime bool
+ MTime ktime.Time
+ MTimeOmit bool
+ MTimeSetSystemTime bool
+}
+
+// InodeOperations are operations on an Inode that diverge per file system.
+//
+// Objects that implement InodeOperations may cache file system "private"
+// data that is useful for implementing these methods. In contrast, Inode
+// contains state that is common to all Inodes; this state may be optionally
+// used by InodeOperations. An object that implements InodeOperations may
+// not take a reference on an Inode.
+type InodeOperations interface {
+ // Release releases all private file system data held by this object.
+ // Once Release is called, this object is dead (no other methods will
+ // ever be called).
+ Release(context.Context)
+
+ // Lookup loads an Inode at name under dir into a Dirent. The name
+ // is a valid component path: it contains no "/"s nor is the empty
+ // string.
+ //
+ // Lookup may return one of:
+ //
+ // * A nil Dirent and a non-nil error. If the reason that Lookup failed
+ // was because the name does not exist under Inode, then must return
+ // syserror.ENOENT.
+ //
+ // * If name does not exist under dir and the file system wishes this
+ // fact to be cached, a non-nil Dirent containing a nil Inode and a
+ // nil error. This is a negative Dirent and must have exactly one
+ // reference (at-construction reference).
+ //
+ // * If name does exist under this dir, a non-nil Dirent containing a
+ // non-nil Inode, and a nil error. File systems that take extra
+ // references on this Dirent should implement DirentOperations.
+ Lookup(ctx context.Context, dir *Inode, name string) (*Dirent, error)
+
+ // Create creates an Inode at name under dir and returns a new File
+ // whose Dirent backs the new Inode. Implementations must ensure that
+ // name does not already exist. Create may return one of:
+ //
+ // * A nil File and a non-nil error.
+ //
+ // * A non-nil File and a nil error. File.Dirent will be a new Dirent,
+ // with a single reference held by File. File systems that take extra
+ // references on this Dirent should implement DirentOperations.
+ //
+ // The caller must ensure that this operation is permitted.
+ Create(ctx context.Context, dir *Inode, name string, flags FileFlags, perm FilePermissions) (*File, error)
+
+ // CreateDirectory creates a new directory under this dir.
+ // CreateDirectory should otherwise do the same as Create.
+ //
+ // The caller must ensure that this operation is permitted.
+ CreateDirectory(ctx context.Context, dir *Inode, name string, perm FilePermissions) error
+
+ // CreateLink creates a symbolic link under dir between newname
+ // and oldname. CreateLink should otherwise do the same as Create.
+ //
+ // The caller must ensure that this operation is permitted.
+ CreateLink(ctx context.Context, dir *Inode, oldname string, newname string) error
+
+ // CreateHardLink creates a hard link under dir between the target
+ // Inode and name.
+ //
+ // The caller must ensure this operation is permitted.
+ CreateHardLink(ctx context.Context, dir *Inode, target *Inode, name string) error
+
+ // CreateFifo creates a new named pipe under dir at name.
+ //
+ // The caller must ensure that this operation is permitted.
+ CreateFifo(ctx context.Context, dir *Inode, name string, perm FilePermissions) error
+
+ // Remove removes the given named non-directory under dir.
+ //
+ // The caller must ensure that this operation is permitted.
+ Remove(ctx context.Context, dir *Inode, name string) error
+
+ // RemoveDirectory removes the given named directory under dir.
+ //
+ // The caller must ensure that this operation is permitted.
+ //
+ // RemoveDirectory should check that the directory to be
+ // removed is empty.
+ RemoveDirectory(ctx context.Context, dir *Inode, name string) error
+
+ // Rename atomically renames oldName under oldParent to newName under
+ // newParent where oldParent and newParent are directories. inode is
+ // the Inode of this InodeOperations.
+ //
+ // If replacement is true, then newName already exists and this call
+ // will replace it with oldName.
+ //
+ // Implementations are responsible for rejecting renames that replace
+ // non-empty directories.
+ Rename(ctx context.Context, inode *Inode, oldParent *Inode, oldName string, newParent *Inode, newName string, replacement bool) error
+
+ // Bind binds a new socket under dir at the given name.
+ //
+ // The caller must ensure that this operation is permitted.
+ Bind(ctx context.Context, dir *Inode, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error)
+
+ // BoundEndpoint returns the socket endpoint at path stored in
+ // or generated by an Inode.
+ //
+ // The path is only relevant for generated endpoint because stored
+ // endpoints already know their path. It is ok for the endpoint to
+ // hold onto their path because the only way to change a bind
+ // address is to rebind the socket.
+ //
+ // This is valid iff the type of the Inode is a Socket, which
+ // generally implies that this Inode was created via CreateSocket.
+ //
+ // If there is no socket endpoint available, nil will be returned.
+ BoundEndpoint(inode *Inode, path string) transport.BoundEndpoint
+
+ // GetFile returns a new open File backed by a Dirent and FileFlags.
+ //
+ // Special Inode types may block using ctx.Sleeper. RegularFiles,
+ // Directories, and Symlinks must not block (see doCopyUp).
+ //
+ // The returned File will uniquely back an application fd.
+ GetFile(ctx context.Context, d *Dirent, flags FileFlags) (*File, error)
+
+ // UnstableAttr returns the most up-to-date "unstable" attributes of
+ // an Inode, where "unstable" means that they change in response to
+ // file system events.
+ UnstableAttr(ctx context.Context, inode *Inode) (UnstableAttr, error)
+
+ // Getxattr retrieves the value of extended attribute name. Inodes that
+ // do not support extended attributes return EOPNOTSUPP. Inodes that
+ // support extended attributes but don't have a value at name return
+ // ENODATA.
+ Getxattr(inode *Inode, name string) (string, error)
+
+ // Setxattr sets the value of extended attribute name. Inodes that
+ // do not support extended attributes return EOPNOTSUPP.
+ Setxattr(inode *Inode, name, value string) error
+
+ // Listxattr returns the set of all extended attributes names that
+ // have values. Inodes that do not support extended attributes return
+ // EOPNOTSUPP.
+ Listxattr(inode *Inode) (map[string]struct{}, error)
+
+ // Check determines whether an Inode can be accessed with the
+ // requested permission mask using the context (which gives access
+ // to Credentials and UserNamespace).
+ Check(ctx context.Context, inode *Inode, p PermMask) bool
+
+ // SetPermissions sets new permissions for an Inode. Returns false
+ // if it was not possible to set the new permissions.
+ //
+ // The caller must ensure that this operation is permitted.
+ SetPermissions(ctx context.Context, inode *Inode, f FilePermissions) bool
+
+ // SetOwner sets the ownership for this file.
+ //
+ // If either UID or GID are set to auth.NoID, its value will not be
+ // changed.
+ //
+ // The caller must ensure that this operation is permitted.
+ SetOwner(ctx context.Context, inode *Inode, owner FileOwner) error
+
+ // SetTimestamps sets the access and modification timestamps of an
+ // Inode according to the access and modification times in the TimeSpec.
+ //
+ // If either ATimeOmit or MTimeOmit is set, then the corresponding
+ // timestamp is not updated.
+ //
+ // If either ATimeSetSystemTime or MTimeSetSystemTime is true, that
+ // timestamp is set to the current time instead.
+ //
+ // The caller must ensure that this operation is permitted.
+ SetTimestamps(ctx context.Context, inode *Inode, ts TimeSpec) error
+
+ // Truncate changes the size of an Inode. Truncate should not check
+ // permissions internally, as it is used for both sys_truncate and
+ // sys_ftruncate.
+ //
+ // Implementations need not check that length >= 0.
+ Truncate(ctx context.Context, inode *Inode, size int64) error
+
+ // Allocate allows the caller to reserve disk space for the inode.
+ // It's equivalent to fallocate(2) with 'mode=0'.
+ Allocate(ctx context.Context, inode *Inode, offset int64, length int64) error
+
+ // WriteOut writes cached Inode state to a backing filesystem in a
+ // synchronous manner.
+ //
+ // File systems that do not cache metadata or data via an Inode
+ // implement WriteOut as a no-op. File systems that are entirely in
+ // memory also implement WriteOut as a no-op. Otherwise file systems
+ // call Inode.Sync to write back page cached data and cached metadata
+ // followed by syncing writeback handles.
+ //
+ // It derives from include/linux/fs.h:super_operations->write_inode.
+ WriteOut(ctx context.Context, inode *Inode) error
+
+ // Readlink reads the symlink path of an Inode.
+ //
+ // Readlink is permitted to return a different path depending on ctx,
+ // the request originator.
+ //
+ // The caller must ensure that this operation is permitted.
+ //
+ // Readlink should check that Inode is a symlink and its content is
+ // at least readable.
+ Readlink(ctx context.Context, inode *Inode) (string, error)
+
+ // Getlink resolves a symlink to a target *Dirent.
+ //
+ // Filesystems that can resolve the link by walking to the path returned
+ // by Readlink should return (nil, ErrResolveViaReadlink), which
+ // triggers link resolution via Realink and Lookup.
+ //
+ // Some links cannot be followed by Lookup. In this case, Getlink can
+ // return the Dirent of the link target. The caller holds a reference
+ // to the Dirent. Filesystems that return a non-nil *Dirent from Getlink
+ // cannot participate in an overlay because it is impossible for the
+ // overlay to ascertain whether or not the *Dirent should contain an
+ // overlayEntry.
+ //
+ // Any error returned from Getlink other than ErrResolveViaReadlink
+ // indicates the caller's inability to traverse this Inode as a link
+ // (e.g. syserror.ENOLINK indicates that the Inode is not a link,
+ // syscall.EPERM indicates that traversing the link is not allowed, etc).
+ Getlink(context.Context, *Inode) (*Dirent, error)
+
+ // Mappable returns a memmap.Mappable that provides memory mappings of the
+ // Inode's data. Mappable may return nil if this is not supported. The
+ // returned Mappable must remain valid until InodeOperations.Release is
+ // called.
+ Mappable(*Inode) memmap.Mappable
+
+ // The below methods require cleanup.
+
+ // AddLink increments the hard link count of an Inode.
+ //
+ // Remove in favor of Inode.IncLink.
+ AddLink()
+
+ // DropLink decrements the hard link count of an Inode.
+ //
+ // Remove in favor of Inode.DecLink.
+ DropLink()
+
+ // NotifyStatusChange sets the status change time to the current time.
+ //
+ // Remove in favor of updating the Inode's cached status change time.
+ NotifyStatusChange(ctx context.Context)
+
+ // IsVirtual indicates whether or not this corresponds to a virtual
+ // resource.
+ //
+ // If IsVirtual returns true, then caching will be disabled for this
+ // node, and fs.Dirent.Freeze() will not stop operations on the node.
+ //
+ // Remove in favor of freezing specific mounts.
+ IsVirtual() bool
+
+ // StatFS returns a filesystem Info implementation or an error. If
+ // the filesystem does not support this operation (maybe in the future
+ // it will), then ENOSYS should be returned.
+ StatFS(context.Context) (Info, error)
+}
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
new file mode 100644
index 000000000..cdffe173b
--- /dev/null
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -0,0 +1,676 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+func overlayHasWhiteout(parent *Inode, name string) bool {
+ s, err := parent.Getxattr(XattrOverlayWhiteout(name))
+ return err == nil && s == "y"
+}
+
+func overlayCreateWhiteout(parent *Inode, name string) error {
+ return parent.InodeOperations.Setxattr(parent, XattrOverlayWhiteout(name), "y")
+}
+
+func overlayWriteOut(ctx context.Context, o *overlayEntry) error {
+ // Hot path. Avoid defers.
+ var err error
+ o.copyMu.RLock()
+ if o.upper != nil {
+ err = o.upper.InodeOperations.WriteOut(ctx, o.upper)
+ }
+ o.copyMu.RUnlock()
+ return err
+}
+
+// overlayLookup performs a lookup in parent.
+//
+// If name exists, it returns true if the Dirent is in the upper, false if the
+// Dirent is in the lower.
+func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name string) (*Dirent, bool, error) {
+ // Hot path. Avoid defers.
+ parent.copyMu.RLock()
+
+ // Assert that there is at least one upper or lower entry.
+ if parent.upper == nil && parent.lower == nil {
+ parent.copyMu.RUnlock()
+ panic("invalid overlayEntry, needs at least one Inode")
+ }
+
+ var upperInode *Inode
+ var lowerInode *Inode
+
+ // We must remember whether the upper fs returned a negative dirent,
+ // because it is only safe to return one if the upper did.
+ var negativeUpperChild bool
+
+ // Does the parent directory exist in the upper file system?
+ if parent.upper != nil {
+ // First check if a file object exists in the upper file system.
+ // A file could have been created over a whiteout, so we need to
+ // check if something exists in the upper file system first.
+ child, err := parent.upper.Lookup(ctx, name)
+ if err != nil && err != syserror.ENOENT {
+ // We encountered an error that an overlay cannot handle,
+ // we must propagate it to the caller.
+ parent.copyMu.RUnlock()
+ return nil, false, err
+ }
+ if child != nil {
+ if child.IsNegative() {
+ negativeUpperChild = true
+ } else {
+ upperInode = child.Inode
+ upperInode.IncRef()
+ }
+ child.DecRef()
+ }
+
+ // Are we done?
+ if overlayHasWhiteout(parent.upper, name) {
+ if upperInode == nil {
+ parent.copyMu.RUnlock()
+ if negativeUpperChild {
+ // If the upper fs returnd a negative
+ // Dirent, then the upper is OK with
+ // that negative Dirent being cached in
+ // the Dirent tree, so we can return
+ // one from the overlay.
+ return NewNegativeDirent(name), false, nil
+ }
+ // Upper fs is not OK with a negative Dirent
+ // being cached in the Dirent tree, so don't
+ // return one.
+ return nil, false, syserror.ENOENT
+ }
+ entry, err := newOverlayEntry(ctx, upperInode, nil, false)
+ if err != nil {
+ // Don't leak resources.
+ upperInode.DecRef()
+ parent.copyMu.RUnlock()
+ return nil, false, err
+ }
+ d, err := NewDirent(newOverlayInode(ctx, entry, inode.MountSource), name), nil
+ parent.copyMu.RUnlock()
+ return d, true, err
+ }
+ }
+
+ // Check the lower file system. We do this unconditionally (even for
+ // non-directories) because we may need to use stable attributes from
+ // the lower filesystem (e.g. device number, inode number) that were
+ // visible before a copy up.
+ if parent.lower != nil {
+ // Check the lower file system.
+ child, err := parent.lower.Lookup(ctx, name)
+ // Same song and dance as above.
+ if err != nil && err != syserror.ENOENT {
+ // Don't leak resources.
+ if upperInode != nil {
+ upperInode.DecRef()
+ }
+ parent.copyMu.RUnlock()
+ return nil, false, err
+ }
+ if child != nil {
+ if !child.IsNegative() {
+ if upperInode == nil {
+ // If nothing was in the upper, use what we found in the lower.
+ lowerInode = child.Inode
+ lowerInode.IncRef()
+ } else {
+ // If we have something from the upper, we can only use it if the types
+ // match.
+ // NOTE(b/112312863): Allow SpecialDirectories and Directories to merge.
+ // This is needed to allow submounts in /proc and /sys.
+ if upperInode.StableAttr.Type == child.Inode.StableAttr.Type ||
+ (IsDir(upperInode.StableAttr) && IsDir(child.Inode.StableAttr)) {
+ lowerInode = child.Inode
+ lowerInode.IncRef()
+ }
+ }
+ }
+ child.DecRef()
+ }
+ }
+
+ // Was all of this for naught?
+ if upperInode == nil && lowerInode == nil {
+ parent.copyMu.RUnlock()
+ // We can only return a negative dirent if the upper returned
+ // one as well. See comments above regarding negativeUpperChild
+ // for more info.
+ if negativeUpperChild {
+ return NewNegativeDirent(name), false, nil
+ }
+ return nil, false, syserror.ENOENT
+ }
+
+ // Did we find a lower Inode? Remember this because we may decide we don't
+ // actually need the lower Inode (see below).
+ lowerExists := lowerInode != nil
+
+ // If we found something in the upper filesystem and the lower filesystem,
+ // use the stable attributes from the lower filesystem. If we don't do this,
+ // then it may appear that the file was magically recreated across copy up.
+ if upperInode != nil && lowerInode != nil {
+ // Steal attributes.
+ upperInode.StableAttr = lowerInode.StableAttr
+
+ // For non-directories, the lower filesystem resource is strictly
+ // unnecessary because we don't need to copy-up and we will always
+ // operate (e.g. read/write) on the upper Inode.
+ if !IsDir(upperInode.StableAttr) {
+ lowerInode.DecRef()
+ lowerInode = nil
+ }
+ }
+
+ // Phew, finally done.
+ entry, err := newOverlayEntry(ctx, upperInode, lowerInode, lowerExists)
+ if err != nil {
+ // Well, not quite, we failed at the last moment, how depressing.
+ // Be sure not to leak resources.
+ if upperInode != nil {
+ upperInode.DecRef()
+ }
+ if lowerInode != nil {
+ lowerInode.DecRef()
+ }
+ parent.copyMu.RUnlock()
+ return nil, false, err
+ }
+ d, err := NewDirent(newOverlayInode(ctx, entry, inode.MountSource), name), nil
+ parent.copyMu.RUnlock()
+ return d, upperInode != nil, err
+}
+
+func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) {
+ // Dirent.Create takes renameMu if the Inode is an overlay Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return nil, err
+ }
+
+ upperFile, err := o.upper.InodeOperations.Create(ctx, o.upper, name, flags, perm)
+ if err != nil {
+ return nil, err
+ }
+
+ // Take another reference on the upper file's inode, which will be
+ // owned by the overlay entry.
+ upperFile.Dirent.Inode.IncRef()
+ entry, err := newOverlayEntry(ctx, upperFile.Dirent.Inode, nil, false)
+ if err != nil {
+ cleanupUpper(ctx, o.upper, name)
+ return nil, err
+ }
+
+ // NOTE(b/71766861): Replace the Dirent with a transient Dirent, since
+ // we are about to create the real Dirent: an overlay Dirent.
+ //
+ // This ensures the *fs.File returned from overlayCreate is in the same
+ // state as the *fs.File returned by overlayGetFile, where the upper
+ // file has a transient Dirent.
+ //
+ // This is necessary for Save/Restore, as otherwise the upper Dirent
+ // (which has no path as it is unparented and never reachable by the
+ // user) will clobber the real path for the underlying Inode.
+ upperFile.Dirent.Inode.IncRef()
+ upperDirent := NewTransientDirent(upperFile.Dirent.Inode)
+ upperFile.Dirent.DecRef()
+ upperFile.Dirent = upperDirent
+
+ // Create the overlay inode and dirent. We need this to construct the
+ // overlay file.
+ overlayInode := newOverlayInode(ctx, entry, parent.Inode.MountSource)
+ // d will own the inode reference.
+ overlayDirent := NewDirent(overlayInode, name)
+ // The overlay file created below with NewFile will take a reference on
+ // the overlayDirent, and it should be the only thing holding a
+ // reference at the time of creation, so we must drop this reference.
+ defer overlayDirent.DecRef()
+
+ // Create a new overlay file that wraps the upper file.
+ flags.Pread = upperFile.Flags().Pread
+ flags.Pwrite = upperFile.Flags().Pwrite
+ overlayFile := NewFile(ctx, overlayDirent, flags, &overlayFileOperations{upper: upperFile})
+
+ return overlayFile, nil
+}
+
+func overlayCreateDirectory(ctx context.Context, o *overlayEntry, parent *Dirent, name string, perm FilePermissions) error {
+ // Dirent.CreateDirectory takes renameMu if the Inode is an overlay
+ // Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.CreateDirectory(ctx, o.upper, name, perm)
+}
+
+func overlayCreateLink(ctx context.Context, o *overlayEntry, parent *Dirent, oldname string, newname string) error {
+ // Dirent.CreateLink takes renameMu if the Inode is an overlay Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.CreateLink(ctx, o.upper, oldname, newname)
+}
+
+func overlayCreateHardLink(ctx context.Context, o *overlayEntry, parent *Dirent, target *Dirent, name string) error {
+ // Dirent.CreateHardLink takes renameMu if the Inode is an overlay
+ // Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ if err := copyUpLockedForRename(ctx, target); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.CreateHardLink(ctx, o.upper, target.Inode.overlay.upper, name)
+}
+
+func overlayCreateFifo(ctx context.Context, o *overlayEntry, parent *Dirent, name string, perm FilePermissions) error {
+ // Dirent.CreateFifo takes renameMu if the Inode is an overlay Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.CreateFifo(ctx, o.upper, name, perm)
+}
+
+func overlayRemove(ctx context.Context, o *overlayEntry, parent *Dirent, child *Dirent) error {
+ // Dirent.Remove and Dirent.RemoveDirectory take renameMu if the Inode
+ // is an overlay Inode.
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
+ return err
+ }
+ child.Inode.overlay.copyMu.RLock()
+ defer child.Inode.overlay.copyMu.RUnlock()
+ if child.Inode.overlay.upper != nil {
+ if child.Inode.StableAttr.Type == Directory {
+ if err := o.upper.InodeOperations.RemoveDirectory(ctx, o.upper, child.name); err != nil {
+ return err
+ }
+ } else {
+ if err := o.upper.InodeOperations.Remove(ctx, o.upper, child.name); err != nil {
+ return err
+ }
+ }
+ }
+ if child.Inode.overlay.lowerExists {
+ return overlayCreateWhiteout(o.upper, child.name)
+ }
+ return nil
+}
+
+func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, renamed *Dirent, newParent *Dirent, newName string, replacement bool) error {
+ // To be able to copy these up below, they have to be part of an
+ // overlay file system.
+ //
+ // Maybe some day we can allow the more complicated case of
+ // non-overlay X overlay renames, but that's not necessary right now.
+ if renamed.Inode.overlay == nil || newParent.Inode.overlay == nil || oldParent.Inode.overlay == nil {
+ return syserror.EXDEV
+ }
+
+ if replacement {
+ // Check here if the file to be replaced exists and is a
+ // non-empty directory. If we copy up first, we may end up
+ // copying the directory but none of its children, so the
+ // directory will appear empty in the upper fs, which will then
+ // allow the rename to proceed when it should return ENOTEMPTY.
+ //
+ // NOTE(b/111808347): Ideally, we'd just pass in the replaced
+ // Dirent from Rename, but we must drop the reference on
+ // replaced before we make the rename call, so Rename can't
+ // pass the Dirent to the Inode without significantly
+ // complicating the API. Thus we look it up again here.
+ //
+ // For the same reason we can't use defer here.
+ replaced, inUpper, err := overlayLookup(ctx, newParent.Inode.overlay, newParent.Inode, newName)
+ // If err == ENOENT or a negative Dirent is returned, then
+ // newName has been removed out from under us. That's fine;
+ // filesystems where that can happen must handle stale
+ // 'replaced'.
+ if err != nil && err != syserror.ENOENT {
+ return err
+ }
+ if err == nil {
+ if !inUpper {
+ // newName doesn't exist in
+ // newParent.Inode.overlay.upper, thus from
+ // that Inode's perspective this won't be a
+ // replacing rename.
+ replacement = false
+ }
+
+ if !replaced.IsNegative() && IsDir(replaced.Inode.StableAttr) {
+ children, err := readdirOne(ctx, replaced)
+ if err != nil {
+ replaced.DecRef()
+ return err
+ }
+
+ // readdirOne ensures that "." and ".." are not
+ // included among the returned children, so we don't
+ // need to bother checking for them.
+ if len(children) > 0 {
+ replaced.DecRef()
+ return syserror.ENOTEMPTY
+ }
+ }
+
+ replaced.DecRef()
+ }
+ }
+
+ if err := copyUpLockedForRename(ctx, renamed); err != nil {
+ return err
+ }
+ if err := copyUpLockedForRename(ctx, newParent); err != nil {
+ return err
+ }
+ oldName := renamed.name
+ if err := o.upper.InodeOperations.Rename(ctx, renamed.Inode.overlay.upper, oldParent.Inode.overlay.upper, oldName, newParent.Inode.overlay.upper, newName, replacement); err != nil {
+ return err
+ }
+ if renamed.Inode.overlay.lowerExists {
+ return overlayCreateWhiteout(oldParent.Inode.overlay.upper, oldName)
+ }
+ return nil
+}
+
+func overlayBind(ctx context.Context, o *overlayEntry, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+ // We do not support doing anything exciting with sockets unless there
+ // is already a directory in the upper filesystem.
+ if o.upper == nil {
+ return nil, syserror.EOPNOTSUPP
+ }
+ d, err := o.upper.InodeOperations.Bind(ctx, o.upper, name, data, perm)
+ if err != nil {
+ return nil, err
+ }
+
+ // Grab the inode and drop the dirent, we don't need it.
+ inode := d.Inode
+ inode.IncRef()
+ d.DecRef()
+
+ // Create a new overlay entry and dirent for the socket.
+ entry, err := newOverlayEntry(ctx, inode, nil, false)
+ if err != nil {
+ inode.DecRef()
+ return nil, err
+ }
+ return NewDirent(newOverlayInode(ctx, entry, inode.MountSource), name), nil
+}
+
+func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ if o.upper != nil {
+ return o.upper.InodeOperations.BoundEndpoint(o.upper, path)
+ }
+
+ return o.lower.BoundEndpoint(path)
+}
+
+func overlayGetFile(ctx context.Context, o *overlayEntry, d *Dirent, flags FileFlags) (*File, error) {
+ // Hot path. Avoid defers.
+ if flags.Write {
+ if err := copyUp(ctx, d); err != nil {
+ return nil, err
+ }
+ }
+
+ o.copyMu.RLock()
+
+ if o.upper != nil {
+ upper, err := overlayFile(ctx, o.upper, flags)
+ if err != nil {
+ o.copyMu.RUnlock()
+ return nil, err
+ }
+ flags.Pread = upper.Flags().Pread
+ flags.Pwrite = upper.Flags().Pwrite
+ f, err := NewFile(ctx, d, flags, &overlayFileOperations{upper: upper}), nil
+ o.copyMu.RUnlock()
+ return f, err
+ }
+
+ lower, err := overlayFile(ctx, o.lower, flags)
+ if err != nil {
+ o.copyMu.RUnlock()
+ return nil, err
+ }
+ flags.Pread = lower.Flags().Pread
+ flags.Pwrite = lower.Flags().Pwrite
+ o.copyMu.RUnlock()
+ return NewFile(ctx, d, flags, &overlayFileOperations{lower: lower}), nil
+}
+
+func overlayUnstableAttr(ctx context.Context, o *overlayEntry) (UnstableAttr, error) {
+ // Hot path. Avoid defers.
+ var (
+ attr UnstableAttr
+ err error
+ )
+ o.copyMu.RLock()
+ if o.upper != nil {
+ attr, err = o.upper.UnstableAttr(ctx)
+ } else {
+ attr, err = o.lower.UnstableAttr(ctx)
+ }
+ o.copyMu.RUnlock()
+ return attr, err
+}
+
+func overlayGetxattr(o *overlayEntry, name string) (string, error) {
+ // Hot path. This is how the overlay checks for whiteout files.
+ // Avoid defers.
+ var (
+ s string
+ err error
+ )
+
+ // Don't forward the value of the extended attribute if it would
+ // unexpectedly change the behavior of a wrapping overlay layer.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return "", syserror.ENODATA
+ }
+
+ o.copyMu.RLock()
+ if o.upper != nil {
+ s, err = o.upper.Getxattr(name)
+ } else {
+ s, err = o.lower.Getxattr(name)
+ }
+ o.copyMu.RUnlock()
+ return s, err
+}
+
+func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+ var names map[string]struct{}
+ var err error
+ if o.upper != nil {
+ names, err = o.upper.Listxattr()
+ } else {
+ names, err = o.lower.Listxattr()
+ }
+ for name := range names {
+ // Same as overlayGetxattr, we shouldn't forward along
+ // overlay attributes.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ delete(names, name)
+ }
+ }
+ return names, err
+}
+
+func overlayCheck(ctx context.Context, o *overlayEntry, p PermMask) error {
+ o.copyMu.RLock()
+ // Hot path. Avoid defers.
+ var err error
+ if o.upper != nil {
+ err = o.upper.check(ctx, p)
+ } else {
+ if p.Write {
+ // Since writes will be redirected to the upper filesystem, the lower
+ // filesystem need not be writable, but must be readable for copy-up.
+ p.Write = false
+ p.Read = true
+ }
+ err = o.lower.check(ctx, p)
+ }
+ o.copyMu.RUnlock()
+ return err
+}
+
+func overlaySetPermissions(ctx context.Context, o *overlayEntry, d *Dirent, f FilePermissions) bool {
+ if err := copyUp(ctx, d); err != nil {
+ return false
+ }
+ return o.upper.InodeOperations.SetPermissions(ctx, o.upper, f)
+}
+
+func overlaySetOwner(ctx context.Context, o *overlayEntry, d *Dirent, owner FileOwner) error {
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.SetOwner(ctx, o.upper, owner)
+}
+
+func overlaySetTimestamps(ctx context.Context, o *overlayEntry, d *Dirent, ts TimeSpec) error {
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.SetTimestamps(ctx, o.upper, ts)
+}
+
+func overlayTruncate(ctx context.Context, o *overlayEntry, d *Dirent, size int64) error {
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.Truncate(ctx, o.upper, size)
+}
+
+func overlayAllocate(ctx context.Context, o *overlayEntry, d *Dirent, offset, length int64) error {
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.InodeOperations.Allocate(ctx, o.upper, offset, length)
+}
+
+func overlayReadlink(ctx context.Context, o *overlayEntry) (string, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+ if o.upper != nil {
+ return o.upper.Readlink(ctx)
+ }
+ return o.lower.Readlink(ctx)
+}
+
+func overlayGetlink(ctx context.Context, o *overlayEntry) (*Dirent, error) {
+ var dirent *Dirent
+ var err error
+
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ if o.upper != nil {
+ dirent, err = o.upper.Getlink(ctx)
+ } else {
+ dirent, err = o.lower.Getlink(ctx)
+ }
+ if dirent != nil {
+ // This dirent is likely bogus (its Inode likely doesn't contain
+ // the right overlayEntry). So we're forced to drop it on the
+ // ground and claim that jumping around the filesystem like this
+ // is not supported.
+ name, _ := dirent.FullName(nil)
+ dirent.DecRef()
+
+ // Claim that the path is not accessible.
+ err = syserror.EACCES
+ log.Warningf("Getlink not supported in overlay for %q", name)
+ }
+ return nil, err
+}
+
+func overlayStatFS(ctx context.Context, o *overlayEntry) (Info, error) {
+ o.copyMu.RLock()
+ defer o.copyMu.RUnlock()
+
+ var i Info
+ var err error
+ if o.upper != nil {
+ i, err = o.upper.StatFS(ctx)
+ } else {
+ i, err = o.lower.StatFS(ctx)
+ }
+ if err != nil {
+ return Info{}, err
+ }
+
+ i.Type = linux.OVERLAYFS_SUPER_MAGIC
+
+ return i, nil
+}
+
+// NewTestOverlayDir returns an overlay Inode for tests.
+//
+// If `revalidate` is true, then the upper filesystem will require
+// revalidation.
+func NewTestOverlayDir(ctx context.Context, upper, lower *Inode, revalidate bool) *Inode {
+ fs := &overlayFilesystem{}
+ var upperMsrc *MountSource
+ if revalidate {
+ upperMsrc = NewRevalidatingMountSource(fs, MountSourceFlags{})
+ } else {
+ upperMsrc = NewNonCachingMountSource(fs, MountSourceFlags{})
+ }
+ msrc := NewMountSource(&overlayMountSourceOperations{
+ upper: upperMsrc,
+ lower: NewNonCachingMountSource(fs, MountSourceFlags{}),
+ }, fs, MountSourceFlags{})
+ overlay := &overlayEntry{
+ upper: upper,
+ lower: lower,
+ }
+ return newOverlayInode(ctx, overlay, msrc)
+}
+
+// TestHasUpperFS returns true if i is an overlay Inode and it has a pointer
+// to an Inode on an upper filesystem.
+func (i *Inode) TestHasUpperFS() bool {
+ return i.overlay != nil && i.overlay.upper != nil
+}
+
+// TestHasLowerFS returns true if i is an overlay Inode and it has a pointer
+// to an Inode on a lower filesystem.
+func (i *Inode) TestHasLowerFS() bool {
+ return i.overlay != nil && i.overlay.lower != nil
+}
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
new file mode 100644
index 000000000..7dfd31020
--- /dev/null
+++ b/pkg/sentry/fs/inotify.go
@@ -0,0 +1,348 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Inotify represents an inotify instance created by inotify_init(2) or
+// inotify_init1(2). Inotify implements the FileOperations interface.
+//
+// Lock ordering:
+// Inotify.mu -> Inode.Watches.mu -> Watch.mu -> Inotify.evMu
+//
+// +stateify savable
+type Inotify struct {
+ // Unique identifier for this inotify instance. We don't just reuse the
+ // inotify fd because fds can be duped. These should not be exposed to the
+ // user, since we may aggressively reuse an id on S/R.
+ id uint64
+
+ waiter.Queue `state:"nosave"`
+
+ // evMu *only* protects the events list. We need a separate lock because
+ // while queuing events, a watch needs to lock the event queue, and using mu
+ // for that would violate lock ordering since at that point the calling
+ // goroutine already holds Watch.target.Watches.mu.
+ evMu sync.Mutex `state:"nosave"`
+
+ // A list of pending events for this inotify instance. Protected by evMu.
+ events eventList
+
+ // A scratch buffer, use to serialize inotify events. Use allocate this
+ // ahead of time and reuse performance. Protected by evMu.
+ scratch []byte
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // The next watch descriptor number to use for this inotify instance. Note
+ // that Linux starts numbering watch descriptors from 1.
+ nextWatch int32
+
+ // Map from watch descriptors to watch objects.
+ watches map[int32]*Watch
+}
+
+// NewInotify constructs a new Inotify instance.
+func NewInotify(ctx context.Context) *Inotify {
+ return &Inotify{
+ id: uniqueid.GlobalFromContext(ctx),
+ scratch: make([]byte, inotifyEventBaseSize),
+ nextWatch: 1, // Linux starts numbering watch descriptors from 1.
+ watches: make(map[int32]*Watch),
+ }
+}
+
+// Release implements FileOperations.Release. Release removes all watches and
+// frees all resources for an inotify instance.
+func (i *Inotify) Release() {
+ // We need to hold i.mu to avoid a race with concurrent calls to
+ // Inotify.targetDestroyed from Watches. There's no risk of Watches
+ // accessing this Inotify after the destructor ends, because we remove all
+ // references to it below.
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ for _, w := range i.watches {
+ // Remove references to the watch from the watch target. We don't need
+ // to worry about the references from the owner instance, since we're in
+ // the owner's destructor.
+ w.target.Watches.Remove(w.ID())
+ // Don't leak any references to the target, held by pins in the watch.
+ w.destroy()
+ }
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+//
+// Readiness indicates whether there are pending events for an inotify instance.
+func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask {
+ ready := waiter.EventMask(0)
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if !i.events.Empty() {
+ ready |= waiter.EventIn
+ }
+
+ return mask & ready
+}
+
+// Seek implements FileOperations.Seek.
+func (*Inotify) Seek(context.Context, *File, SeekWhence, int64) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Readdir implements FileOperatons.Readdir.
+func (*Inotify) Readdir(context.Context, *File, DentrySerializer) (int64, error) {
+ return 0, syserror.ENOTDIR
+}
+
+// Write implements FileOperations.Write.
+func (*Inotify) Write(context.Context, *File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Read implements FileOperations.Read.
+func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() < inotifyEventBaseSize {
+ return 0, syserror.EINVAL
+ }
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if i.events.Empty() {
+ // Nothing to read yet, tell caller to block.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var writeLen int64
+ for event := i.events.Front(); event != nil; event = event.Next() {
+ // Does the buffer have enough remaining space to hold the event we're
+ // about to write out?
+ if dst.NumBytes() < int64(event.sizeOf()) {
+ if writeLen > 0 {
+ // Buffer wasn't big enough for all pending events, but we did
+ // write some events out.
+ return writeLen, nil
+ }
+ return 0, syserror.EINVAL
+ }
+
+ // Linux always dequeues an available event as long as there's enough
+ // buffer space to copy it out, even if the copy below fails. Emulate
+ // this behaviour.
+ i.events.Remove(event)
+
+ // Buffer has enough space, copy event to the read buffer.
+ n, err := event.CopyTo(ctx, i.scratch, dst)
+ if err != nil {
+ return 0, err
+ }
+
+ writeLen += n
+ dst = dst.DropFirst64(n)
+ }
+ return writeLen, nil
+}
+
+// WriteTo implements FileOperations.WriteTo.
+func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// Fsync implements FileOperations.Fsync.
+func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error {
+ return syserror.EINVAL
+}
+
+// ReadFrom implements FileOperations.ReadFrom.
+func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) {
+ return 0, syserror.ENOSYS
+}
+
+// Flush implements FileOperations.Flush.
+func (*Inotify) Flush(context.Context, *File) error {
+ return nil
+}
+
+// ConfigureMMap implements FileOperations.ConfigureMMap.
+func (*Inotify) ConfigureMMap(context.Context, *File, *memmap.MMapOpts) error {
+ return syserror.ENODEV
+}
+
+// UnstableAttr implements FileOperations.UnstableAttr.
+func (i *Inotify) UnstableAttr(ctx context.Context, file *File) (UnstableAttr, error) {
+ return file.Dirent.Inode.UnstableAttr(ctx)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (i *Inotify) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch args[1].Int() {
+ case linux.FIONREAD:
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+ var n uint32
+ for e := i.events.Front(); e != nil; e = e.Next() {
+ n += uint32(e.sizeOf())
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], n)
+ _, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{})
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func (i *Inotify) queueEvent(ev *Event) {
+ i.evMu.Lock()
+
+ // Check if we should coalesce the event we're about to queue with the last
+ // one currently in the queue. Events are coalesced if they are identical.
+ if last := i.events.Back(); last != nil {
+ if ev.equals(last) {
+ // "Coalesce" the two events by simply not queuing the new one. We
+ // don't need to raise a waiter.EventIn notification because no new
+ // data is available for reading.
+ i.evMu.Unlock()
+ return
+ }
+ }
+
+ i.events.PushBack(ev)
+
+ // Release mutex before notifying waiters because we don't control what they
+ // can do.
+ i.evMu.Unlock()
+
+ i.Queue.Notify(waiter.EventIn)
+}
+
+// newWatchLocked creates and adds a new watch to target.
+func (i *Inotify) newWatchLocked(target *Dirent, mask uint32) *Watch {
+ wd := i.nextWatch
+ i.nextWatch++
+
+ watch := &Watch{
+ owner: i,
+ wd: wd,
+ mask: mask,
+ target: target.Inode,
+ pins: make(map[*Dirent]bool),
+ }
+
+ i.watches[wd] = watch
+
+ // Grab an extra reference to target to prevent it from being evicted from
+ // memory. This ref is dropped during either watch removal, target
+ // destruction, or inotify instance destruction. See callers of Watch.Unpin.
+ watch.Pin(target)
+ target.Inode.Watches.Add(watch)
+
+ return watch
+}
+
+// targetDestroyed is called by w to notify i that w's target is gone. This
+// automatically generates a watch removal event.
+func (i *Inotify) targetDestroyed(w *Watch) {
+ i.mu.Lock()
+ _, found := i.watches[w.wd]
+ delete(i.watches, w.wd)
+ i.mu.Unlock()
+
+ if found {
+ i.queueEvent(newEvent(w.wd, "", linux.IN_IGNORED, 0))
+ }
+}
+
+// AddWatch constructs a new inotify watch and adds it to the target dirent. It
+// returns the watch descriptor returned by inotify_add_watch(2).
+func (i *Inotify) AddWatch(target *Dirent, mask uint32) int32 {
+ // Note: Locking this inotify instance protects the result returned by
+ // Lookup() below. With the lock held, we know for sure the lookup result
+ // won't become stale because it's impossible for *this* instance to
+ // add/remove watches on target.
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ // Does the target already have a watch from this inotify instance?
+ if existing := target.Inode.Watches.Lookup(i.id); existing != nil {
+ // This may be a watch on a different dirent pointing to the
+ // same inode. Obtain an extra reference if necessary.
+ existing.Pin(target)
+
+ newmask := mask
+ if mergeMask := mask&linux.IN_MASK_ADD != 0; mergeMask {
+ // "Add (OR) events to watch mask for this pathname if it already
+ // exists (instead of replacing mask)." -- inotify(7)
+ newmask |= atomic.LoadUint32(&existing.mask)
+ }
+ atomic.StoreUint32(&existing.mask, newmask)
+ return existing.wd
+ }
+
+ // No existing watch, create a new watch.
+ watch := i.newWatchLocked(target, mask)
+ return watch.wd
+}
+
+// RmWatch implements watcher.Watchable.RmWatch.
+//
+// RmWatch looks up an inotify watch for the given 'wd' and configures the
+// target dirent to stop sending events to this inotify instance.
+func (i *Inotify) RmWatch(wd int32) error {
+ i.mu.Lock()
+
+ // Find the watch we were asked to removed.
+ watch, ok := i.watches[wd]
+ if !ok {
+ i.mu.Unlock()
+ return syserror.EINVAL
+ }
+
+ // Remove the watch from this instance.
+ delete(i.watches, wd)
+
+ // Remove the watch from the watch target.
+ watch.target.Watches.Remove(watch.ID())
+
+ // The watch is now isolated and we can safely drop the instance lock. We
+ // need to do so because watch.destroy() acquires Watch.mu, which cannot be
+ // acquired with Inotify.mu held.
+ i.mu.Unlock()
+
+ // Generate the event for the removal.
+ i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0))
+
+ // Remove all pins.
+ watch.destroy()
+
+ return nil
+}
diff --git a/pkg/sentry/fs/inotify_event.go b/pkg/sentry/fs/inotify_event.go
new file mode 100644
index 000000000..d52f956e4
--- /dev/null
+++ b/pkg/sentry/fs/inotify_event.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// inotifyEventBaseSize is the base size of linux's struct inotify_event. This
+// must be a power 2 for rounding below.
+const inotifyEventBaseSize = 16
+
+// Event represents a struct inotify_event from linux.
+//
+// +stateify savable
+type Event struct {
+ eventEntry
+
+ wd int32
+ mask uint32
+ cookie uint32
+
+ // len is computed based on the name field is set automatically by
+ // Event.setName. It should be 0 when no name is set; otherwise it is the
+ // length of the name slice.
+ len uint32
+
+ // The name field has special padding requirements and should only be set by
+ // calling Event.setName.
+ name []byte
+}
+
+func newEvent(wd int32, name string, events, cookie uint32) *Event {
+ e := &Event{
+ wd: wd,
+ mask: events,
+ cookie: cookie,
+ }
+ if name != "" {
+ e.setName(name)
+ }
+ return e
+}
+
+// paddedBytes converts a go string to a null-terminated c-string, padded with
+// null bytes to a total size of 'l'. 'l' must be large enough for all the bytes
+// in the 's' plus at least one null byte.
+func paddedBytes(s string, l uint32) []byte {
+ if l < uint32(len(s)+1) {
+ panic("Converting string to byte array results in truncation, this can lead to buffer-overflow due to the missing null-byte!")
+ }
+ b := make([]byte, l)
+ copy(b, s)
+
+ // b was zero-value initialized during make(), so the rest of the slice is
+ // already filled with null bytes.
+
+ return b
+}
+
+// setName sets the optional name for this event.
+func (e *Event) setName(name string) {
+ // We need to pad the name such that the entire event length ends up a
+ // multiple of inotifyEventBaseSize.
+ unpaddedLen := len(name) + 1
+ // Round up to nearest multiple of inotifyEventBaseSize.
+ e.len = uint32((unpaddedLen + inotifyEventBaseSize - 1) & ^(inotifyEventBaseSize - 1))
+ // Make sure we haven't overflowed and wrapped around when rounding.
+ if unpaddedLen > int(e.len) {
+ panic("Overflow when rounding inotify event size, the 'name' field was too big.")
+ }
+ e.name = paddedBytes(name, e.len)
+}
+
+func (e *Event) sizeOf() int {
+ s := inotifyEventBaseSize + int(e.len)
+ if s < inotifyEventBaseSize {
+ panic("overflow")
+ }
+ return s
+}
+
+// CopyTo serializes this event to dst. buf is used as a scratch buffer to
+// construct the output. We use a buffer allocated ahead of time for
+// performance. buf must be at least inotifyEventBaseSize bytes.
+func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) {
+ usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
+ usermem.ByteOrder.PutUint32(buf[4:], e.mask)
+ usermem.ByteOrder.PutUint32(buf[8:], e.cookie)
+ usermem.ByteOrder.PutUint32(buf[12:], e.len)
+
+ writeLen := 0
+
+ n, err := dst.CopyOut(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ dst = dst.DropFirst(n)
+
+ if e.len > 0 {
+ n, err = dst.CopyOut(ctx, e.name)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ }
+
+ // Santiy check.
+ if writeLen != e.sizeOf() {
+ panic(fmt.Sprintf("Serialized unexpected amount of data for an event, expected %v, wrote %v.", e.sizeOf(), writeLen))
+ }
+
+ return int64(writeLen), nil
+}
+
+func (e *Event) equals(other *Event) bool {
+ return e.wd == other.wd &&
+ e.mask == other.mask &&
+ e.cookie == other.cookie &&
+ e.len == other.len &&
+ bytes.Equal(e.name, other.name)
+}
diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go
new file mode 100644
index 000000000..a0b488467
--- /dev/null
+++ b/pkg/sentry/fs/inotify_watch.go
@@ -0,0 +1,135 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// Watch represent a particular inotify watch created by inotify_add_watch.
+//
+// While a watch is active, it ensures the target inode is pinned in memory by
+// holding an extra ref on each dirent known (by inotify) to point to the
+// inode. These are known as pins. For a full discussion, see
+// fs/g3doc/inotify.md.
+//
+// +stateify savable
+type Watch struct {
+ // Inotify instance which owns this watch.
+ owner *Inotify
+
+ // Descriptor for this watch. This is unique across an inotify instance.
+ wd int32
+
+ // The inode being watched. Note that we don't directly hold a reference on
+ // this inode. Instead we hold a reference on the dirent(s) containing the
+ // inode, which we record in pins.
+ target *Inode
+
+ // unpinned indicates whether we have a hard reference on target. This field
+ // may only be modified through atomic ops.
+ unpinned uint32
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // Events being monitored via this watch. Must be accessed atomically,
+ // writes are protected by mu.
+ mask uint32
+
+ // pins is the set of dirents this watch is currently pinning in memory by
+ // holding a reference to them. See Pin()/Unpin().
+ pins map[*Dirent]bool
+}
+
+// ID returns the id of the inotify instance that owns this watch.
+func (w *Watch) ID() uint64 {
+ return w.owner.id
+}
+
+// NotifyParentAfterUnlink indicates whether the parent of the watched object
+// should continue to be be notified of events after the target has been
+// unlinked.
+func (w *Watch) NotifyParentAfterUnlink() bool {
+ return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK == 0
+}
+
+// isRenameEvent returns true if eventMask describes a rename event.
+func isRenameEvent(eventMask uint32) bool {
+ return eventMask&(linux.IN_MOVED_FROM|linux.IN_MOVED_TO|linux.IN_MOVE_SELF) != 0
+}
+
+// Notify queues a new event on this watch.
+func (w *Watch) Notify(name string, events uint32, cookie uint32) {
+ mask := atomic.LoadUint32(&w.mask)
+ if mask&events == 0 {
+ // We weren't watching for this event.
+ return
+ }
+
+ // Event mask should include bits matched from the watch plus all control
+ // event bits.
+ unmaskableBits := ^uint32(0) &^ linux.IN_ALL_EVENTS
+ effectiveMask := unmaskableBits | mask
+ matchedEvents := effectiveMask & events
+ w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie))
+}
+
+// Pin acquires a new ref on dirent, which pins the dirent in memory while
+// the watch is active. Calling Pin for a second time on the same dirent for
+// the same watch is a no-op.
+func (w *Watch) Pin(d *Dirent) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if !w.pins[d] {
+ w.pins[d] = true
+ d.IncRef()
+ }
+}
+
+// Unpin drops any extra refs held on dirent due to a previous Pin
+// call. Calling Unpin multiple times for the same dirent, or on a dirent
+// without a corresponding Pin call is a no-op.
+func (w *Watch) Unpin(d *Dirent) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.pins[d] {
+ delete(w.pins, d)
+ d.DecRef()
+ }
+}
+
+// TargetDestroyed notifies the owner of the watch that the watch target is
+// gone. The owner should release its own references to the watcher upon
+// receiving this notification.
+func (w *Watch) TargetDestroyed() {
+ w.owner.targetDestroyed(w)
+}
+
+// destroy prepares the watch for destruction. It unpins all dirents pinned by
+// this watch. Destroy does not cause any new events to be generated. The caller
+// is responsible for ensuring there are no outstanding references to this
+// watch.
+func (w *Watch) destroy() {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ for d := range w.pins {
+ d.DecRef()
+ }
+ w.pins = nil
+}
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
new file mode 100644
index 000000000..f2aee4512
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock.go
@@ -0,0 +1,461 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lock is the API for POSIX-style advisory regional file locks and
+// BSD-style full file locks.
+//
+// Callers needing to enforce these types of locks, like sys_fcntl, can call
+// LockRegion and UnlockRegion on a thread-safe set of Locks. Locks are
+// specific to a unique file (unique device/inode pair) and for this reason
+// should not be shared between files.
+//
+// A Lock has a set of holders identified by UniqueID. Normally this is the
+// pid of the thread attempting to acquire the lock.
+//
+// Since these are advisory locks, they do not need to be integrated into
+// Reads/Writes and for this reason there is no way to *check* if a lock is
+// held. One can only attempt to take a lock or unlock an existing lock.
+//
+// A Lock in a set of Locks is typed: it is either a read lock with any number
+// of readers and no writer, or a write lock with no readers.
+//
+// As expected from POSIX, any attempt to acquire a write lock on a file region
+// when there already exits a write lock held by a different uid will fail. Any
+// attempt to acquire a write lock on a file region when there is more than one
+// reader will fail. Any attempt to acquire a read lock on a file region when
+// there is already a writer will fail.
+//
+// In special cases, a read lock may be upgraded to a write lock and a write lock
+// can be downgraded to a read lock. This can only happen if:
+//
+// * read lock upgrade to write lock: There can be only one reader and the reader
+// must be the same as the requested write lock holder.
+//
+// * write lock downgrade to read lock: The writer must be the same as the requested
+// read lock holder.
+//
+// UnlockRegion always succeeds. If LockRegion fails the caller should normally
+// interpret this as "try again later".
+package lock
+
+import (
+ "fmt"
+ "math"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// LockType is a type of regional file lock.
+type LockType int
+
+// UniqueID is a unique identifier of the holder of a regional file lock.
+type UniqueID uint64
+
+const (
+ // ReadLock describes a POSIX regional file lock to be taken
+ // read only. There may be multiple of these locks on a single
+ // file region as long as there is no writer lock on the same
+ // region.
+ ReadLock LockType = iota
+
+ // WriteLock describes a POSIX regional file lock to be taken
+ // write only. There may be only a single holder of this lock
+ // and no read locks.
+ WriteLock
+)
+
+// LockEOF is the maximal possible end of a regional file lock.
+const LockEOF = math.MaxUint64
+
+// Lock is a regional file lock. It consists of either a single writer
+// or a set of readers.
+//
+// A Lock may be upgraded from a read lock to a write lock only if there
+// is a single reader and that reader has the same uid as the write lock.
+//
+// A Lock may be downgraded from a write lock to a read lock only if
+// the write lock's uid is the same as the read lock.
+//
+// +stateify savable
+type Lock struct {
+ // Readers are the set of read lock holders identified by UniqueID.
+ // If len(Readers) > 0 then HasWriter must be false.
+ Readers map[UniqueID]bool
+
+ // HasWriter indicates that this is a write lock held by a single
+ // UniqueID.
+ HasWriter bool
+
+ // Writer is only valid if HasWriter is true. It identifies a
+ // single write lock holder.
+ Writer UniqueID
+}
+
+// Locks is a thread-safe wrapper around a LockSet.
+//
+// +stateify savable
+type Locks struct {
+ // mu protects locks below.
+ mu sync.Mutex `state:"nosave"`
+
+ // locks is the set of region locks currently held on an Inode.
+ locks LockSet
+
+ // blockedQueue is the queue of waiters that are waiting on a lock.
+ blockedQueue waiter.Queue `state:"zerovalue"`
+}
+
+// Blocker is the interface used for blocking locks. Passing a nil Blocker
+// will be treated as non-blocking.
+type Blocker interface {
+ Block(C <-chan struct{}) error
+}
+
+const (
+ // EventMaskAll is the mask we will always use for locks, by using the
+ // same mask all the time we can wake up everyone anytime the lock
+ // changes state.
+ EventMaskAll waiter.EventMask = 0xFFFF
+)
+
+// LockRegion attempts to acquire a typed lock for the uid on a region
+// of a file. Returns true if successful in locking the region. If false
+// is returned, the caller should normally interpret this as "try again later" if
+// accquiring the lock in a non-blocking mode or "interrupted" if in a blocking mode.
+// Blocker is the interface used to provide blocking behavior, passing a nil Blocker
+// will result in non-blocking behavior.
+func (l *Locks) LockRegion(uid UniqueID, t LockType, r LockRange, block Blocker) bool {
+ for {
+ l.mu.Lock()
+
+ // Blocking locks must run in a loop because we'll be woken up whenever an unlock event
+ // happens for this lock. We will then attempt to take the lock again and if it fails
+ // continue blocking.
+ res := l.locks.lock(uid, t, r)
+ if !res && block != nil {
+ e, ch := waiter.NewChannelEntry(nil)
+ l.blockedQueue.EventRegister(&e, EventMaskAll)
+ l.mu.Unlock()
+ if err := block.Block(ch); err != nil {
+ // We were interrupted, the caller can translate this to EINTR if applicable.
+ l.blockedQueue.EventUnregister(&e)
+ return false
+ }
+ l.blockedQueue.EventUnregister(&e)
+ continue // Try again now that someone has unlocked.
+ }
+
+ l.mu.Unlock()
+ return res
+ }
+}
+
+// UnlockRegion attempts to release a lock for the uid on a region of a file.
+// This operation is always successful, even if there did not exist a lock on
+// the requested region held by uid in the first place.
+func (l *Locks) UnlockRegion(uid UniqueID, r LockRange) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.locks.unlock(uid, r)
+
+ // Now that we've released the lock, we need to wake up any waiters.
+ l.blockedQueue.Notify(EventMaskAll)
+}
+
+// makeLock returns a new typed Lock that has either uid as its only reader
+// or uid as its only writer.
+func makeLock(uid UniqueID, t LockType) Lock {
+ value := Lock{Readers: make(map[UniqueID]bool)}
+ switch t {
+ case ReadLock:
+ value.Readers[uid] = true
+ case WriteLock:
+ value.HasWriter = true
+ value.Writer = uid
+ default:
+ panic(fmt.Sprintf("makeLock: invalid lock type %d", t))
+ }
+ return value
+}
+
+// isHeld returns true if uid is a holder of Lock.
+func (l Lock) isHeld(uid UniqueID) bool {
+ if l.HasWriter && l.Writer == uid {
+ return true
+ }
+ return l.Readers[uid]
+}
+
+// lock sets uid as a holder of a typed lock on Lock.
+//
+// Preconditions: canLock is true for the range containing this Lock.
+func (l *Lock) lock(uid UniqueID, t LockType) {
+ switch t {
+ case ReadLock:
+ // If we are already a reader, then this is a no-op.
+ if l.Readers[uid] {
+ return
+ }
+ // We cannot downgrade a write lock to a read lock unless the
+ // uid is the same.
+ if l.HasWriter {
+ if l.Writer != uid {
+ panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer))
+ }
+ // Ensure that there is only one reader if upgrading.
+ l.Readers = make(map[UniqueID]bool)
+ // Ensure that there is no longer a writer.
+ l.HasWriter = false
+ }
+ l.Readers[uid] = true
+ return
+ case WriteLock:
+ // If we are already the writer, then this is a no-op.
+ if l.HasWriter && l.Writer == uid {
+ return
+ }
+ // We can only upgrade a read lock to a write lock if there
+ // is only one reader and that reader has the same uid as
+ // the write lock.
+ if readers := len(l.Readers); readers > 0 {
+ if readers != 1 {
+ panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, too many readers %v", uid, l.Readers))
+ }
+ if !l.Readers[uid] {
+ panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, conflicting reader %v", uid, l.Readers))
+ }
+ }
+ // Ensure that there is only a writer.
+ l.Readers = make(map[UniqueID]bool)
+ l.HasWriter = true
+ l.Writer = uid
+ default:
+ panic(fmt.Sprintf("lock: invalid lock type %d", t))
+ }
+}
+
+// lockable returns true if check returns true for every Lock in LockRange.
+// Further, check should return true if Lock meets the callers requirements
+// for locking Lock.
+func (l LockSet) lockable(r LockRange, check func(value Lock) bool) bool {
+ // Get our starting point.
+ seg := l.LowerBoundSegment(r.Start)
+ for seg.Ok() && seg.Start() < r.End {
+ // Note that we don't care about overruning the end of the
+ // last segment because if everything checks out we'll just
+ // split the last segment.
+ if !check(seg.Value()) {
+ return false
+ }
+ // Jump to the next segment, ignoring gaps, for the same
+ // reason we ignored the first gap.
+ seg = seg.NextSegment()
+ }
+ // No conflict, we can get a lock for uid over the entire range.
+ return true
+}
+
+// canLock returns true if uid will be able to take a Lock of type t on the
+// entire range specified by LockRange.
+func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
+ switch t {
+ case ReadLock:
+ return l.lockable(r, func(value Lock) bool {
+ // If there is no writer, there's no problem adding
+ // another reader.
+ if !value.HasWriter {
+ return true
+ }
+ // If there is a writer, then it must be the same uid
+ // in order to downgrade the lock to a read lock.
+ return value.Writer == uid
+ })
+ case WriteLock:
+ return l.lockable(r, func(value Lock) bool {
+ // If there are only readers.
+ if !value.HasWriter {
+ // Then this uid can only take a write lock if
+ // this is a private upgrade, meaning that the
+ // only reader is uid.
+ return len(value.Readers) == 1 && value.Readers[uid]
+ }
+ // If the uid is already a writer on this region, then
+ // adding a write lock would be a no-op.
+ return value.Writer == uid
+ })
+ default:
+ panic(fmt.Sprintf("canLock: invalid lock type %d", t))
+ }
+}
+
+// lock returns true if uid took a lock of type t on the entire range of LockRange.
+//
+// Preconditions: r.Start <= r.End (will panic otherwise).
+func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
+ if r.Start > r.End {
+ panic(fmt.Sprintf("lock: r.Start %d > r.End %d", r.Start, r.End))
+ }
+
+ // Don't attempt to insert anything with a range of 0 and treat this
+ // as a successful no-op.
+ if r.Length() == 0 {
+ return true
+ }
+
+ // Do a first-pass check. We *could* hold onto the segments we
+ // checked if canLock would return true, but traversing the segment
+ // set should be fast and this keeps things simple.
+ if !l.canLock(uid, t, r) {
+ return false
+ }
+ // Get our starting point.
+ seg, gap := l.Find(r.Start)
+ if gap.Ok() {
+ // Fill in the gap and get the next segment to modify.
+ seg = l.Insert(gap, gap.Range().Intersect(r), makeLock(uid, t)).NextSegment()
+ } else if seg.Start() < r.Start {
+ // Get our first segment to modify.
+ _, seg = l.Split(seg, r.Start)
+ }
+ for seg.Ok() && seg.Start() < r.End {
+ // Split the last one if necessary.
+ if seg.End() > r.End {
+ seg, _ = l.SplitUnchecked(seg, r.End)
+ }
+
+ // Set the lock on the segment. This is guaranteed to
+ // always be safe, given canLock above.
+ value := seg.ValuePtr()
+ value.lock(uid, t)
+
+ // Fill subsequent gaps.
+ gap = seg.NextGap()
+ if gr := gap.Range().Intersect(r); gr.Length() > 0 {
+ seg = l.Insert(gap, gr, makeLock(uid, t)).NextSegment()
+ } else {
+ seg = gap.NextSegment()
+ }
+ }
+ return true
+}
+
+// unlock is always successful. If uid has no locks held for the range LockRange,
+// unlock is a no-op.
+//
+// Preconditions: same as lock.
+func (l *LockSet) unlock(uid UniqueID, r LockRange) {
+ if r.Start > r.End {
+ panic(fmt.Sprintf("unlock: r.Start %d > r.End %d", r.Start, r.End))
+ }
+
+ // Same as setlock.
+ if r.Length() == 0 {
+ return
+ }
+
+ // Get our starting point.
+ seg := l.LowerBoundSegment(r.Start)
+ for seg.Ok() && seg.Start() < r.End {
+ // If this segment doesn't have a lock from uid then
+ // there is no need to fragment the set with Isolate (below).
+ // In this case just move on to the next segment.
+ if !seg.Value().isHeld(uid) {
+ seg = seg.NextSegment()
+ continue
+ }
+
+ // Ensure that if we need to unlock a sub-segment that
+ // we don't unlock/remove that entire segment.
+ seg = l.Isolate(seg, r)
+
+ value := seg.Value()
+ var remove bool
+ if value.HasWriter && value.Writer == uid {
+ // If we are unlocking a writer, then since there can
+ // only ever be one writer and no readers, then this
+ // lock should always be removed from the set.
+ remove = true
+ } else if value.Readers[uid] {
+ // If uid is the last reader, then just remove the entire
+ // segment.
+ if len(value.Readers) == 1 {
+ remove = true
+ } else {
+ // Otherwise we need to remove this reader without
+ // affecting any other segment's readers. To do
+ // this, we need to make a copy of the Readers map
+ // and not add this uid.
+ newValue := Lock{Readers: make(map[UniqueID]bool)}
+ for k, v := range value.Readers {
+ if k != uid {
+ newValue.Readers[k] = v
+ }
+ }
+ seg.SetValue(newValue)
+ }
+ }
+ if remove {
+ seg = l.Remove(seg).NextSegment()
+ } else {
+ seg = seg.NextSegment()
+ }
+ }
+}
+
+// ComputeRange takes a positive file offset and computes the start of a LockRange
+// using start (relative to offset) and the end of the LockRange using length. The
+// values of start and length may be negative but the resulting LockRange must
+// preserve that LockRange.Start < LockRange.End and LockRange.Start > 0.
+func ComputeRange(start, length, offset int64) (LockRange, error) {
+ offset += start
+ // fcntl(2): "l_start can be a negative number provided the offset
+ // does not lie before the start of the file"
+ if offset < 0 {
+ return LockRange{}, syscall.EINVAL
+ }
+
+ // fcntl(2): Specifying 0 for l_len has the special meaning: lock all
+ // bytes starting at the location specified by l_whence and l_start
+ // through to the end of file, no matter how large the file grows.
+ end := uint64(LockEOF)
+ if length > 0 {
+ // fcntl(2): If l_len is positive, then the range to be locked
+ // covers bytes l_start up to and including l_start+l_len-1.
+ //
+ // Since LockRange.End is exclusive we need not -1 from length..
+ end = uint64(offset + length)
+ } else if length < 0 {
+ // fcntl(2): If l_len is negative, the interval described by
+ // lock covers bytes l_start+l_len up to and including l_start-1.
+ //
+ // Since LockRange.End is exclusive we need not -1 from offset.
+ signedEnd := offset
+ // Add to offset using a negative length (subtract).
+ offset += length
+ if offset < 0 {
+ return LockRange{}, syscall.EINVAL
+ }
+ if signedEnd < offset {
+ return LockRange{}, syscall.EOVERFLOW
+ }
+ // At this point signedEnd cannot be negative,
+ // since we asserted that offset is not negative
+ // and it is not less than offset.
+ end = uint64(signedEnd)
+ }
+ // Offset is guaranteed to be positive at this point.
+ return LockRange{Start: uint64(offset), End: end}, nil
+}
diff --git a/pkg/sentry/fs/lock/lock_range.go b/pkg/sentry/fs/lock/lock_range.go
new file mode 100755
index 000000000..7a6f77640
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock_range.go
@@ -0,0 +1,62 @@
+package lock
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type LockRange struct {
+ // Start is the inclusive start of the range.
+ Start uint64
+
+ // End is the exclusive end of the range.
+ End uint64
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r LockRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r LockRange) Length() uint64 {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r LockRange) Contains(x uint64) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r LockRange) Overlaps(r2 LockRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r LockRange) IsSupersetOf(r2 LockRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r LockRange) Intersect(r2 LockRange) LockRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r LockRange) CanSplitAt(x uint64) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/sentry/fs/lock/lock_set.go b/pkg/sentry/fs/lock/lock_set.go
new file mode 100755
index 000000000..127ca5012
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock_set.go
@@ -0,0 +1,1270 @@
+package lock
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ LockminDegree = 3
+
+ LockmaxDegree = 2 * LockminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type LockSet struct {
+ root Locknode `state:".(*LockSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *LockSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *LockSet) IsEmptyRange(r LockRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *LockSet) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *LockSet) SpanRange(r LockRange) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *LockSet) FirstSegment() LockIterator {
+ if s.root.nrSegments == 0 {
+ return LockIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *LockSet) LastSegment() LockIterator {
+ if s.root.nrSegments == 0 {
+ return LockIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *LockSet) FirstGap() LockGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return LockGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *LockSet) LastGap() LockGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return LockGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *LockSet) Find(key uint64) (LockIterator, LockGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return LockIterator{n, i}, LockGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return LockIterator{}, LockGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *LockSet) FindSegment(key uint64) LockIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *LockSet) LowerBoundSegment(min uint64) LockIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *LockSet) UpperBoundSegment(max uint64) LockIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *LockSet) FindGap(key uint64) LockGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *LockSet) LowerBoundGap(min uint64) LockGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *LockSet) UpperBoundGap(max uint64) LockGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *LockSet) Add(r LockRange, val Lock) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *LockSet) AddWithoutMerging(r LockRange, val Lock) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *LockSet) Insert(gap LockGapIterator, r LockRange, val Lock) LockIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (lockSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (lockSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (lockSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *LockSet) InsertWithoutMerging(gap LockGapIterator, r LockRange, val Lock) LockIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *LockSet) InsertWithoutMergingUnchecked(gap LockGapIterator, r LockRange, val Lock) LockIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return LockIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *LockSet) Remove(seg LockIterator) LockGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ lockSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(LockGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *LockSet) RemoveAll() {
+ s.root = Locknode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *LockSet) RemoveRange(r LockRange) LockGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *LockSet) Merge(first, second LockIterator) LockIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *LockSet) MergeUnchecked(first, second LockIterator) LockIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (lockSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return LockIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *LockSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *LockSet) MergeRange(r LockRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *LockSet) MergeAdjacent(r LockRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *LockSet) Split(seg LockIterator, split uint64) (LockIterator, LockIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *LockSet) SplitUnchecked(seg LockIterator, split uint64) (LockIterator, LockIterator) {
+ val1, val2 := (lockSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), LockRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *LockSet) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *LockSet) Isolate(seg LockIterator, r LockRange) LockIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *LockSet) ApplyContiguous(r LockRange, fn func(seg LockIterator)) LockGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return LockGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return LockGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type Locknode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *Locknode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [LockmaxDegree - 1]LockRange
+ values [LockmaxDegree - 1]Lock
+ children [LockmaxDegree]*Locknode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *Locknode) firstSegment() LockIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return LockIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *Locknode) lastSegment() LockIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return LockIterator{n, n.nrSegments - 1}
+}
+
+func (n *Locknode) prevSibling() *Locknode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *Locknode) nextSibling() *Locknode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *Locknode) rebalanceBeforeInsert(gap LockGapIterator) LockGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < LockmaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &Locknode{
+ nrSegments: LockminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &Locknode{
+ nrSegments: LockminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:LockminDegree-1], n.keys[:LockminDegree-1])
+ copy(left.values[:LockminDegree-1], n.values[:LockminDegree-1])
+ copy(right.keys[:LockminDegree-1], n.keys[LockminDegree:])
+ copy(right.values[:LockminDegree-1], n.values[LockminDegree:])
+ n.keys[0], n.values[0] = n.keys[LockminDegree-1], n.values[LockminDegree-1]
+ LockzeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:LockminDegree], n.children[:LockminDegree])
+ copy(right.children[:LockminDegree], n.children[LockminDegree:])
+ LockzeroNodeSlice(n.children[2:])
+ for i := 0; i < LockminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < LockminDegree {
+ return LockGapIterator{left, gap.index}
+ }
+ return LockGapIterator{right, gap.index - LockminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[LockminDegree-1], n.values[LockminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &Locknode{
+ nrSegments: LockminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:LockminDegree-1], n.keys[LockminDegree:])
+ copy(sibling.values[:LockminDegree-1], n.values[LockminDegree:])
+ LockzeroValueSlice(n.values[LockminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:LockminDegree], n.children[LockminDegree:])
+ LockzeroNodeSlice(n.children[LockminDegree:])
+ for i := 0; i < LockminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = LockminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < LockminDegree {
+ return gap
+ }
+ return LockGapIterator{sibling, gap.index - LockminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *Locknode) rebalanceAfterRemove(gap LockGapIterator) LockGapIterator {
+ for {
+ if n.nrSegments >= LockminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= LockminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ lockSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return LockGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return LockGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= LockminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ lockSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return LockGapIterator{n, n.nrSegments}
+ }
+ return LockGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return LockGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return LockGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *Locknode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = LockGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ lockSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type LockIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *Locknode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg LockIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg LockIterator) Range() LockRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg LockIterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg LockIterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg LockIterator) SetRangeUnchecked(r LockRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg LockIterator) SetRange(r LockRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg LockIterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg LockIterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg LockIterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg LockIterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg LockIterator) Value() Lock {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg LockIterator) ValuePtr() *Lock {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg LockIterator) SetValue(val Lock) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg LockIterator) PrevSegment() LockIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return LockIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return LockIterator{}
+ }
+ return LocksegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg LockIterator) NextSegment() LockIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return LockIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return LockIterator{}
+ }
+ return LocksegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg LockIterator) PrevGap() LockGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return LockGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg LockIterator) NextGap() LockGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return LockGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg LockIterator) PrevNonEmpty() (LockIterator, LockGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return LockIterator{}, gap
+ }
+ return gap.PrevSegment(), LockGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg LockIterator) NextNonEmpty() (LockIterator, LockGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return LockIterator{}, gap
+ }
+ return gap.NextSegment(), LockGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type LockGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *Locknode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap LockGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap LockGapIterator) Range() LockRange {
+ return LockRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap LockGapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return lockSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap LockGapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return lockSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap LockGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap LockGapIterator) PrevSegment() LockIterator {
+ return LocksegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap LockGapIterator) NextSegment() LockIterator {
+ return LocksegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap LockGapIterator) PrevGap() LockGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return LockGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap LockGapIterator) NextGap() LockGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return LockGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func LocksegmentBeforePosition(n *Locknode, i int) LockIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return LockIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return LockIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func LocksegmentAfterPosition(n *Locknode, i int) LockIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return LockIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return LockIterator{n, i}
+}
+
+func LockzeroValueSlice(slice []Lock) {
+
+ for i := range slice {
+ lockSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func LockzeroNodeSlice(slice []*Locknode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *LockSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *Locknode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *Locknode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type LockSegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []Lock
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *LockSet) ExportSortedSlices() *LockSegmentDataSlices {
+ var sds LockSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *LockSet) ImportSortedSlices(sds *LockSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := LockRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *LockSet) saveRoot() *LockSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *LockSet) loadRoot(sds *LockSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go
new file mode 100644
index 000000000..8a3ace0c1
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock_set_functions.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lock
+
+import (
+ "math"
+)
+
+// LockSet maps a set of Locks into a file. The key is the file offset.
+
+type lockSetFunctions struct{}
+
+func (lockSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (lockSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (lockSetFunctions) ClearValue(l *Lock) {
+ *l = Lock{}
+}
+
+func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock) (Lock, bool) {
+ // Merge only if the Readers/Writers are identical.
+ if len(val1.Readers) != len(val2.Readers) {
+ return Lock{}, false
+ }
+ for k := range val1.Readers {
+ if !val2.Readers[k] {
+ return Lock{}, false
+ }
+ }
+ if val1.HasWriter != val2.HasWriter {
+ return Lock{}, false
+ }
+ if val1.HasWriter {
+ if val1.Writer != val2.Writer {
+ return Lock{}, false
+ }
+ }
+ return val1, true
+}
+
+func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock) {
+ // Copy the segment so that split segments don't contain map references
+ // to other segments.
+ val0 := Lock{Readers: make(map[UniqueID]bool)}
+ for k, v := range val.Readers {
+ val0.Readers[k] = v
+ }
+ val0.HasWriter = val.HasWriter
+ val0.Writer = val.Writer
+
+ return val, val0
+}
diff --git a/pkg/sentry/fs/lock/lock_state_autogen.go b/pkg/sentry/fs/lock/lock_state_autogen.go
new file mode 100755
index 000000000..abfeea2b6
--- /dev/null
+++ b/pkg/sentry/fs/lock/lock_state_autogen.go
@@ -0,0 +1,106 @@
+// automatically generated by stateify.
+
+package lock
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Lock) beforeSave() {}
+func (x *Lock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Readers", &x.Readers)
+ m.Save("HasWriter", &x.HasWriter)
+ m.Save("Writer", &x.Writer)
+}
+
+func (x *Lock) afterLoad() {}
+func (x *Lock) load(m state.Map) {
+ m.Load("Readers", &x.Readers)
+ m.Load("HasWriter", &x.HasWriter)
+ m.Load("Writer", &x.Writer)
+}
+
+func (x *Locks) beforeSave() {}
+func (x *Locks) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.blockedQueue) { m.Failf("blockedQueue is %v, expected zero", x.blockedQueue) }
+ m.Save("locks", &x.locks)
+}
+
+func (x *Locks) afterLoad() {}
+func (x *Locks) load(m state.Map) {
+ m.Load("locks", &x.locks)
+}
+
+func (x *LockRange) beforeSave() {}
+func (x *LockRange) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *LockRange) afterLoad() {}
+func (x *LockRange) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func (x *LockSet) beforeSave() {}
+func (x *LockSet) save(m state.Map) {
+ x.beforeSave()
+ var root *LockSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *LockSet) afterLoad() {}
+func (x *LockSet) load(m state.Map) {
+ m.LoadValue("root", new(*LockSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*LockSegmentDataSlices)) })
+}
+
+func (x *Locknode) beforeSave() {}
+func (x *Locknode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *Locknode) afterLoad() {}
+func (x *Locknode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *LockSegmentDataSlices) beforeSave() {}
+func (x *LockSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *LockSegmentDataSlices) afterLoad() {}
+func (x *LockSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func init() {
+ state.Register("lock.Lock", (*Lock)(nil), state.Fns{Save: (*Lock).save, Load: (*Lock).load})
+ state.Register("lock.Locks", (*Locks)(nil), state.Fns{Save: (*Locks).save, Load: (*Locks).load})
+ state.Register("lock.LockRange", (*LockRange)(nil), state.Fns{Save: (*LockRange).save, Load: (*LockRange).load})
+ state.Register("lock.LockSet", (*LockSet)(nil), state.Fns{Save: (*LockSet).save, Load: (*LockSet).load})
+ state.Register("lock.Locknode", (*Locknode)(nil), state.Fns{Save: (*Locknode).save, Load: (*Locknode).load})
+ state.Register("lock.LockSegmentDataSlices", (*LockSegmentDataSlices)(nil), state.Fns{Save: (*LockSegmentDataSlices).save, Load: (*LockSegmentDataSlices).load})
+}
diff --git a/pkg/sentry/fs/mock.go b/pkg/sentry/fs/mock.go
new file mode 100644
index 000000000..ff04e9b22
--- /dev/null
+++ b/pkg/sentry/fs/mock.go
@@ -0,0 +1,170 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// MockInodeOperations implements InodeOperations for testing Inodes.
+type MockInodeOperations struct {
+ InodeOperations
+
+ UAttr UnstableAttr
+
+ createCalled bool
+ createDirectoryCalled bool
+ createLinkCalled bool
+ renameCalled bool
+ walkCalled bool
+}
+
+// NewMockInode returns a mock *Inode using MockInodeOperations.
+func NewMockInode(ctx context.Context, msrc *MountSource, sattr StableAttr) *Inode {
+ return NewInode(NewMockInodeOperations(ctx), msrc, sattr)
+}
+
+// NewMockInodeOperations returns a *MockInodeOperations.
+func NewMockInodeOperations(ctx context.Context) *MockInodeOperations {
+ return &MockInodeOperations{
+ UAttr: WithCurrentTime(ctx, UnstableAttr{
+ Perms: FilePermsFromMode(0777),
+ }),
+ }
+}
+
+// MockMountSourceOps implements fs.MountSourceOperations.
+type MockMountSourceOps struct {
+ MountSourceOperations
+ keep bool
+ revalidate bool
+}
+
+// NewMockMountSource returns a new *MountSource using MockMountSourceOps.
+func NewMockMountSource(cache *DirentCache) *MountSource {
+ var keep bool
+ if cache != nil {
+ keep = cache.maxSize > 0
+ }
+ return &MountSource{
+ MountSourceOperations: &MockMountSourceOps{keep: keep},
+ fscache: cache,
+ }
+}
+
+// Revalidate implements fs.MountSourceOperations.Revalidate.
+func (n *MockMountSourceOps) Revalidate(context.Context, string, *Inode, *Inode) bool {
+ return n.revalidate
+}
+
+// Keep implements fs.MountSourceOperations.Keep.
+func (n *MockMountSourceOps) Keep(dirent *Dirent) bool {
+ return n.keep
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (n *MockInodeOperations) WriteOut(context.Context, *Inode) error {
+ return nil
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (n *MockInodeOperations) UnstableAttr(context.Context, *Inode) (UnstableAttr, error) {
+ return n.UAttr, nil
+}
+
+// IsVirtual implements fs.InodeOperations.IsVirtual.
+func (n *MockInodeOperations) IsVirtual() bool {
+ return false
+}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (n *MockInodeOperations) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) {
+ n.walkCalled = true
+ return NewDirent(NewInode(&MockInodeOperations{}, dir.MountSource, StableAttr{}), p), nil
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (n *MockInodeOperations) SetPermissions(context.Context, *Inode, FilePermissions) bool {
+ return false
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (*MockInodeOperations) SetOwner(context.Context, *Inode, FileOwner) error {
+ return syserror.EINVAL
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (n *MockInodeOperations) SetTimestamps(context.Context, *Inode, TimeSpec) error {
+ return nil
+}
+
+// Create implements fs.InodeOperations.Create.
+func (n *MockInodeOperations) Create(ctx context.Context, dir *Inode, p string, flags FileFlags, perms FilePermissions) (*File, error) {
+ n.createCalled = true
+ d := NewDirent(NewInode(&MockInodeOperations{}, dir.MountSource, StableAttr{}), p)
+ return &File{Dirent: d}, nil
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+func (n *MockInodeOperations) CreateLink(_ context.Context, dir *Inode, oldname string, newname string) error {
+ n.createLinkCalled = true
+ return nil
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+func (n *MockInodeOperations) CreateDirectory(context.Context, *Inode, string, FilePermissions) error {
+ n.createDirectoryCalled = true
+ return nil
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (n *MockInodeOperations) Rename(ctx context.Context, inode *Inode, oldParent *Inode, oldName string, newParent *Inode, newName string, replacement bool) error {
+ n.renameCalled = true
+ return nil
+}
+
+// Check implements fs.InodeOperations.Check.
+func (n *MockInodeOperations) Check(ctx context.Context, inode *Inode, p PermMask) bool {
+ return ContextCanAccessFile(ctx, inode, p)
+}
+
+// Release implements fs.InodeOperations.Release.
+func (n *MockInodeOperations) Release(context.Context) {}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (n *MockInodeOperations) Truncate(ctx context.Context, inode *Inode, size int64) error {
+ return nil
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (n *MockInodeOperations) Allocate(ctx context.Context, inode *Inode, offset, length int64) error {
+ return nil
+}
+
+// Remove implements fs.InodeOperations.Remove.
+func (n *MockInodeOperations) Remove(context.Context, *Inode, string) error {
+ return nil
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+func (n *MockInodeOperations) RemoveDirectory(context.Context, *Inode, string) error {
+ return nil
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (n *MockInodeOperations) Getlink(context.Context, *Inode) (*Dirent, error) {
+ return nil, syserror.ENOLINK
+}
diff --git a/pkg/sentry/fs/mount.go b/pkg/sentry/fs/mount.go
new file mode 100644
index 000000000..41e0d285b
--- /dev/null
+++ b/pkg/sentry/fs/mount.go
@@ -0,0 +1,267 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "bytes"
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// DirentOperations provide file systems greater control over how long a Dirent stays pinned
+// in core. Implementations must not take Dirent.mu.
+type DirentOperations interface {
+ // Revalidate is called during lookup each time we encounter a Dirent
+ // in the cache. Implementations may update stale properties of the
+ // child Inode. If Revalidate returns true, then the entire Inode will
+ // be reloaded.
+ //
+ // Revalidate will never be called on a Inode that is mounted.
+ Revalidate(ctx context.Context, name string, parent, child *Inode) bool
+
+ // Keep returns true if the Dirent should be kept in memory for as long
+ // as possible beyond any active references.
+ Keep(dirent *Dirent) bool
+}
+
+// MountSourceOperations contains filesystem specific operations.
+type MountSourceOperations interface {
+ // DirentOperations provide optional extra management of Dirents.
+ DirentOperations
+
+ // Destroy destroys the MountSource.
+ Destroy()
+
+ // Below are MountSourceOperations that do not conform to Linux.
+
+ // ResetInodeMappings clears all mappings of Inodes before SaveInodeMapping
+ // is called.
+ ResetInodeMappings()
+
+ // SaveInodeMappings is called during saving to store, for each reachable
+ // Inode in the mounted filesystem, a mapping of Inode.StableAttr.InodeID
+ // to the Inode's path relative to its mount point. If an Inode is
+ // reachable at more than one path due to hard links, it is unspecified
+ // which path is mapped. Filesystems that do not use this information to
+ // restore inodes can make SaveInodeMappings a no-op.
+ SaveInodeMapping(inode *Inode, path string)
+}
+
+// InodeMappings defines a fmt.Stringer MountSource Inode mappings.
+type InodeMappings map[uint64]string
+
+// String implements fmt.Stringer.String.
+func (i InodeMappings) String() string {
+ var mappingsBuf bytes.Buffer
+ mappingsBuf.WriteString("\n")
+ for ino, name := range i {
+ mappingsBuf.WriteString(fmt.Sprintf("\t%q\t\tinode number %d\n", name, ino))
+ }
+ return mappingsBuf.String()
+}
+
+// MountSource represents a source of file objects.
+//
+// MountSource corresponds to struct super_block in Linux.
+//
+// A mount source may represent a physical device (or a partition of a physical
+// device) or a virtual source of files such as procfs for a specific PID
+// namespace. There should be only one mount source per logical device. E.g.
+// there should be only procfs mount source for a given PID namespace.
+//
+// A mount source represents files as inodes. Every inode belongs to exactly
+// one mount source. Each file object may only be represented using one inode
+// object in a sentry instance.
+//
+// TODO(b/63601033): Move Flags out of MountSource to Mount.
+//
+// +stateify savable
+type MountSource struct {
+ refs.AtomicRefCount
+
+ // MountSourceOperations defines filesystem specific behavior.
+ MountSourceOperations
+
+ // FilesystemType is the type of the filesystem backing this mount.
+ FilesystemType string
+
+ // Flags are the flags that this filesystem was mounted with.
+ Flags MountSourceFlags
+
+ // fscache keeps Dirents pinned beyond application references to them.
+ // It must be flushed before kernel.SaveTo.
+ fscache *DirentCache
+
+ // direntRefs is the sum of references on all Dirents in this MountSource.
+ //
+ // direntRefs is increased when a Dirent in MountSource is IncRef'd, and
+ // decreased when a Dirent in MountSource is DecRef'd.
+ //
+ // To cleanly unmount a MountSource, one must check that no direntRefs are
+ // held anymore. To check, one must hold root.parent.dirMu of the
+ // MountSource's root Dirent before reading direntRefs to prevent further
+ // walks to Dirents in this MountSource.
+ //
+ // direntRefs must be atomically changed.
+ direntRefs uint64
+}
+
+// DefaultDirentCacheSize is the number of Dirents that the VFS can hold an
+// extra reference on.
+const DefaultDirentCacheSize uint64 = 1000
+
+// NewMountSource returns a new MountSource. Filesystem may be nil if there is no
+// filesystem backing the mount.
+func NewMountSource(mops MountSourceOperations, filesystem Filesystem, flags MountSourceFlags) *MountSource {
+ fsType := "none"
+ if filesystem != nil {
+ fsType = filesystem.Name()
+ }
+ return &MountSource{
+ MountSourceOperations: mops,
+ Flags: flags,
+ FilesystemType: fsType,
+ fscache: NewDirentCache(DefaultDirentCacheSize),
+ }
+}
+
+// DirentRefs returns the current mount direntRefs.
+func (msrc *MountSource) DirentRefs() uint64 {
+ return atomic.LoadUint64(&msrc.direntRefs)
+}
+
+// IncDirentRefs increases direntRefs.
+func (msrc *MountSource) IncDirentRefs() {
+ atomic.AddUint64(&msrc.direntRefs, 1)
+}
+
+// DecDirentRefs decrements direntRefs.
+func (msrc *MountSource) DecDirentRefs() {
+ if atomic.AddUint64(&msrc.direntRefs, ^uint64(0)) == ^uint64(0) {
+ panic("Decremented zero mount reference direntRefs")
+ }
+}
+
+func (msrc *MountSource) destroy() {
+ if c := msrc.DirentRefs(); c != 0 {
+ panic(fmt.Sprintf("MountSource with non-zero direntRefs is being destroyed: %d", c))
+ }
+ msrc.MountSourceOperations.Destroy()
+}
+
+// DecRef drops a reference on the MountSource.
+func (msrc *MountSource) DecRef() {
+ msrc.DecRefWithDestructor(msrc.destroy)
+}
+
+// FlushDirentRefs drops all references held by the MountSource on Dirents.
+func (msrc *MountSource) FlushDirentRefs() {
+ msrc.fscache.Invalidate()
+}
+
+// SetDirentCacheMaxSize sets the max size to the dirent cache associated with
+// this mount source.
+func (msrc *MountSource) SetDirentCacheMaxSize(max uint64) {
+ msrc.fscache.setMaxSize(max)
+}
+
+// SetDirentCacheLimiter sets the limiter objcet to the dirent cache associated
+// with this mount source.
+func (msrc *MountSource) SetDirentCacheLimiter(l *DirentCacheLimiter) {
+ msrc.fscache.limit = l
+}
+
+// NewCachingMountSource returns a generic mount that will cache dirents
+// aggressively.
+func NewCachingMountSource(filesystem Filesystem, flags MountSourceFlags) *MountSource {
+ return NewMountSource(&SimpleMountSourceOperations{
+ keep: true,
+ revalidate: false,
+ }, filesystem, flags)
+}
+
+// NewNonCachingMountSource returns a generic mount that will never cache dirents.
+func NewNonCachingMountSource(filesystem Filesystem, flags MountSourceFlags) *MountSource {
+ return NewMountSource(&SimpleMountSourceOperations{
+ keep: false,
+ revalidate: false,
+ }, filesystem, flags)
+}
+
+// NewRevalidatingMountSource returns a generic mount that will cache dirents,
+// but will revalidate them on each lookup.
+func NewRevalidatingMountSource(filesystem Filesystem, flags MountSourceFlags) *MountSource {
+ return NewMountSource(&SimpleMountSourceOperations{
+ keep: true,
+ revalidate: true,
+ }, filesystem, flags)
+}
+
+// NewPseudoMountSource returns a "pseudo" mount source that is not backed by
+// an actual filesystem. It is always non-caching.
+func NewPseudoMountSource() *MountSource {
+ return NewMountSource(&SimpleMountSourceOperations{
+ keep: false,
+ revalidate: false,
+ }, nil, MountSourceFlags{})
+}
+
+// SimpleMountSourceOperations implements MountSourceOperations.
+//
+// +stateify savable
+type SimpleMountSourceOperations struct {
+ keep bool
+ revalidate bool
+}
+
+// Revalidate implements MountSourceOperations.Revalidate.
+func (smo *SimpleMountSourceOperations) Revalidate(context.Context, string, *Inode, *Inode) bool {
+ return smo.revalidate
+}
+
+// Keep implements MountSourceOperations.Keep.
+func (smo *SimpleMountSourceOperations) Keep(*Dirent) bool {
+ return smo.keep
+}
+
+// ResetInodeMappings implements MountSourceOperations.ResetInodeMappings.
+func (*SimpleMountSourceOperations) ResetInodeMappings() {}
+
+// SaveInodeMapping implements MountSourceOperations.SaveInodeMapping.
+func (*SimpleMountSourceOperations) SaveInodeMapping(*Inode, string) {}
+
+// Destroy implements MountSourceOperations.Destroy.
+func (*SimpleMountSourceOperations) Destroy() {}
+
+// Info defines attributes of a filesystem.
+type Info struct {
+ // Type is the filesystem type magic value.
+ Type uint64
+
+ // TotalBlocks is the total data blocks in the filesystem.
+ TotalBlocks uint64
+
+ // FreeBlocks is the number of free blocks available.
+ FreeBlocks uint64
+
+ // TotalFiles is the total file nodes in the filesystem.
+ TotalFiles uint64
+
+ // FreeFiles is the number of free file nodes.
+ FreeFiles uint64
+}
diff --git a/pkg/sentry/fs/mount_overlay.go b/pkg/sentry/fs/mount_overlay.go
new file mode 100644
index 000000000..535f812c8
--- /dev/null
+++ b/pkg/sentry/fs/mount_overlay.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// overlayMountSourceOperations implements MountSourceOperations for an overlay
+// mount point. The upper filesystem determines the caching behavior of the
+// overlay.
+//
+// +stateify savable
+type overlayMountSourceOperations struct {
+ upper *MountSource
+ lower *MountSource
+}
+
+func newOverlayMountSource(upper, lower *MountSource, flags MountSourceFlags) *MountSource {
+ upper.IncRef()
+ lower.IncRef()
+ msrc := NewMountSource(&overlayMountSourceOperations{
+ upper: upper,
+ lower: lower,
+ }, &overlayFilesystem{}, flags)
+
+ // Use the minimum number to keep resource usage under limits.
+ size := lower.fscache.maxSize
+ if size > upper.fscache.maxSize {
+ size = upper.fscache.maxSize
+ }
+ msrc.fscache.setMaxSize(size)
+
+ return msrc
+}
+
+// Revalidate implements MountSourceOperations.Revalidate for an overlay by
+// delegating to the upper filesystem's Revalidate method. We cannot reload
+// files from the lower filesystem, so we panic if the lower filesystem's
+// Revalidate method returns true.
+func (o *overlayMountSourceOperations) Revalidate(ctx context.Context, name string, parent, child *Inode) bool {
+ if child.overlay == nil {
+ panic("overlay cannot revalidate inode that is not an overlay")
+ }
+
+ // Revalidate is never called on a mount point, so parent and child
+ // must be from the same mount, and thus must both be overlay inodes.
+ if parent.overlay == nil {
+ panic("trying to revalidate an overlay inode but the parent is not an overlay")
+ }
+
+ // We can't revalidate from the lower filesystem.
+ if child.overlay.lower != nil && o.lower.Revalidate(ctx, name, parent.overlay.lower, child.overlay.lower) {
+ panic("an overlay cannot revalidate file objects from the lower fs")
+ }
+
+ // Do we have anything to revalidate?
+ if child.overlay.upper == nil {
+ return false
+ }
+
+ // Does the upper require revalidation?
+ return o.upper.Revalidate(ctx, name, parent.overlay.upper, child.overlay.upper)
+}
+
+// Keep implements MountSourceOperations by delegating to the upper
+// filesystem's Keep method.
+func (o *overlayMountSourceOperations) Keep(dirent *Dirent) bool {
+ return o.upper.Keep(dirent)
+}
+
+// ResetInodeMappings propagates the call to both upper and lower MountSource.
+func (o *overlayMountSourceOperations) ResetInodeMappings() {
+ o.upper.ResetInodeMappings()
+ o.lower.ResetInodeMappings()
+}
+
+// SaveInodeMapping propagates the call to both upper and lower MountSource.
+func (o *overlayMountSourceOperations) SaveInodeMapping(inode *Inode, path string) {
+ inode.overlay.copyMu.RLock()
+ defer inode.overlay.copyMu.RUnlock()
+ if inode.overlay.upper != nil {
+ o.upper.SaveInodeMapping(inode.overlay.upper, path)
+ }
+ if inode.overlay.lower != nil {
+ o.lower.SaveInodeMapping(inode.overlay.lower, path)
+ }
+}
+
+// Destroy drops references on the upper and lower MountSource.
+func (o *overlayMountSourceOperations) Destroy() {
+ o.upper.DecRef()
+ o.lower.DecRef()
+}
+
+// type overlayFilesystem is the filesystem for overlay mounts.
+//
+// +stateify savable
+type overlayFilesystem struct{}
+
+// Name implements Filesystem.Name.
+func (ofs *overlayFilesystem) Name() string {
+ return "overlayfs"
+}
+
+// Flags implements Filesystem.Flags.
+func (ofs *overlayFilesystem) Flags() FilesystemFlags {
+ return 0
+}
+
+// AllowUserMount implements Filesystem.AllowUserMount.
+func (ofs *overlayFilesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList implements Filesystem.AllowUserList.
+func (*overlayFilesystem) AllowUserList() bool {
+ return true
+}
+
+// Mount implements Filesystem.Mount.
+func (ofs *overlayFilesystem) Mount(ctx context.Context, device string, flags MountSourceFlags, data string, _ interface{}) (*Inode, error) {
+ panic("overlayFilesystem.Mount should not be called!")
+}
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
new file mode 100644
index 000000000..a5c52d7ba
--- /dev/null
+++ b/pkg/sentry/fs/mounts.go
@@ -0,0 +1,675 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "math"
+ "path"
+ "strings"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// DefaultTraversalLimit provides a sensible default traversal limit that may
+// be passed to FindInode and FindLink. You may want to provide other options in
+// individual syscall implementations, but for internal functions this will be
+// sane.
+const DefaultTraversalLimit = 10
+
+const invalidMountID = math.MaxUint64
+
+// Mount represents a mount in the file system. It holds the root dirent for the
+// mount. It also points back to the dirent or mount where it was mounted over,
+// so that it can be restored when unmounted. The chained mount can be either:
+// - Mount: when it's mounted on top of another mount point.
+// - Dirent: when it's mounted on top of a dirent. In this case the mount is
+// called an "undo" mount and only 'root' is set. All other fields are
+// either invalid or nil.
+//
+// +stateify savable
+type Mount struct {
+ // ID is a unique id for this mount. It may be invalidMountID if this is
+ // used to cache a dirent that was mounted over.
+ ID uint64
+
+ // ParentID is the parent's mount unique id. It may be invalidMountID if this
+ // is the root mount or if this is used to cache a dirent that was mounted
+ // over.
+ ParentID uint64
+
+ // root is the root Dirent of this mount. A reference on this Dirent must be
+ // held through the lifetime of the Mount which contains it.
+ root *Dirent
+
+ // previous is the existing dirent or mount that this object was mounted over.
+ // It's nil for the root mount and for the last entry in the chain (always an
+ // "undo" mount).
+ previous *Mount
+}
+
+// newMount creates a new mount, taking a reference on 'root'. Caller must
+// release the reference when it's done with the mount.
+func newMount(id, pid uint64, root *Dirent) *Mount {
+ root.IncRef()
+ return &Mount{
+ ID: id,
+ ParentID: pid,
+ root: root,
+ }
+}
+
+// newRootMount creates a new root mount (no parent), taking a reference on
+// 'root'. Caller must release the reference when it's done with the mount.
+func newRootMount(id uint64, root *Dirent) *Mount {
+ root.IncRef()
+ return &Mount{
+ ID: id,
+ ParentID: invalidMountID,
+ root: root,
+ }
+}
+
+// newUndoMount creates a new undo mount, taking a reference on 'd'. Caller must
+// release the reference when it's done with the mount.
+func newUndoMount(d *Dirent) *Mount {
+ d.IncRef()
+ return &Mount{
+ ID: invalidMountID,
+ ParentID: invalidMountID,
+ root: d,
+ }
+}
+
+// Root returns the root dirent of this mount. Callers must call DecRef on the
+// returned dirent.
+func (m *Mount) Root() *Dirent {
+ m.root.IncRef()
+ return m.root
+}
+
+// IsRoot returns true if the mount has no parent.
+func (m *Mount) IsRoot() bool {
+ return !m.IsUndo() && m.ParentID == invalidMountID
+}
+
+// IsUndo returns true if 'm' is an undo mount that should be used to restore
+// the original dirent during unmount only and it's not a valid mount.
+func (m *Mount) IsUndo() bool {
+ if m.ID == invalidMountID {
+ if m.ParentID != invalidMountID {
+ panic(fmt.Sprintf("Undo mount with valid parentID: %+v", m))
+ }
+ return true
+ }
+ return false
+}
+
+// MountNamespace defines a collection of mounts.
+//
+// +stateify savable
+type MountNamespace struct {
+ refs.AtomicRefCount
+
+ // userns is the user namespace associated with this mount namespace.
+ //
+ // All privileged operations on this mount namespace must have
+ // appropriate capabilities in this userns.
+ //
+ // userns is immutable.
+ userns *auth.UserNamespace
+
+ // root is the root directory.
+ root *Dirent
+
+ // mu protects mounts and mountID counter.
+ mu sync.Mutex `state:"nosave"`
+
+ // mounts is a map of mounted Dirent -> Mount object. There are three
+ // possible cases:
+ // - Dirent is mounted over a mount point: the stored Mount object will be
+ // the Mount for that mount point.
+ // - Dirent is mounted over a regular (non-mount point) Dirent: the stored
+ // Mount object will be an "undo" mount containing the mounted-over
+ // Dirent.
+ // - Dirent is the root mount: the stored Mount object will be a root mount
+ // containing the Dirent itself.
+ mounts map[*Dirent]*Mount
+
+ // mountID is the next mount id to assign.
+ mountID uint64
+}
+
+// NewMountNamespace returns a new MountNamespace, with the provided node at the
+// root, and the given cache size. A root must always be provided.
+func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error) {
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Set the root dirent and id on the root mount. The reference returned from
+ // NewDirent will be donated to the MountNamespace constructed below.
+ d := NewDirent(root, "/")
+
+ mnts := map[*Dirent]*Mount{
+ d: newRootMount(1, d),
+ }
+
+ return &MountNamespace{
+ userns: creds.UserNamespace,
+ root: d,
+ mounts: mnts,
+ mountID: 2,
+ }, nil
+}
+
+// UserNamespace returns the user namespace associated with this mount manager.
+func (mns *MountNamespace) UserNamespace() *auth.UserNamespace {
+ return mns.userns
+}
+
+// Root returns the MountNamespace's root Dirent and increments its reference
+// count. The caller must call DecRef when finished.
+func (mns *MountNamespace) Root() *Dirent {
+ mns.root.IncRef()
+ return mns.root
+}
+
+// FlushMountSourceRefs flushes extra references held by MountSources for all active mount points;
+// see fs/mount.go:MountSource.FlushDirentRefs.
+func (mns *MountNamespace) FlushMountSourceRefs() {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+ mns.flushMountSourceRefsLocked()
+}
+
+func (mns *MountNamespace) flushMountSourceRefsLocked() {
+ // Flush mounts' MountSource references.
+ for _, mp := range mns.mounts {
+ for ; mp != nil; mp = mp.previous {
+ mp.root.Inode.MountSource.FlushDirentRefs()
+ }
+ }
+
+ // Flush root's MountSource references.
+ mns.root.Inode.MountSource.FlushDirentRefs()
+}
+
+// destroy drops root and mounts dirent references and closes any original nodes.
+//
+// After destroy is called, the MountNamespace may continue to be referenced (for
+// example via /proc/mounts), but should free all resources and shouldn't have
+// Find* methods called.
+func (mns *MountNamespace) destroy() {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+
+ // Flush all mounts' MountSource references to Dirents. This allows for mount
+ // points to be torn down since there should be no remaining references after
+ // this and DecRef below.
+ mns.flushMountSourceRefsLocked()
+
+ // Teardown mounts.
+ for _, mp := range mns.mounts {
+ // Drop the mount reference on all mounted dirents.
+ for ; mp != nil; mp = mp.previous {
+ mp.root.DecRef()
+ }
+ }
+ mns.mounts = nil
+
+ // Drop reference on the root.
+ mns.root.DecRef()
+
+ // Wait for asynchronous work (queued by dropping Dirent references
+ // above) to complete before destroying this MountNamespace.
+ AsyncBarrier()
+}
+
+// DecRef implements RefCounter.DecRef with destructor mns.destroy.
+func (mns *MountNamespace) DecRef() {
+ mns.DecRefWithDestructor(mns.destroy)
+}
+
+// Freeze freezes the entire mount tree.
+func (mns *MountNamespace) Freeze() {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+
+ // We only want to freeze Dirents with active references, not Dirents referenced
+ // by a mount's MountSource.
+ mns.flushMountSourceRefsLocked()
+
+ // Freeze the entire shebang.
+ mns.root.Freeze()
+}
+
+// withMountLocked prevents further walks to `node`, because `node` is about to
+// be a mount point.
+func (mns *MountNamespace) withMountLocked(node *Dirent, fn func() error) error {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+
+ renameMu.Lock()
+ defer renameMu.Unlock()
+
+ // Linux allows mounting over the root (?). It comes with a strange set
+ // of semantics. We'll just not do this for now.
+ if node.parent == nil {
+ return syserror.EBUSY
+ }
+
+ // For both mount and unmount, we take this lock so we can swap out the
+ // appropriate child in parent.children.
+ //
+ // For unmount, this also ensures that if `node` is a mount point, the
+ // underlying mount's MountSource.direntRefs cannot increase by preventing
+ // walks to node.
+ node.parent.dirMu.Lock()
+ defer node.parent.dirMu.Unlock()
+
+ node.parent.mu.Lock()
+ defer node.parent.mu.Unlock()
+
+ // We need not take node.dirMu since we have parent.dirMu.
+
+ // We need to take node.mu, so that we can check for deletion.
+ node.mu.Lock()
+ defer node.mu.Unlock()
+
+ return fn()
+}
+
+// Mount mounts a `inode` over the subtree at `node`.
+func (mns *MountNamespace) Mount(ctx context.Context, mountPoint *Dirent, inode *Inode) error {
+ return mns.withMountLocked(mountPoint, func() error {
+ replacement, err := mountPoint.mount(ctx, inode)
+ if err != nil {
+ return err
+ }
+ defer replacement.DecRef()
+
+ // Set the mount's root dirent and id.
+ parentMnt := mns.findMountLocked(mountPoint)
+ childMnt := newMount(mns.mountID, parentMnt.ID, replacement)
+ mns.mountID++
+
+ // Drop mountPoint from its dirent cache.
+ mountPoint.dropExtendedReference()
+
+ // If mountPoint is already a mount, push mountPoint on the stack so it can
+ // be recovered on unmount.
+ if prev := mns.mounts[mountPoint]; prev != nil {
+ childMnt.previous = prev
+ mns.mounts[replacement] = childMnt
+ delete(mns.mounts, mountPoint)
+ return nil
+ }
+
+ // Was not already mounted, just add another mount point.
+ childMnt.previous = newUndoMount(mountPoint)
+ mns.mounts[replacement] = childMnt
+ return nil
+ })
+}
+
+// Unmount ensures no references to the MountSource remain and removes `node` from
+// this subtree. The subtree formerly mounted in `node`'s place will be
+// restored. node's MountSource will be destroyed as soon as the last reference to
+// `node` is dropped, as no references to Dirents within will remain.
+//
+// If detachOnly is set, Unmount merely removes `node` from the subtree, but
+// allows existing references to the MountSource remain. E.g. if an open file still
+// refers to Dirents in MountSource, the Unmount will succeed anyway and MountSource will
+// be destroyed at a later time when all references to Dirents within are
+// dropped.
+//
+// The caller must hold a reference to node from walking to it.
+func (mns *MountNamespace) Unmount(ctx context.Context, node *Dirent, detachOnly bool) error {
+ // This takes locks to prevent further walks to Dirents in this mount
+ // under the assumption that `node` is the root of the mount.
+ return mns.withMountLocked(node, func() error {
+ orig, ok := mns.mounts[node]
+ if !ok {
+ // node is not a mount point.
+ return syserror.EINVAL
+ }
+
+ if orig.previous == nil {
+ panic("cannot unmount initial dirent")
+ }
+
+ m := node.Inode.MountSource
+ if !detachOnly {
+ // Flush all references on the mounted node.
+ m.FlushDirentRefs()
+
+ // At this point, exactly two references must be held
+ // to mount: one mount reference on node, and one due
+ // to walking to node.
+ //
+ // We must also be guaranteed that no more references
+ // can be taken on mount. This is why withMountLocked
+ // must be held at this point to prevent any walks to
+ // and from node.
+ if refs := m.DirentRefs(); refs < 2 {
+ panic(fmt.Sprintf("have %d refs on unmount, expect 2 or more", refs))
+ } else if refs != 2 {
+ return syserror.EBUSY
+ }
+ }
+
+ prev := orig.previous
+ if err := node.unmount(ctx, prev.root); err != nil {
+ return err
+ }
+
+ if prev.previous == nil {
+ if !prev.IsUndo() {
+ panic(fmt.Sprintf("Last mount in the chain must be a undo mount: %+v", prev))
+ }
+ // Drop mount reference taken at the end of MountNamespace.Mount.
+ prev.root.DecRef()
+ } else {
+ mns.mounts[prev.root] = prev
+ }
+ delete(mns.mounts, node)
+
+ return nil
+ })
+}
+
+// FindMount returns the mount that 'd' belongs to. It walks the dirent back
+// until a mount is found. It may return nil if no mount was found.
+func (mns *MountNamespace) FindMount(d *Dirent) *Mount {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+ renameMu.Lock()
+ defer renameMu.Unlock()
+
+ return mns.findMountLocked(d)
+}
+
+func (mns *MountNamespace) findMountLocked(d *Dirent) *Mount {
+ for {
+ if mnt := mns.mounts[d]; mnt != nil {
+ return mnt
+ }
+ if d.parent == nil {
+ return nil
+ }
+ d = d.parent
+ }
+}
+
+// AllMountsUnder returns a slice of all mounts under the parent, including
+// itself.
+func (mns *MountNamespace) AllMountsUnder(parent *Mount) []*Mount {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+
+ var rv []*Mount
+ for _, mp := range mns.mounts {
+ if !mp.IsUndo() && mp.root.descendantOf(parent.root) {
+ rv = append(rv, mp)
+ }
+ }
+ return rv
+}
+
+// FindLink returns an Dirent from a given node, which may be a symlink.
+//
+// The root argument is treated as the root directory, and FindLink will not
+// return anything above that. The wd dirent provides the starting directory,
+// and may be nil which indicates the root should be used. You must call DecRef
+// on the resulting Dirent when you are no longer using the object.
+//
+// If wd is nil, then the root will be used as the working directory. If the
+// path is absolute, this has no functional impact.
+//
+// Precondition: root must be non-nil.
+// Precondition: the path must be non-empty.
+func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path string, remainingTraversals *uint) (*Dirent, error) {
+ if root == nil {
+ panic("MountNamespace.FindLink: root must not be nil")
+ }
+ if len(path) == 0 {
+ panic("MountNamespace.FindLink: path is empty")
+ }
+
+ // Split the path.
+ first, remainder := SplitFirst(path)
+
+ // Where does this walk originate?
+ current := wd
+ if current == nil {
+ current = root
+ }
+ for first == "/" {
+ // Special case: it's possible that we have nothing to walk at
+ // all. This is necessary since we're resplitting the path.
+ if remainder == "" {
+ root.IncRef()
+ return root, nil
+ }
+
+ // Start at the root and advance the path component so that the
+ // walk below can proceed. Note at this point, it handles the
+ // no-op walk case perfectly fine.
+ current = root
+ first, remainder = SplitFirst(remainder)
+ }
+
+ current.IncRef() // Transferred during walk.
+
+ for {
+ // Check that the file is a directory and that we have
+ // permissions to walk.
+ //
+ // Note that we elide this check for the root directory as an
+ // optimization; a non-executable root may still be walked. A
+ // non-directory root is hopeless.
+ if current != root {
+ if !IsDir(current.Inode.StableAttr) {
+ current.DecRef() // Drop reference from above.
+ return nil, syserror.ENOTDIR
+ }
+ if err := current.Inode.CheckPermission(ctx, PermMask{Execute: true}); err != nil {
+ current.DecRef() // Drop reference from above.
+ return nil, err
+ }
+ }
+
+ // Move to the next level.
+ next, err := current.Walk(ctx, root, first)
+ if err != nil {
+ // Allow failed walks to cache the dirent, because no
+ // children will acquire a reference at the end.
+ current.maybeExtendReference()
+ current.DecRef()
+ return nil, err
+ }
+
+ // Drop old reference.
+ current.DecRef()
+
+ if remainder != "" {
+ // Ensure it's resolved, unless it's the last level.
+ //
+ // See resolve for reference semantics; on err next
+ // will have one dropped.
+ current, err = mns.resolve(ctx, root, next, remainingTraversals)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ // Allow the file system to take an extra reference on the
+ // found child. This will hold a reference on the containing
+ // directory, so the whole tree will be implicitly cached.
+ next.maybeExtendReference()
+ return next, nil
+ }
+
+ // Move to the next element.
+ first, remainder = SplitFirst(remainder)
+ }
+}
+
+// FindInode is identical to FindLink except the return value is resolved.
+//
+//go:nosplit
+func (mns *MountNamespace) FindInode(ctx context.Context, root, wd *Dirent, path string, remainingTraversals *uint) (*Dirent, error) {
+ d, err := mns.FindLink(ctx, root, wd, path, remainingTraversals)
+ if err != nil {
+ return nil, err
+ }
+
+ // See resolve for reference semantics; on err d will have the
+ // reference dropped.
+ return mns.resolve(ctx, root, d, remainingTraversals)
+}
+
+// resolve resolves the given link.
+//
+// If successful, a reference is dropped on node and one is acquired on the
+// caller's behalf for the returned dirent.
+//
+// If not successful, a reference is _also_ dropped on the node and an error
+// returned. This is for convenience in using resolve directly as a return
+// value.
+func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, remainingTraversals *uint) (*Dirent, error) {
+ // Resolve the path.
+ target, err := node.Inode.Getlink(ctx)
+
+ switch err {
+ case nil:
+ // Make sure we didn't exhaust the traversal budget.
+ if *remainingTraversals == 0 {
+ target.DecRef()
+ return nil, syscall.ELOOP
+ }
+
+ node.DecRef() // Drop the original reference.
+ return target, nil
+
+ case syscall.ENOLINK:
+ // Not a symlink.
+ return node, nil
+
+ case ErrResolveViaReadlink:
+ defer node.DecRef() // See above.
+
+ // First, check if we should traverse.
+ if *remainingTraversals == 0 {
+ return nil, syscall.ELOOP
+ }
+
+ // Read the target path.
+ targetPath, err := node.Inode.Readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ // Find the node; we resolve relative to the current symlink's parent.
+ *remainingTraversals--
+ d, err := mns.FindInode(ctx, root, node.parent, targetPath, remainingTraversals)
+ if err != nil {
+ return nil, err
+ }
+
+ return d, err
+
+ default:
+ node.DecRef() // Drop for err; see above.
+
+ // Propagate the error.
+ return nil, err
+ }
+}
+
+// SyncAll calls Dirent.SyncAll on the root.
+func (mns *MountNamespace) SyncAll(ctx context.Context) {
+ mns.mu.Lock()
+ defer mns.mu.Unlock()
+ mns.root.SyncAll(ctx)
+}
+
+// ResolveExecutablePath resolves the given executable name given a set of
+// paths that might contain it.
+func (mns *MountNamespace) ResolveExecutablePath(ctx context.Context, wd, name string, paths []string) (string, error) {
+ // Absolute paths can be used directly.
+ if path.IsAbs(name) {
+ return name, nil
+ }
+
+ // Paths with '/' in them should be joined to the working directory, or
+ // to the root if working directory is not set.
+ if strings.IndexByte(name, '/') > 0 {
+ if wd == "" {
+ wd = "/"
+ }
+ if !path.IsAbs(wd) {
+ return "", fmt.Errorf("working directory %q must be absolute", wd)
+ }
+ return path.Join(wd, name), nil
+ }
+
+ // Otherwise, We must lookup the name in the paths, starting from the
+ // calling context's root directory.
+ root := RootFromContext(ctx)
+ if root == nil {
+ // Caller has no root. Don't bother traversing anything.
+ return "", syserror.ENOENT
+ }
+ defer root.DecRef()
+ for _, p := range paths {
+ binPath := path.Join(p, name)
+ traversals := uint(linux.MaxSymlinkTraversals)
+ d, err := mns.FindInode(ctx, root, nil, binPath, &traversals)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return "", err
+ }
+ defer d.DecRef()
+
+ // Check whether we can read and execute the found file.
+ if err := d.Inode.CheckPermission(ctx, PermMask{Read: true, Execute: true}); err != nil {
+ log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err)
+ continue
+ }
+ return path.Join("/", p, name), nil
+ }
+ return "", syserror.ENOENT
+}
+
+// GetPath returns the PATH as a slice of strings given the environemnt
+// variables.
+func GetPath(env []string) []string {
+ const prefix = "PATH="
+ for _, e := range env {
+ if strings.HasPrefix(e, prefix) {
+ return strings.Split(strings.TrimPrefix(e, prefix), ":")
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/offset.go b/pkg/sentry/fs/offset.go
new file mode 100644
index 000000000..3f68da149
--- /dev/null
+++ b/pkg/sentry/fs/offset.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// OffsetPageEnd returns the file offset rounded up to the nearest
+// page boundary. OffsetPageEnd panics if rounding up causes overflow,
+// which shouldn't be possible given that offset is an int64.
+func OffsetPageEnd(offset int64) uint64 {
+ end, ok := usermem.Addr(offset).RoundUp()
+ if !ok {
+ panic("impossible overflow")
+ }
+ return uint64(end)
+}
+
+// ReadEndOffset returns an exclusive end offset for a read operation
+// so that the read does not overflow an int64 nor size.
+//
+// Parameters:
+// - offset: the starting offset of the read.
+// - length: the number of bytes to read.
+// - size: the size of the file.
+//
+// Postconditions: The returned offset is >= offset.
+func ReadEndOffset(offset int64, length int64, size int64) int64 {
+ if offset >= size {
+ return offset
+ }
+ end := offset + length
+ // Don't overflow.
+ if end < offset || end > size {
+ end = size
+ }
+ return end
+}
+
+// WriteEndOffset returns an exclusive end offset for a write operation
+// so that the write does not overflow an int64.
+//
+// Parameters:
+// - offset: the starting offset of the write.
+// - length: the number of bytes to write.
+//
+// Postconditions: The returned offset is >= offset.
+func WriteEndOffset(offset int64, length int64) int64 {
+ return ReadEndOffset(offset, length, math.MaxInt64)
+}
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
new file mode 100644
index 000000000..db89a5f70
--- /dev/null
+++ b/pkg/sentry/fs/overlay.go
@@ -0,0 +1,303 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// The virtual filesystem implements an overlay configuration. For a high-level
+// description, see README.md.
+//
+// Note on whiteouts:
+//
+// This implementation does not use the "Docker-style" whiteouts (symlinks with
+// ".wh." prefix). Instead upper filesystem directories support a set of extended
+// attributes to encode whiteouts: "trusted.overlay.whiteout.<filename>". This
+// gives flexibility to persist whiteouts independently of the filesystem layout
+// while additionally preventing name conflicts with files prefixed with ".wh.".
+//
+// Known deficiencies:
+//
+// - The device number of two files under the same overlay mount point may be
+// different. This can happen if a file is found in the lower filesystem (takes
+// the lower filesystem device) and another file is created in the upper
+// filesystem (takes the upper filesystem device). This may appear odd but
+// should not break applications.
+//
+// - Registered events on files (i.e. for notification of read/write readiness)
+// are not copied across copy up. This is fine in the common case of files that
+// do not block. For files that do block, like pipes and sockets, copy up is not
+// supported.
+//
+// - Hardlinks in a lower filesystem are broken by copy up. For this reason, no
+// attempt is made to preserve link count across copy up.
+//
+// - The maximum length of an extended attribute name is the same as the maximum
+// length of a file path in Linux (XATTR_NAME_MAX == NAME_MAX). This means that
+// whiteout attributes, if set directly on the host, are limited additionally by
+// the extra whiteout prefix length (file paths must be strictly shorter than
+// NAME_MAX). This is not a problem for in-memory filesystems which don't enforce
+// XATTR_NAME_MAX.
+
+const (
+ // XattrOverlayPrefix is the prefix for extended attributes that affect
+ // the behavior of an overlay.
+ XattrOverlayPrefix = "trusted.overlay."
+
+ // XattrOverlayWhiteoutPrefix is the prefix for extended attributes
+ // that indicate that a whiteout exists.
+ XattrOverlayWhiteoutPrefix = XattrOverlayPrefix + "whiteout."
+)
+
+// XattrOverlayWhiteout returns an extended attribute that indicates a
+// whiteout exists for name. It is supported by directories that wish to
+// mask the existence of name.
+func XattrOverlayWhiteout(name string) string {
+ return XattrOverlayWhiteoutPrefix + name
+}
+
+// isXattrOverlay returns whether the given extended attribute configures the
+// overlay.
+func isXattrOverlay(name string) bool {
+ return strings.HasPrefix(name, XattrOverlayPrefix)
+}
+
+// NewOverlayRoot produces the root of an overlay.
+//
+// Preconditions:
+//
+// - upper and lower must be non-nil.
+// - upper must not be an overlay.
+// - lower should not expose character devices, pipes, or sockets, because
+// copying up these types of files is not supported.
+// - lower must not require that file objects be revalidated.
+// - lower must not have dynamic file/directory content.
+func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags MountSourceFlags) (*Inode, error) {
+ if !IsDir(upper.StableAttr) {
+ return nil, fmt.Errorf("upper Inode is a %v, not a directory", upper.StableAttr.Type)
+ }
+ if !IsDir(lower.StableAttr) {
+ return nil, fmt.Errorf("lower Inode is a %v, not a directory", lower.StableAttr.Type)
+ }
+ if upper.overlay != nil {
+ return nil, fmt.Errorf("cannot nest overlay in upper file of another overlay")
+ }
+
+ msrc := newOverlayMountSource(upper.MountSource, lower.MountSource, flags)
+ overlay, err := newOverlayEntry(ctx, upper, lower, true)
+ if err != nil {
+ msrc.DecRef()
+ return nil, err
+ }
+
+ return newOverlayInode(ctx, overlay, msrc), nil
+}
+
+// NewOverlayRootFile produces the root of an overlay that points to a file.
+//
+// Preconditions:
+//
+// - lower must be non-nil.
+// - lower should not expose character devices, pipes, or sockets, because
+// copying up these types of files is not supported. Neither it can be a dir.
+// - lower must not require that file objects be revalidated.
+// - lower must not have dynamic file/directory content.
+func NewOverlayRootFile(ctx context.Context, upperMS *MountSource, lower *Inode, flags MountSourceFlags) (*Inode, error) {
+ if !IsRegular(lower.StableAttr) {
+ return nil, fmt.Errorf("lower Inode is not a regular file")
+ }
+ msrc := newOverlayMountSource(upperMS, lower.MountSource, flags)
+ overlay, err := newOverlayEntry(ctx, nil, lower, true)
+ if err != nil {
+ msrc.DecRef()
+ return nil, err
+ }
+ return newOverlayInode(ctx, overlay, msrc), nil
+}
+
+// newOverlayInode creates a new Inode for an overlay.
+func newOverlayInode(ctx context.Context, o *overlayEntry, msrc *MountSource) *Inode {
+ var inode *Inode
+ if o.upper != nil {
+ inode = NewInode(nil, msrc, o.upper.StableAttr)
+ } else {
+ inode = NewInode(nil, msrc, o.lower.StableAttr)
+ }
+ inode.overlay = o
+ return inode
+}
+
+// overlayEntry is the overlay metadata of an Inode. It implements Mappable.
+//
+// +stateify savable
+type overlayEntry struct {
+ // lowerExists is true if an Inode exists for this file in the lower
+ // filesystem. If lowerExists is true, then the overlay must create
+ // a whiteout entry when renaming and removing this entry to mask the
+ // lower Inode.
+ //
+ // Note that this is distinct from actually holding onto a non-nil
+ // lower Inode (below). The overlay does not need to keep a lower Inode
+ // around unless it needs to operate on it, but it always needs to know
+ // whether the lower Inode exists to correctly execute a rename or
+ // remove operation.
+ lowerExists bool
+
+ // lower is an Inode from a lower filesystem. Modifications are
+ // never made on this Inode.
+ lower *Inode
+
+ // copyMu serializes copy-up for operations above
+ // mm.MemoryManager.mappingMu in the lock order.
+ copyMu sync.RWMutex `state:"nosave"`
+
+ // mapsMu serializes copy-up for operations between
+ // mm.MemoryManager.mappingMu and mm.MemoryManager.activeMu in the lock
+ // order.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks memory mappings of this Mappable so they can be removed
+ // from the lower filesystem Mappable and added to the upper filesystem
+ // Mappable when copy up occurs. It is strictly unnecessary after copy-up.
+ //
+ // mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // dataMu serializes copy-up for operations below mm.MemoryManager.activeMu
+ // in the lock order.
+ dataMu sync.RWMutex `state:"nosave"`
+
+ // upper is an Inode from an upper filesystem. It is non-nil if
+ // the file exists in the upper filesystem. It becomes non-nil
+ // when the Inode that owns this overlayEntry is modified.
+ //
+ // upper is protected by all of copyMu, mapsMu, and dataMu. Holding any of
+ // these locks is sufficient to read upper; holding all three for writing
+ // is required to mutate it.
+ upper *Inode
+}
+
+// newOverlayEntry returns a new overlayEntry.
+func newOverlayEntry(ctx context.Context, upper *Inode, lower *Inode, lowerExists bool) (*overlayEntry, error) {
+ if upper == nil && lower == nil {
+ panic("invalid overlayEntry, needs at least one Inode")
+ }
+ if upper != nil && upper.overlay != nil {
+ panic("nested writable layers are not supported")
+ }
+ // Check for supported lower filesystem types.
+ if lower != nil {
+ switch lower.StableAttr.Type {
+ case RegularFile, Directory, Symlink, Socket:
+ default:
+ // We don't support copying up from character devices,
+ // named pipes, or anything weird (like proc files).
+ log.Warningf("%s not supported in lower filesytem", lower.StableAttr.Type)
+ return nil, syserror.EINVAL
+ }
+ }
+ return &overlayEntry{
+ lowerExists: lowerExists,
+ lower: lower,
+ upper: upper,
+ }, nil
+}
+
+func (o *overlayEntry) release() {
+ // We drop a reference on upper and lower file system Inodes
+ // rather than releasing them, because in-memory filesystems
+ // may hold an extra reference to these Inodes so that they
+ // stay in memory.
+ if o.upper != nil {
+ o.upper.DecRef()
+ }
+ if o.lower != nil {
+ o.lower.DecRef()
+ }
+}
+
+// overlayUpperMountSource gives the upper mount of an overlay mount.
+//
+// The caller may not use this MountSource past the lifetime of overlayMountSource and may
+// not call DecRef on it.
+func overlayUpperMountSource(overlayMountSource *MountSource) *MountSource {
+ return overlayMountSource.MountSourceOperations.(*overlayMountSourceOperations).upper
+}
+
+// Preconditions: At least one of o.copyMu, o.mapsMu, or o.dataMu must be locked.
+func (o *overlayEntry) inodeLocked() *Inode {
+ if o.upper != nil {
+ return o.upper
+ }
+ return o.lower
+}
+
+// Preconditions: At least one of o.copyMu, o.mapsMu, or o.dataMu must be locked.
+func (o *overlayEntry) isMappableLocked() bool {
+ return o.inodeLocked().Mappable() != nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (o *overlayEntry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ o.mapsMu.Lock()
+ defer o.mapsMu.Unlock()
+ if err := o.inodeLocked().Mappable().AddMapping(ctx, ms, ar, offset, writable); err != nil {
+ return err
+ }
+ o.mappings.AddMapping(ms, ar, offset, writable)
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (o *overlayEntry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ o.mapsMu.Lock()
+ defer o.mapsMu.Unlock()
+ o.inodeLocked().Mappable().RemoveMapping(ctx, ms, ar, offset, writable)
+ o.mappings.RemoveMapping(ms, ar, offset, writable)
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (o *overlayEntry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ o.mapsMu.Lock()
+ defer o.mapsMu.Unlock()
+ if err := o.inodeLocked().Mappable().CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil {
+ return err
+ }
+ o.mappings.AddMapping(ms, dstAR, offset, writable)
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (o *overlayEntry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ o.dataMu.RLock()
+ defer o.dataMu.RUnlock()
+ return o.inodeLocked().Mappable().Translate(ctx, required, optional, at)
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (o *overlayEntry) InvalidateUnsavable(ctx context.Context) error {
+ o.mapsMu.Lock()
+ defer o.mapsMu.Unlock()
+ return o.inodeLocked().Mappable().InvalidateUnsavable(ctx)
+}
diff --git a/pkg/sentry/fs/path.go b/pkg/sentry/fs/path.go
new file mode 100644
index 000000000..e4dc02dbb
--- /dev/null
+++ b/pkg/sentry/fs/path.go
@@ -0,0 +1,119 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "path/filepath"
+ "strings"
+)
+
+// TrimTrailingSlashes trims any trailing slashes.
+//
+// The returned boolean indicates whether any changes were made.
+//
+//go:nosplit
+func TrimTrailingSlashes(dir string) (trimmed string, changed bool) {
+ // Trim the trailing slash, except for root.
+ for len(dir) > 1 && dir[len(dir)-1] == '/' {
+ dir = dir[:len(dir)-1]
+ changed = true
+ }
+ return dir, changed
+}
+
+// SplitLast splits the given path into a directory and a file.
+//
+// The "absoluteness" of the path is preserved, but dir is always stripped of
+// trailing slashes.
+//
+//go:nosplit
+func SplitLast(path string) (dir, file string) {
+ path, _ = TrimTrailingSlashes(path)
+ if path == "" {
+ return ".", "."
+ } else if path == "/" {
+ return "/", "."
+ }
+
+ var slash int // Last location of slash in path.
+ for slash = len(path) - 1; slash >= 0 && path[slash] != '/'; slash-- {
+ }
+ switch {
+ case slash < 0:
+ return ".", path
+ case slash == 0:
+ // Directory of the form "/foo", or just "/". We need to
+ // preserve the first slash here, since it indicates an
+ // absolute path.
+ return "/", path[1:]
+ default:
+ // Drop the trailing slash.
+ dir, _ = TrimTrailingSlashes(path[:slash])
+ return dir, path[slash+1:]
+ }
+}
+
+// SplitFirst splits the given path into a first directory and the remainder.
+//
+// If remainder is empty, then the path is a single element.
+//
+//go:nosplit
+func SplitFirst(path string) (current, remainder string) {
+ path, _ = TrimTrailingSlashes(path)
+ if path == "" {
+ return ".", ""
+ }
+
+ var slash int // First location of slash in path.
+ for slash = 0; slash < len(path) && path[slash] != '/'; slash++ {
+ }
+ switch {
+ case slash >= len(path):
+ return path, ""
+ case slash == 0:
+ // See above.
+ return "/", path[1:]
+ default:
+ current = path[:slash]
+ remainder = path[slash+1:]
+ // Strip redundant slashes.
+ for len(remainder) > 0 && remainder[0] == '/' {
+ remainder = remainder[1:]
+ }
+ return current, remainder
+ }
+}
+
+// IsSubpath checks whether the first path is a (strict) descendent of the
+// second. If it is a subpath, then true is returned along with a clean
+// relative path from the second path to the first. Otherwise false is
+// returned.
+func IsSubpath(subpath, path string) (string, bool) {
+ cleanPath := filepath.Clean(path)
+ cleanSubpath := filepath.Clean(subpath)
+
+ // Add a trailing slash to the path if it does not already have one.
+ if len(cleanPath) == 0 || cleanPath[len(cleanPath)-1] != '/' {
+ cleanPath += "/"
+ }
+ if cleanPath == cleanSubpath {
+ // Paths are equal, thus not a strict subpath.
+ return "", false
+ }
+ if strings.HasPrefix(cleanSubpath, cleanPath) {
+ return strings.TrimPrefix(cleanSubpath, cleanPath), true
+ }
+ return "", false
+}
diff --git a/pkg/sentry/fs/proc/cgroup.go b/pkg/sentry/fs/proc/cgroup.go
new file mode 100644
index 000000000..1019f862a
--- /dev/null
+++ b/pkg/sentry/fs/proc/cgroup.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+func newCGroupInode(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string]string) *fs.Inode {
+ // From man 7 cgroups: "For each cgroup hierarchy of which the process
+ // is a member, there is one entry containing three colon-separated
+ // fields: hierarchy-ID:controller-list:cgroup-path"
+
+ // The hierarchy ids must be positive integers (for cgroup v1), but the
+ // exact number does not matter, so long as they are unique. We can
+ // just use a counter, but since linux sorts this file in descending
+ // order, we must count down to perserve this behavior.
+ i := len(cgroupControllers)
+ var data string
+ for name, dir := range cgroupControllers {
+ data += fmt.Sprintf("%d:%s:%s\n", i, name, dir)
+ i--
+ }
+
+ return newStaticProcInode(ctx, msrc, []byte(data))
+}
diff --git a/pkg/sentry/fs/proc/cpuinfo.go b/pkg/sentry/fs/proc/cpuinfo.go
new file mode 100644
index 000000000..15031234e
--- /dev/null
+++ b/pkg/sentry/fs/proc/cpuinfo.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+)
+
+func newCPUInfo(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ k := kernel.KernelFromContext(ctx)
+ features := k.FeatureSet()
+ if features == nil {
+ // Kernel is always initialized with a FeatureSet.
+ panic("cpuinfo read with nil FeatureSet")
+ }
+ contents := make([]byte, 0, 1024)
+ for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
+ contents = append(contents, []byte(features.CPUInfo(i))...)
+ }
+ return newStaticProcInode(ctx, msrc, contents)
+}
diff --git a/pkg/sentry/fs/proc/device/device.go b/pkg/sentry/fs/proc/device/device.go
new file mode 100644
index 000000000..0de466c73
--- /dev/null
+++ b/pkg/sentry/fs/proc/device/device.go
@@ -0,0 +1,23 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package device contains the proc device to avoid dependency loops.
+package device
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+)
+
+// ProcDevice is the kernel proc device.
+var ProcDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/proc/device/device_state_autogen.go b/pkg/sentry/fs/proc/device/device_state_autogen.go
new file mode 100755
index 000000000..be407ac45
--- /dev/null
+++ b/pkg/sentry/fs/proc/device/device_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package device
+
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
new file mode 100644
index 000000000..cb28f6bc3
--- /dev/null
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -0,0 +1,203 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// execArgType enumerates the types of exec arguments that are exposed through
+// proc.
+type execArgType int
+
+const (
+ cmdlineExecArg execArgType = iota
+ environExecArg
+)
+
+// execArgInode is a inode containing the exec args (either cmdline or environ)
+// for a given task.
+//
+// +stateify savable
+type execArgInode struct {
+ fsutil.SimpleFileInode
+
+ // arg is the type of exec argument this file contains.
+ arg execArgType
+
+ // t is the Task to read the exec arg line from.
+ t *kernel.Task
+}
+
+var _ fs.InodeOperations = (*execArgInode)(nil)
+
+// newExecArgFile creates a file containing the exec args of the given type.
+func newExecArgInode(t *kernel.Task, msrc *fs.MountSource, arg execArgType) *fs.Inode {
+ if arg != cmdlineExecArg && arg != environExecArg {
+ panic(fmt.Sprintf("unknown exec arg type %v", arg))
+ }
+ f := &execArgInode{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ arg: arg,
+ t: t,
+ }
+ return newProcInode(f, msrc, fs.SpecialFile, t)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *execArgInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &execArgFile{
+ arg: i.arg,
+ t: i.t,
+ }), nil
+}
+
+// +stateify savable
+type execArgFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopWrite `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ // arg is the type of exec argument this file contains.
+ arg execArgType
+
+ // t is the Task to read the exec arg line from.
+ t *kernel.Task
+}
+
+var _ fs.FileOperations = (*execArgFile)(nil)
+
+// Read reads the exec arg from the process's address space..
+func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ m, err := getTaskMM(f.t)
+ if err != nil {
+ return 0, err
+ }
+ defer m.DecUsers(ctx)
+
+ // Figure out the bounds of the exec arg we are trying to read.
+ var execArgStart, execArgEnd usermem.Addr
+ switch f.arg {
+ case cmdlineExecArg:
+ execArgStart, execArgEnd = m.ArgvStart(), m.ArgvEnd()
+ case environExecArg:
+ execArgStart, execArgEnd = m.EnvvStart(), m.EnvvEnd()
+ default:
+ panic(fmt.Sprintf("unknown exec arg type %v", f.arg))
+ }
+ if execArgStart == 0 || execArgEnd == 0 {
+ // Don't attempt to read before the start/end are set up.
+ return 0, io.EOF
+ }
+
+ start, ok := execArgStart.AddLength(uint64(offset))
+ if !ok {
+ return 0, io.EOF
+ }
+ if start >= execArgEnd {
+ return 0, io.EOF
+ }
+
+ length := int(execArgEnd - start)
+ if dstlen := dst.NumBytes(); int64(length) > dstlen {
+ length = int(dstlen)
+ }
+
+ buf := make([]byte, length)
+ // N.B. Technically this should be usermem.IOOpts.IgnorePermissions = true
+ // until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading
+ // cmdline and environment").
+ copyN, err := m.CopyIn(ctx, start, buf, usermem.IOOpts{})
+ if copyN == 0 {
+ // Nothing to copy.
+ return 0, err
+ }
+ buf = buf[:copyN]
+
+ // On Linux, if the NUL byte at the end of the argument vector has been
+ // overwritten, it continues reading the environment vector as part of
+ // the argument vector.
+
+ if f.arg == cmdlineExecArg && buf[copyN-1] != 0 {
+ // Linux will limit the return up to and including the first null character in argv
+
+ copyN = bytes.IndexByte(buf, 0)
+ if copyN == -1 {
+ copyN = len(buf)
+ }
+ // If we found a NUL character in argv, return upto and including that character.
+ if copyN < len(buf) {
+ buf = buf[:copyN]
+ } else { // Otherwise return into envp.
+ lengthEnvv := int(m.EnvvEnd() - m.EnvvStart())
+
+ // Upstream limits the returned amount to one page of slop.
+ // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208
+ // we'll return one page total between argv and envp because of the
+ // above page restrictions.
+ if lengthEnvv > usermem.PageSize-len(buf) {
+ lengthEnvv = usermem.PageSize - len(buf)
+ }
+ // Make a new buffer to fit the whole thing
+ tmp := make([]byte, length+lengthEnvv)
+ copyNE, err := m.CopyIn(ctx, m.EnvvStart(), tmp[copyN:], usermem.IOOpts{})
+ if err != nil {
+ return 0, err
+ }
+
+ // Linux will return envp up to and including the first NUL character, so find it.
+ for i, c := range tmp[copyN:] {
+ if c == 0 {
+ copyNE = i
+ break
+ }
+ }
+
+ copy(tmp, buf)
+ buf = tmp[:copyN+copyNE]
+
+ }
+
+ }
+
+ n, dstErr := dst.CopyOut(ctx, buf)
+ if dstErr != nil {
+ return int64(n), dstErr
+ }
+ return int64(n), err
+}
diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go
new file mode 100644
index 000000000..744b31c74
--- /dev/null
+++ b/pkg/sentry/fs/proc/fds.go
@@ -0,0 +1,285 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// walkDescriptors finds the descriptor (file-flag pair) for the fd identified
+// by p, and calls the toInodeOperations callback with that descriptor. This is a helper
+// method for implementing fs.InodeOperations.Lookup.
+func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDFlags) *fs.Inode) (*fs.Inode, error) {
+ n, err := strconv.ParseUint(p, 10, 64)
+ if err != nil {
+ // Not found.
+ return nil, syserror.ENOENT
+ }
+
+ var file *fs.File
+ var fdFlags kernel.FDFlags
+ t.WithMuLocked(func(t *kernel.Task) {
+ if fdm := t.FDMap(); fdm != nil {
+ file, fdFlags = fdm.GetDescriptor(kdefs.FD(n))
+ }
+ })
+ if file == nil {
+ return nil, syserror.ENOENT
+ }
+ return toInode(file, fdFlags), nil
+}
+
+// readDescriptors reads fds in the task starting at offset, and calls the
+// toDentAttr callback for each to get a DentAttr, which it then emits. This is
+// a helper for implementing fs.InodeOperations.Readdir.
+func readDescriptors(t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) {
+ var fds kernel.FDs
+ t.WithMuLocked(func(t *kernel.Task) {
+ if fdm := t.FDMap(); fdm != nil {
+ fds = fdm.GetFDs()
+ }
+ })
+
+ fdInts := make([]int, 0, len(fds))
+ for _, fd := range fds {
+ fdInts = append(fdInts, int(fd))
+ }
+
+ // Find the fd to start at.
+ idx := sort.SearchInts(fdInts, int(offset))
+ if idx == len(fdInts) {
+ return offset, nil
+ }
+ fdInts = fdInts[idx:]
+
+ var fd int
+ for _, fd = range fdInts {
+ name := strconv.FormatUint(uint64(fd), 10)
+ if err := c.DirEmit(name, toDentAttr(fd)); err != nil {
+ // Returned offset is the next fd to serialize.
+ return int64(fd), err
+ }
+ }
+ // We serialized them all. Next offset should be higher than last
+ // serialized fd.
+ return int64(fd + 1), nil
+}
+
+// fd implements fs.InodeOperations for a file in /proc/TID/fd/.
+type fd struct {
+ ramfs.Symlink
+ file *fs.File
+}
+
+var _ fs.InodeOperations = (*fd)(nil)
+
+// newFd returns a new fd based on an existing file.
+//
+// This inherits one reference to the file.
+func newFd(t *kernel.Task, f *fs.File, msrc *fs.MountSource) *fs.Inode {
+ fd := &fd{
+ // RootOwner overridden by taskOwnedInodeOps.UnstableAttrs().
+ Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""),
+ file: f,
+ }
+ return newProcInode(fd, msrc, fs.Symlink, t)
+}
+
+// GetFile returns the fs.File backing this fd. The dirent and flags
+// arguments are ignored.
+func (f *fd) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error) {
+ // Take a reference on the fs.File.
+ f.file.IncRef()
+ return f.file, nil
+}
+
+// Readlink returns the current target.
+func (f *fd) Readlink(ctx context.Context, _ *fs.Inode) (string, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ n, _ := f.file.Dirent.FullName(root)
+ return n, nil
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (f *fd) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ f.file.Dirent.IncRef()
+ return f.file.Dirent, nil
+}
+
+// Truncate is ignored.
+func (f *fd) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+func (f *fd) Release(ctx context.Context) {
+ f.Symlink.Release(ctx)
+ f.file.DecRef()
+}
+
+// Close releases the reference on the file.
+func (f *fd) Close() error {
+ f.file.DecRef()
+ return nil
+}
+
+// fdDir is an InodeOperations for /proc/TID/fd.
+//
+// +stateify savable
+type fdDir struct {
+ ramfs.Dir
+
+ // We hold a reference on the task's fdmap but only keep an indirect
+ // task pointer to avoid Dirent loading circularity caused by fdmap's
+ // potential back pointers into the dirent tree.
+ t *kernel.Task
+}
+
+var _ fs.InodeOperations = (*fdDir)(nil)
+
+// newFdDir creates a new fdDir.
+func newFdDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ f := &fdDir{
+ Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermissions{User: fs.PermMask{Read: true, Execute: true}}),
+ t: t,
+ }
+ return newProcInode(f, msrc, fs.SpecialDirectory, t)
+}
+
+// Check implements InodeOperations.Check.
+//
+// This is to match Linux, which uses a special permission handler to guarantee
+// that a process can still access /proc/self/fd after it has executed
+// setuid. See fs/proc/fd.c:proc_fd_permission.
+func (f *fdDir) Check(ctx context.Context, inode *fs.Inode, req fs.PermMask) bool {
+ if fs.ContextCanAccessFile(ctx, inode, req) {
+ return true
+ }
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ // Allow access if the task trying to access it is in the
+ // thread group corresponding to this directory.
+ if f.t.ThreadGroup() == t.ThreadGroup() {
+ return true
+ }
+ }
+ return false
+}
+
+// Lookup loads an Inode in /proc/TID/fd into a Dirent.
+func (f *fdDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) {
+ n, err := walkDescriptors(f.t, p, func(file *fs.File, _ kernel.FDFlags) *fs.Inode {
+ return newFd(f.t, file, dir.MountSource)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return fs.NewDirent(n, p), nil
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (f *fdDir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ fops := &fdDirFile{
+ isInfoFile: false,
+ t: f.t,
+ }
+ return fs.NewFile(ctx, dirent, flags, fops), nil
+}
+
+// +stateify savable
+type fdDirFile struct {
+ fsutil.DirFileOperations `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ isInfoFile bool
+
+ t *kernel.Task
+}
+
+var _ fs.FileOperations = (*fdDirFile)(nil)
+
+// Readdir implements fs.FileOperations.Readdir.
+func (f *fdDirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) {
+ dirCtx := &fs.DirCtx{
+ Serializer: ser,
+ }
+ typ := fs.RegularFile
+ if f.isInfoFile {
+ typ = fs.Symlink
+ }
+ return readDescriptors(f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr {
+ return fs.GenericDentAttr(typ, device.ProcDevice)
+ })
+}
+
+// fdInfoDir implements /proc/TID/fdinfo. It embeds an fdDir, but overrides
+// Lookup and Readdir.
+//
+// +stateify savable
+type fdInfoDir struct {
+ ramfs.Dir
+
+ t *kernel.Task
+}
+
+// newFdInfoDir creates a new fdInfoDir.
+func newFdInfoDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ fdid := &fdInfoDir{
+ Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermsFromMode(0500)),
+ t: t,
+ }
+ return newProcInode(fdid, msrc, fs.SpecialDirectory, t)
+}
+
+// Lookup loads an fd in /proc/TID/fdinfo into a Dirent.
+func (fdid *fdInfoDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) {
+ inode, err := walkDescriptors(fdid.t, p, func(file *fs.File, fdFlags kernel.FDFlags) *fs.Inode {
+ // TODO(b/121266871): Using a static inode here means that the
+ // data can be out-of-date if, for instance, the flags on the
+ // FD change before we read this file. We should switch to
+ // generating the data on Read(). Also, we should include pos,
+ // locks, and other data. For now we only have flags.
+ // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
+ flags := file.Flags().ToLinux() | fdFlags.ToLinuxFileFlags()
+ file.DecRef()
+ contents := []byte(fmt.Sprintf("flags:\t0%o\n", flags))
+ return newStaticProcInode(ctx, dir.MountSource, contents)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return fs.NewDirent(inode, p), nil
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (fdid *fdInfoDir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ fops := &fdDirFile{
+ isInfoFile: true,
+ t: fdid.t,
+ }
+ return fs.NewFile(ctx, dirent, flags, fops), nil
+}
diff --git a/pkg/sentry/fs/proc/filesystems.go b/pkg/sentry/fs/proc/filesystems.go
new file mode 100644
index 000000000..7bb081d0e
--- /dev/null
+++ b/pkg/sentry/fs/proc/filesystems.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+)
+
+// filesystemsData backs /proc/filesystems.
+//
+// +stateify savable
+type filesystemsData struct{}
+
+// NeedsUpdate returns true on the first generation. The set of registered file
+// systems doesn't change so there's no need to generate SeqData more than once.
+func (*filesystemsData) NeedsUpdate(generation int64) bool {
+ return generation == 0
+}
+
+// ReadSeqFileData returns data for the SeqFile reader.
+// SeqData, the current generation and where in the file the handle corresponds to.
+func (*filesystemsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ // We don't ever expect to see a non-nil SeqHandle.
+ if h != nil {
+ return nil, 0
+ }
+
+ // Generate the file contents.
+ var buf bytes.Buffer
+ for _, sys := range fs.GetFilesystems() {
+ if !sys.AllowUserList() {
+ continue
+ }
+ nodev := "nodev"
+ if sys.Flags()&fs.FilesystemRequiresDev != 0 {
+ nodev = ""
+ }
+ // Matches the format of fs/filesystems.c:filesystems_proc_show.
+ fmt.Fprintf(&buf, "%s\t%s\n", nodev, sys.Name())
+ }
+
+ // Return the SeqData and advance the generation counter.
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*filesystemsData)(nil)}}, 1
+}
diff --git a/pkg/sentry/fs/proc/fs.go b/pkg/sentry/fs/proc/fs.go
new file mode 100644
index 000000000..d57d6cc5d
--- /dev/null
+++ b/pkg/sentry/fs/proc/fs.go
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// filesystem is a procfs.
+//
+// +stateify savable
+type filesystem struct{}
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name underwhich the filesystem is registered.
+// Name matches fs/proc/root.c:proc_fs_type.name.
+const FilesystemName = "proc"
+
+// Name is the name of the file system.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*filesystem) AllowUserMount() bool {
+ return true
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// In Linux, proc returns FS_USERNS_VISIBLE | FS_USERNS_MOUNT, see fs/proc/root.c.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns the root of a procfs that can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, cgroupsInt interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+
+ // Parse generic comma-separated key=value options, this file system expects them.
+ options := fs.GenericMountSourceOptions(data)
+
+ // Proc options parsing checks for either a gid= or hidepid= and barfs on
+ // anything else, see fs/proc/root.c:proc_parse_options. Since we don't know
+ // what to do with gid= or hidepid=, we blow up if we get any options.
+ if len(options) > 0 {
+ return nil, fmt.Errorf("unsupported mount options: %v", options)
+ }
+
+ var cgroups map[string]string
+ if cgroupsInt != nil {
+ cgroups = cgroupsInt.(map[string]string)
+ }
+
+ // Construct the procfs root. Since procfs files are all virtual, we
+ // never want them cached.
+ return New(ctx, fs.NewNonCachingMountSource(f, flags), cgroups)
+}
diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go
new file mode 100644
index 000000000..379569823
--- /dev/null
+++ b/pkg/sentry/fs/proc/inode.go
@@ -0,0 +1,97 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr
+// method to return the task as the owner.
+//
+// +stateify savable
+type taskOwnedInodeOps struct {
+ fs.InodeOperations
+
+ // t is the task that owns this file.
+ t *kernel.Task
+}
+
+// UnstableAttr implement fs.InodeOperations.UnstableAttr.
+func (i *taskOwnedInodeOps) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ uattr, err := i.InodeOperations.UnstableAttr(ctx, inode)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ // Set the task owner as the file owner.
+ creds := i.t.Credentials()
+ uattr.Owner = fs.FileOwner{creds.EffectiveKUID, creds.EffectiveKGID}
+ return uattr, nil
+}
+
+// staticFileInodeOps is an InodeOperations implementation that can be used to
+// return file contents which are constant. This file is not writable and will
+// always have mode 0444.
+//
+// +stateify savable
+type staticFileInodeOps struct {
+ fsutil.InodeDenyWriteChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeStaticFileGetter
+}
+
+var _ fs.InodeOperations = (*staticFileInodeOps)(nil)
+
+// newStaticFileInode returns a procfs InodeOperations with static contents.
+func newStaticProcInode(ctx context.Context, msrc *fs.MountSource, contents []byte) *fs.Inode {
+ iops := &staticFileInodeOps{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ InodeStaticFileGetter: fsutil.InodeStaticFileGetter{
+ Contents: contents,
+ },
+ }
+ return newProcInode(iops, msrc, fs.SpecialFile, nil)
+}
+
+// newProcInode creates a new inode from the given inode operations.
+func newProcInode(iops fs.InodeOperations, msrc *fs.MountSource, typ fs.InodeType, t *kernel.Task) *fs.Inode {
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: typ,
+ }
+ if t != nil {
+ iops = &taskOwnedInodeOps{iops, t}
+ }
+ return fs.NewInode(iops, msrc, sattr)
+}
diff --git a/pkg/sentry/fs/proc/loadavg.go b/pkg/sentry/fs/proc/loadavg.go
new file mode 100644
index 000000000..2dfe7089a
--- /dev/null
+++ b/pkg/sentry/fs/proc/loadavg.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+)
+
+// loadavgData backs /proc/loadavg.
+//
+// +stateify savable
+type loadavgData struct{}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*loadavgData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (d *loadavgData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+
+ // TODO(b/62345059): Include real data in fields.
+ // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
+ // Column 4-5: currently running processes and the total number of processes.
+ // Column 6: the last process ID used.
+ fmt.Fprintf(&buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
+
+ return []seqfile.SeqData{
+ {
+ Buf: buf.Bytes(),
+ Handle: (*loadavgData)(nil),
+ },
+ }, 0
+}
diff --git a/pkg/sentry/fs/proc/meminfo.go b/pkg/sentry/fs/proc/meminfo.go
new file mode 100644
index 000000000..d2b9b92c7
--- /dev/null
+++ b/pkg/sentry/fs/proc/meminfo.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// meminfoData backs /proc/meminfo.
+//
+// +stateify savable
+type meminfoData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*meminfoData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (d *meminfoData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ mf := d.k.MemoryFile()
+ mf.UpdateUsage()
+ snapshot, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ anon := snapshot.Anonymous + snapshot.Tmpfs
+ file := snapshot.PageCache + snapshot.Mapped
+ // We don't actually have active/inactive LRUs, so just make up numbers.
+ activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ inactiveFile := file - activeFile
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "MemTotal: %8d kB\n", totalSize/1024)
+ memFree := (totalSize - totalUsage) / 1024
+ // We use MemFree as MemAvailable because we don't swap.
+ // TODO(rahat): When reclaim is implemented the value of MemAvailable
+ // should change.
+ fmt.Fprintf(&buf, "MemFree: %8d kB\n", memFree)
+ fmt.Fprintf(&buf, "MemAvailable: %8d kB\n", memFree)
+ fmt.Fprintf(&buf, "Buffers: 0 kB\n") // memory usage by block devices
+ fmt.Fprintf(&buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
+ // Emulate a system with no swap, which disables inactivation of anon pages.
+ fmt.Fprintf(&buf, "SwapCache: 0 kB\n")
+ fmt.Fprintf(&buf, "Active: %8d kB\n", (anon+activeFile)/1024)
+ fmt.Fprintf(&buf, "Inactive: %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(&buf, "Active(anon): %8d kB\n", anon/1024)
+ fmt.Fprintf(&buf, "Inactive(anon): 0 kB\n")
+ fmt.Fprintf(&buf, "Active(file): %8d kB\n", activeFile/1024)
+ fmt.Fprintf(&buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(&buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(&buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(&buf, "SwapTotal: 0 kB\n")
+ fmt.Fprintf(&buf, "SwapFree: 0 kB\n")
+ fmt.Fprintf(&buf, "Dirty: 0 kB\n")
+ fmt.Fprintf(&buf, "Writeback: 0 kB\n")
+ fmt.Fprintf(&buf, "AnonPages: %8d kB\n", anon/1024)
+ fmt.Fprintf(&buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
+ fmt.Fprintf(&buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*meminfoData)(nil)}}, 0
+}
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
new file mode 100644
index 000000000..1f7817947
--- /dev/null
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -0,0 +1,197 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+)
+
+// forEachMountSource runs f for the process root mount and each mount that is a
+// descendant of the root.
+func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
+ var fsctx *kernel.FSContext
+ t.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return
+ }
+
+ // All mount points must be relative to the rootDir, and mounts outside
+ // will be excluded.
+ rootDir := fsctx.RootDirectory()
+ if rootDir == nil {
+ // The task has been destroyed. Nothing to show here.
+ return
+ }
+ defer rootDir.DecRef()
+
+ mnt := t.MountNamespace().FindMount(rootDir)
+ if mnt == nil {
+ // Has it just been unmounted?
+ return
+ }
+ ms := t.MountNamespace().AllMountsUnder(mnt)
+ sort.Slice(ms, func(i, j int) bool {
+ return ms[i].ID < ms[j].ID
+ })
+ for _, m := range ms {
+ mroot := m.Root()
+ mountPath, desc := mroot.FullName(rootDir)
+ mroot.DecRef()
+ if !desc {
+ // MountSources that are not descendants of the chroot jail are ignored.
+ continue
+ }
+
+ fn(mountPath, m)
+ }
+}
+
+// mountInfoFile is used to implement /proc/[pid]/mountinfo.
+//
+// +stateify savable
+type mountInfoFile struct {
+ t *kernel.Task
+}
+
+// NeedsUpdate implements SeqSource.NeedsUpdate.
+func (mif *mountInfoFile) NeedsUpdate(_ int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements SeqSource.ReadSeqFileData.
+func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if handle != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ forEachMount(mif.t, func(mountPath string, m *fs.Mount) {
+ // Format:
+ // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
+ // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
+
+ // (1) MountSource ID.
+ fmt.Fprintf(&buf, "%d ", m.ID)
+
+ // (2) Parent ID (or this ID if there is no parent).
+ pID := m.ID
+ if !m.IsRoot() && !m.IsUndo() {
+ pID = m.ParentID
+ }
+ fmt.Fprintf(&buf, "%d ", pID)
+
+ // (3) Major:Minor device ID. We don't have a superblock, so we
+ // just use the root inode device number.
+ mroot := m.Root()
+ defer mroot.DecRef()
+
+ sa := mroot.Inode.StableAttr
+ fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
+
+ // (4) Root: the pathname of the directory in the filesystem
+ // which forms the root of this mount.
+ //
+ // NOTE(b/78135857): This will always be "/" until we implement
+ // bind mounts.
+ fmt.Fprintf(&buf, "/ ")
+
+ // (5) Mount point (relative to process root).
+ fmt.Fprintf(&buf, "%s ", mountPath)
+
+ // (6) Mount options.
+ flags := mroot.Inode.MountSource.Flags
+ opts := "rw"
+ if flags.ReadOnly {
+ opts = "ro"
+ }
+ if flags.NoAtime {
+ opts += ",noatime"
+ }
+ if flags.NoExec {
+ opts += ",noexec"
+ }
+ fmt.Fprintf(&buf, "%s ", opts)
+
+ // (7) Optional fields: zero or more fields of the form "tag[:value]".
+ // (8) Separator: the end of the optional fields is marked by a single hyphen.
+ fmt.Fprintf(&buf, "- ")
+
+ // (9) Filesystem type.
+ fmt.Fprintf(&buf, "%s ", mroot.Inode.MountSource.FilesystemType)
+
+ // (10) Mount source: filesystem-specific information or "none".
+ fmt.Fprintf(&buf, "none ")
+
+ // (11) Superblock options. Only "ro/rw" is supported for now,
+ // and is the same as the filesystem option.
+ fmt.Fprintf(&buf, "%s\n", opts)
+ })
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*mountInfoFile)(nil)}}, 0
+}
+
+// mountsFile is used to implement /proc/[pid]/mounts.
+//
+// +stateify savable
+type mountsFile struct {
+ t *kernel.Task
+}
+
+// NeedsUpdate implements SeqSource.NeedsUpdate.
+func (mf *mountsFile) NeedsUpdate(_ int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements SeqSource.ReadSeqFileData.
+func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if handle != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ forEachMount(mf.t, func(mountPath string, m *fs.Mount) {
+ // Format:
+ // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order>
+ //
+ // We use the filesystem name as the first field, since there
+ // is no real block device we can point to, and we also should
+ // not expose anything about the remote filesystem.
+ //
+ // Only ro/rw option is supported for now.
+ //
+ // The "needs dump"and fsck flags are always 0, which is allowed.
+ root := m.Root()
+ defer root.DecRef()
+
+ flags := root.Inode.MountSource.Flags
+ opts := "rw"
+ if flags.ReadOnly {
+ opts = "ro"
+ }
+ fmt.Fprintf(&buf, "%s %s %s %s %d %d\n", "none", mountPath, root.Inode.MountSource.FilesystemType, opts, 0, 0)
+ })
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*mountsFile)(nil)}}, 0
+}
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
new file mode 100644
index 000000000..4a107c739
--- /dev/null
+++ b/pkg/sentry/fs/proc/net.go
@@ -0,0 +1,308 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// newNet creates a new proc net entry.
+func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
+ var contents map[string]*fs.Inode
+ if s := p.k.NetworkStack(); s != nil {
+ contents = map[string]*fs.Inode{
+ "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
+
+ // The following files are simple stubs until they are
+ // implemented in netstack, if the file contains a
+ // header the stub is just the header otherwise it is
+ // an empty file.
+ "arp": newStaticProcInode(ctx, msrc, []byte("IP address HW type Flags HW address Mask Device")),
+
+ "netlink": newStaticProcInode(ctx, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode")),
+ "netstat": newStaticProcInode(ctx, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess")),
+ "packet": newStaticProcInode(ctx, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode")),
+ "protocols": newStaticProcInode(ctx, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em")),
+ // Linux sets psched values to: nsec per usec, psched
+ // tick in ns, 1000000, high res timer ticks per sec
+ // (ClockGetres returns 1ns resolution).
+ "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
+ "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")),
+ "route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")),
+ "tcp": newStaticProcInode(ctx, msrc, []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")),
+
+ "udp": newStaticProcInode(ctx, msrc, []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops")),
+
+ "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
+ }
+
+ if s.SupportsIPv6() {
+ contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc)
+ contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte(""))
+ contents["tcp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode"))
+ contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode"))
+ }
+ }
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+// ifinet6 implements seqfile.SeqSource for /proc/net/if_inet6.
+//
+// +stateify savable
+type ifinet6 struct {
+ s inet.Stack
+}
+
+func (n *ifinet6) contents() []string {
+ var lines []string
+ nics := n.s.Interfaces()
+ for id, naddrs := range n.s.InterfaceAddrs() {
+ nic, ok := nics[id]
+ if !ok {
+ // NIC was added after NICNames was called. We'll just
+ // ignore it.
+ continue
+ }
+
+ for _, a := range naddrs {
+ // IPv6 only.
+ if a.Family != linux.AF_INET6 {
+ continue
+ }
+
+ // Fields:
+ // IPv6 address displayed in 32 hexadecimal chars without colons
+ // Netlink device number (interface index) in hexadecimal (use nic id)
+ // Prefix length in hexadecimal
+ // Scope value (use 0)
+ // Interface flags
+ // Device name
+ lines = append(lines, fmt.Sprintf("%032x %02x %02x %02x %02x %8s\n", a.Addr, id, a.PrefixLen, 0, a.Flags, nic.Name))
+ }
+ }
+ return lines
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*ifinet6) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *ifinet6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var data []seqfile.SeqData
+ for _, l := range n.contents() {
+ data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*ifinet6)(nil)})
+ }
+
+ return data, 0
+}
+
+// netDev implements seqfile.SeqSource for /proc/net/dev.
+//
+// +stateify savable
+type netDev struct {
+ s inet.Stack
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (n *netDev) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's
+// net/core/net-procfs.c:dev_seq_show.
+func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ interfaces := n.s.Interfaces()
+ contents := make([]string, 2, 2+len(interfaces))
+ // Add the table header. From net/core/net-procfs.c:dev_seq_show.
+ contents[0] = "Inter-| Receive | Transmit\n"
+ contents[1] = " face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n"
+
+ for _, i := range interfaces {
+ // TODO(b/71872867): Collect stats from each inet.Stack
+ // implementation (hostinet, epsocket, and rpcinet).
+
+ // Implements the same format as
+ // net/core/net-procfs.c:dev_seq_printf_stats.
+ l := fmt.Sprintf("%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n",
+ i.Name,
+ // Received
+ 0, // bytes
+ 0, // packets
+ 0, // errors
+ 0, // dropped
+ 0, // fifo
+ 0, // frame
+ 0, // compressed
+ 0, // multicast
+ // Transmitted
+ 0, // bytes
+ 0, // packets
+ 0, // errors
+ 0, // dropped
+ 0, // fifo
+ 0, // frame
+ 0, // compressed
+ 0) // multicast
+ contents = append(contents, l)
+ }
+
+ var data []seqfile.SeqData
+ for _, l := range contents {
+ data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*ifinet6)(nil)})
+ }
+
+ return data, 0
+}
+
+// netUnix implements seqfile.SeqSource for /proc/net/unix.
+//
+// +stateify savable
+type netUnix struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netUnix) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return []seqfile.SeqData{}, 0
+ }
+
+ var buf bytes.Buffer
+ // Header
+ fmt.Fprintf(&buf, "Num RefCount Protocol Flags Type St Inode Path\n")
+
+ // Entries
+ for _, sref := range n.k.ListSockets(linux.AF_UNIX) {
+ s := sref.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", sref)
+ continue
+ }
+ sfile := s.(*fs.File)
+ sops, ok := sfile.FileOperations.(*unix.SocketOperations)
+ if !ok {
+ panic(fmt.Sprintf("Found non-unix socket file in unix socket table: %+v", sfile))
+ }
+
+ addr, err := sops.Endpoint().GetLocalAddress()
+ if err != nil {
+ log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err)
+ addr.Addr = "<unknown>"
+ }
+
+ sockFlags := 0
+ if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok {
+ if ce.Listening() {
+ // For unix domain sockets, linux reports a single flag
+ // value if the socket is listening, of __SO_ACCEPTCON.
+ sockFlags = linux.SO_ACCEPTCON
+ }
+ }
+
+ var sockState int
+ switch sops.Endpoint().Type() {
+ case linux.SOCK_DGRAM:
+ sockState = linux.SS_CONNECTING
+ // Unlike Linux, we don't have unbound connection-less sockets,
+ // so no SS_DISCONNECTING.
+
+ case linux.SOCK_SEQPACKET:
+ fallthrough
+ case linux.SOCK_STREAM:
+ // Connectioned.
+ if sops.Endpoint().(transport.ConnectingEndpoint).Connected() {
+ sockState = linux.SS_CONNECTED
+ } else {
+ sockState = linux.SS_UNCONNECTED
+ }
+ }
+
+ // In the socket entry below, the value for the 'Num' field requires
+ // some consideration. Linux prints the address to the struct
+ // unix_sock representing a socket in the kernel, but may redact the
+ // value for unprivileged users depending on the kptr_restrict
+ // sysctl.
+ //
+ // One use for this field is to allow a privileged user to
+ // introspect into the kernel memory to determine information about
+ // a socket not available through procfs, such as the socket's peer.
+ //
+ // On gvisor, returning a pointer to our internal structures would
+ // be pointless, as it wouldn't match the memory layout for struct
+ // unix_sock, making introspection difficult. We could populate a
+ // struct unix_sock with the appropriate data, but even that
+ // requires consideration for which kernel version to emulate, as
+ // the definition of this struct changes over time.
+ //
+ // For now, we always redact this pointer.
+ fmt.Fprintf(&buf, "%#016p: %08X %08X %08X %04X %02X %5d",
+ (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
+ sfile.ReadRefs()-1, // RefCount, don't count our own ref.
+ 0, // Protocol, always 0 for UDS.
+ sockFlags, // Flags.
+ sops.Endpoint().Type(), // Type.
+ sockState, // State.
+ sfile.InodeID(), // Inode.
+ )
+
+ // Path
+ if len(addr.Addr) != 0 {
+ if addr.Addr[0] == 0 {
+ // Abstract path.
+ fmt.Fprintf(&buf, " @%s", string(addr.Addr[1:]))
+ } else {
+ fmt.Fprintf(&buf, " %s", string(addr.Addr))
+ }
+ }
+ fmt.Fprintf(&buf, "\n")
+
+ sfile.DecRef()
+ }
+
+ data := []seqfile.SeqData{{
+ Buf: buf.Bytes(),
+ Handle: (*netUnix)(nil),
+ }}
+ return data, 0
+}
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
new file mode 100644
index 000000000..0e15894b4
--- /dev/null
+++ b/pkg/sentry/fs/proc/proc.go
@@ -0,0 +1,251 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package proc implements a partial in-memory file system for profs.
+package proc
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// proc is a root proc node.
+//
+// +stateify savable
+type proc struct {
+ ramfs.Dir
+
+ // k is the Kernel containing this proc node.
+ k *kernel.Kernel
+
+ // pidns is the PID namespace of the task that mounted the proc filesystem
+ // that this node represents.
+ pidns *kernel.PIDNamespace
+
+ // cgroupControllers is a map of controller name to directory in the
+ // cgroup hierarchy. These controllers are immutable and will be listed
+ // in /proc/pid/cgroup if not nil.
+ cgroupControllers map[string]string
+}
+
+// New returns the root node of a partial simple procfs.
+func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string]string) (*fs.Inode, error) {
+ k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, fmt.Errorf("procfs requires a kernel")
+ }
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return nil, fmt.Errorf("procfs requires a PID namespace")
+ }
+
+ // Note that these are just the static members. There are dynamic
+ // members populated in Readdir and Lookup below.
+ contents := map[string]*fs.Inode{
+ "cpuinfo": newCPUInfo(ctx, msrc),
+ "filesystems": seqfile.NewSeqFileInode(ctx, &filesystemsData{}, msrc),
+ "loadavg": seqfile.NewSeqFileInode(ctx, &loadavgData{}, msrc),
+ "meminfo": seqfile.NewSeqFileInode(ctx, &meminfoData{k}, msrc),
+ "mounts": newProcInode(ramfs.NewSymlink(ctx, fs.RootOwner, "self/mounts"), msrc, fs.Symlink, nil),
+ "self": newSelf(ctx, pidns, msrc),
+ "stat": seqfile.NewSeqFileInode(ctx, &statData{k}, msrc),
+ "thread-self": newThreadSelf(ctx, pidns, msrc),
+ "uptime": newUptime(ctx, msrc),
+ "version": seqfile.NewSeqFileInode(ctx, &versionData{k}, msrc),
+ }
+
+ // Construct the proc InodeOperations.
+ p := &proc{
+ Dir: *ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)),
+ k: k,
+ pidns: pidns,
+ cgroupControllers: cgroupControllers,
+ }
+
+ // Add more contents that need proc to be initialized.
+ p.AddChild(ctx, "sys", p.newSysDir(ctx, msrc))
+
+ // If we're using rpcinet we will let it manage /proc/net.
+ if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok {
+ p.AddChild(ctx, "net", newRPCInetProcNet(ctx, msrc))
+ } else {
+ p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc))
+ }
+
+ return newProcInode(p, msrc, fs.SpecialDirectory, nil), nil
+}
+
+// self is a magical link.
+//
+// +stateify savable
+type self struct {
+ ramfs.Symlink
+
+ pidns *kernel.PIDNamespace
+}
+
+// newSelf returns a new "self" node.
+func newSelf(ctx context.Context, pidns *kernel.PIDNamespace, msrc *fs.MountSource) *fs.Inode {
+ s := &self{
+ Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""),
+ pidns: pidns,
+ }
+ return newProcInode(s, msrc, fs.Symlink, nil)
+}
+
+// newThreadSelf returns a new "threadSelf" node.
+func newThreadSelf(ctx context.Context, pidns *kernel.PIDNamespace, msrc *fs.MountSource) *fs.Inode {
+ s := &threadSelf{
+ Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""),
+ pidns: pidns,
+ }
+ return newProcInode(s, msrc, fs.Symlink, nil)
+}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (s *self) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ if tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return strconv.FormatUint(uint64(tgid), 10), nil
+ }
+
+ // Who is reading this link?
+ return "", syserror.EINVAL
+}
+
+// threadSelf is more magical than "self" link.
+//
+// +stateify savable
+type threadSelf struct {
+ ramfs.Symlink
+
+ pidns *kernel.PIDNamespace
+}
+
+// Readlink implements fs.InodeOperations.Readlink.
+func (s *threadSelf) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ tid := s.pidns.IDOfTask(t)
+ if tid == 0 || tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return fmt.Sprintf("%d/task/%d", tgid, tid), nil
+ }
+
+ // Who is reading this link?
+ return "", syserror.EINVAL
+}
+
+// Lookup loads an Inode at name into a Dirent.
+func (p *proc) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
+ dirent, walkErr := p.Dir.Lookup(ctx, dir, name)
+ if walkErr == nil {
+ return dirent, nil
+ }
+
+ // Try to lookup a corresponding task.
+ tid, err := strconv.ParseUint(name, 10, 64)
+ if err != nil {
+ // Ignore the parse error and return the original.
+ return nil, walkErr
+ }
+
+ // Grab the other task.
+ otherTask := p.pidns.TaskWithID(kernel.ThreadID(tid))
+ if otherTask == nil {
+ // Per above.
+ return nil, walkErr
+ }
+
+ // Wrap it in a taskDir.
+ td := p.newTaskDir(otherTask, dir.MountSource, true)
+ return fs.NewDirent(td, name), nil
+}
+
+// GetFile implements fs.InodeOperations.
+func (p *proc) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &rootProcFile{iops: p}), nil
+}
+
+// rootProcFile implements fs.FileOperations for the proc directory.
+//
+// +stateify savable
+type rootProcFile struct {
+ fsutil.DirFileOperations `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ iops *proc
+}
+
+var _ fs.FileOperations = (*rootProcFile)(nil)
+
+// Readdir implements fs.FileOperations.Readdir.
+func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) {
+ offset := file.Offset()
+ dirCtx := &fs.DirCtx{
+ Serializer: ser,
+ }
+
+ // Get normal directory contents from ramfs dir.
+ names, m := rpf.iops.Dir.Children()
+
+ // Add dot and dotdot.
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dot, dotdot := file.Dirent.GetDotAttrs(root)
+ names = append(names, ".", "..")
+ m["."] = dot
+ m[".."] = dotdot
+
+ // Collect tasks.
+ // Per linux we only include it in directory listings if it's the leader.
+ // But for whatever crazy reason, you can still walk to the given node.
+ for _, tg := range rpf.iops.pidns.ThreadGroups() {
+ if leader := tg.Leader(); leader != nil {
+ name := strconv.FormatUint(uint64(tg.ID()), 10)
+ m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice)
+ names = append(names, name)
+ }
+ }
+
+ if offset >= int64(len(m)) {
+ return offset, nil
+ }
+ sort.Strings(names)
+ names = names[offset:]
+ for _, name := range names {
+ if err := dirCtx.DirEmit(name, m[name]); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
diff --git a/pkg/sentry/fs/proc/proc_state_autogen.go b/pkg/sentry/fs/proc/proc_state_autogen.go
new file mode 100755
index 000000000..788606f21
--- /dev/null
+++ b/pkg/sentry/fs/proc/proc_state_autogen.go
@@ -0,0 +1,657 @@
+// automatically generated by stateify.
+
+package proc
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *execArgInode) beforeSave() {}
+func (x *execArgInode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+ m.Save("arg", &x.arg)
+ m.Save("t", &x.t)
+}
+
+func (x *execArgInode) afterLoad() {}
+func (x *execArgInode) load(m state.Map) {
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+ m.Load("arg", &x.arg)
+ m.Load("t", &x.t)
+}
+
+func (x *execArgFile) beforeSave() {}
+func (x *execArgFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("arg", &x.arg)
+ m.Save("t", &x.t)
+}
+
+func (x *execArgFile) afterLoad() {}
+func (x *execArgFile) load(m state.Map) {
+ m.Load("arg", &x.arg)
+ m.Load("t", &x.t)
+}
+
+func (x *fdDir) beforeSave() {}
+func (x *fdDir) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Dir", &x.Dir)
+ m.Save("t", &x.t)
+}
+
+func (x *fdDir) afterLoad() {}
+func (x *fdDir) load(m state.Map) {
+ m.Load("Dir", &x.Dir)
+ m.Load("t", &x.t)
+}
+
+func (x *fdDirFile) beforeSave() {}
+func (x *fdDirFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("isInfoFile", &x.isInfoFile)
+ m.Save("t", &x.t)
+}
+
+func (x *fdDirFile) afterLoad() {}
+func (x *fdDirFile) load(m state.Map) {
+ m.Load("isInfoFile", &x.isInfoFile)
+ m.Load("t", &x.t)
+}
+
+func (x *fdInfoDir) beforeSave() {}
+func (x *fdInfoDir) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Dir", &x.Dir)
+ m.Save("t", &x.t)
+}
+
+func (x *fdInfoDir) afterLoad() {}
+func (x *fdInfoDir) load(m state.Map) {
+ m.Load("Dir", &x.Dir)
+ m.Load("t", &x.t)
+}
+
+func (x *filesystemsData) beforeSave() {}
+func (x *filesystemsData) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *filesystemsData) afterLoad() {}
+func (x *filesystemsData) load(m state.Map) {
+}
+
+func (x *filesystem) beforeSave() {}
+func (x *filesystem) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *filesystem) afterLoad() {}
+func (x *filesystem) load(m state.Map) {
+}
+
+func (x *taskOwnedInodeOps) beforeSave() {}
+func (x *taskOwnedInodeOps) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeOperations", &x.InodeOperations)
+ m.Save("t", &x.t)
+}
+
+func (x *taskOwnedInodeOps) afterLoad() {}
+func (x *taskOwnedInodeOps) load(m state.Map) {
+ m.Load("InodeOperations", &x.InodeOperations)
+ m.Load("t", &x.t)
+}
+
+func (x *staticFileInodeOps) beforeSave() {}
+func (x *staticFileInodeOps) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("InodeStaticFileGetter", &x.InodeStaticFileGetter)
+}
+
+func (x *staticFileInodeOps) afterLoad() {}
+func (x *staticFileInodeOps) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("InodeStaticFileGetter", &x.InodeStaticFileGetter)
+}
+
+func (x *loadavgData) beforeSave() {}
+func (x *loadavgData) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *loadavgData) afterLoad() {}
+func (x *loadavgData) load(m state.Map) {
+}
+
+func (x *meminfoData) beforeSave() {}
+func (x *meminfoData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("k", &x.k)
+}
+
+func (x *meminfoData) afterLoad() {}
+func (x *meminfoData) load(m state.Map) {
+ m.Load("k", &x.k)
+}
+
+func (x *mountInfoFile) beforeSave() {}
+func (x *mountInfoFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+}
+
+func (x *mountInfoFile) afterLoad() {}
+func (x *mountInfoFile) load(m state.Map) {
+ m.Load("t", &x.t)
+}
+
+func (x *mountsFile) beforeSave() {}
+func (x *mountsFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+}
+
+func (x *mountsFile) afterLoad() {}
+func (x *mountsFile) load(m state.Map) {
+ m.Load("t", &x.t)
+}
+
+func (x *ifinet6) beforeSave() {}
+func (x *ifinet6) save(m state.Map) {
+ x.beforeSave()
+ m.Save("s", &x.s)
+}
+
+func (x *ifinet6) afterLoad() {}
+func (x *ifinet6) load(m state.Map) {
+ m.Load("s", &x.s)
+}
+
+func (x *netDev) beforeSave() {}
+func (x *netDev) save(m state.Map) {
+ x.beforeSave()
+ m.Save("s", &x.s)
+}
+
+func (x *netDev) afterLoad() {}
+func (x *netDev) load(m state.Map) {
+ m.Load("s", &x.s)
+}
+
+func (x *netUnix) beforeSave() {}
+func (x *netUnix) save(m state.Map) {
+ x.beforeSave()
+ m.Save("k", &x.k)
+}
+
+func (x *netUnix) afterLoad() {}
+func (x *netUnix) load(m state.Map) {
+ m.Load("k", &x.k)
+}
+
+func (x *proc) beforeSave() {}
+func (x *proc) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Dir", &x.Dir)
+ m.Save("k", &x.k)
+ m.Save("pidns", &x.pidns)
+ m.Save("cgroupControllers", &x.cgroupControllers)
+}
+
+func (x *proc) afterLoad() {}
+func (x *proc) load(m state.Map) {
+ m.Load("Dir", &x.Dir)
+ m.Load("k", &x.k)
+ m.Load("pidns", &x.pidns)
+ m.Load("cgroupControllers", &x.cgroupControllers)
+}
+
+func (x *self) beforeSave() {}
+func (x *self) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Symlink", &x.Symlink)
+ m.Save("pidns", &x.pidns)
+}
+
+func (x *self) afterLoad() {}
+func (x *self) load(m state.Map) {
+ m.Load("Symlink", &x.Symlink)
+ m.Load("pidns", &x.pidns)
+}
+
+func (x *threadSelf) beforeSave() {}
+func (x *threadSelf) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Symlink", &x.Symlink)
+ m.Save("pidns", &x.pidns)
+}
+
+func (x *threadSelf) afterLoad() {}
+func (x *threadSelf) load(m state.Map) {
+ m.Load("Symlink", &x.Symlink)
+ m.Load("pidns", &x.pidns)
+}
+
+func (x *rootProcFile) beforeSave() {}
+func (x *rootProcFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("iops", &x.iops)
+}
+
+func (x *rootProcFile) afterLoad() {}
+func (x *rootProcFile) load(m state.Map) {
+ m.Load("iops", &x.iops)
+}
+
+func (x *statData) beforeSave() {}
+func (x *statData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("k", &x.k)
+}
+
+func (x *statData) afterLoad() {}
+func (x *statData) load(m state.Map) {
+ m.Load("k", &x.k)
+}
+
+func (x *mmapMinAddrData) beforeSave() {}
+func (x *mmapMinAddrData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("k", &x.k)
+}
+
+func (x *mmapMinAddrData) afterLoad() {}
+func (x *mmapMinAddrData) load(m state.Map) {
+ m.Load("k", &x.k)
+}
+
+func (x *overcommitMemory) beforeSave() {}
+func (x *overcommitMemory) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *overcommitMemory) afterLoad() {}
+func (x *overcommitMemory) load(m state.Map) {
+}
+
+func (x *hostname) beforeSave() {}
+func (x *hostname) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+}
+
+func (x *hostname) afterLoad() {}
+func (x *hostname) load(m state.Map) {
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+}
+
+func (x *hostnameFile) beforeSave() {}
+func (x *hostnameFile) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *hostnameFile) afterLoad() {}
+func (x *hostnameFile) load(m state.Map) {
+}
+
+func (x *tcpMemInode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+ m.Save("dir", &x.dir)
+ m.Save("s", &x.s)
+ m.Save("size", &x.size)
+}
+
+func (x *tcpMemInode) load(m state.Map) {
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+ m.Load("dir", &x.dir)
+ m.LoadWait("s", &x.s)
+ m.Load("size", &x.size)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *tcpMemFile) beforeSave() {}
+func (x *tcpMemFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("tcpMemInode", &x.tcpMemInode)
+}
+
+func (x *tcpMemFile) afterLoad() {}
+func (x *tcpMemFile) load(m state.Map) {
+ m.Load("tcpMemInode", &x.tcpMemInode)
+}
+
+func (x *tcpSack) beforeSave() {}
+func (x *tcpSack) save(m state.Map) {
+ x.beforeSave()
+ m.Save("stack", &x.stack)
+ m.Save("enabled", &x.enabled)
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+}
+
+func (x *tcpSack) load(m state.Map) {
+ m.LoadWait("stack", &x.stack)
+ m.Load("enabled", &x.enabled)
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *tcpSackFile) beforeSave() {}
+func (x *tcpSackFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("tcpSack", &x.tcpSack)
+ m.Save("stack", &x.stack)
+}
+
+func (x *tcpSackFile) afterLoad() {}
+func (x *tcpSackFile) load(m state.Map) {
+ m.Load("tcpSack", &x.tcpSack)
+ m.LoadWait("stack", &x.stack)
+}
+
+func (x *taskDir) beforeSave() {}
+func (x *taskDir) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Dir", &x.Dir)
+ m.Save("t", &x.t)
+ m.Save("pidns", &x.pidns)
+}
+
+func (x *taskDir) afterLoad() {}
+func (x *taskDir) load(m state.Map) {
+ m.Load("Dir", &x.Dir)
+ m.Load("t", &x.t)
+ m.Load("pidns", &x.pidns)
+}
+
+func (x *subtasks) beforeSave() {}
+func (x *subtasks) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Dir", &x.Dir)
+ m.Save("t", &x.t)
+ m.Save("p", &x.p)
+}
+
+func (x *subtasks) afterLoad() {}
+func (x *subtasks) load(m state.Map) {
+ m.Load("Dir", &x.Dir)
+ m.Load("t", &x.t)
+ m.Load("p", &x.p)
+}
+
+func (x *subtasksFile) beforeSave() {}
+func (x *subtasksFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+ m.Save("pidns", &x.pidns)
+}
+
+func (x *subtasksFile) afterLoad() {}
+func (x *subtasksFile) load(m state.Map) {
+ m.Load("t", &x.t)
+ m.Load("pidns", &x.pidns)
+}
+
+func (x *exe) beforeSave() {}
+func (x *exe) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Symlink", &x.Symlink)
+ m.Save("t", &x.t)
+}
+
+func (x *exe) afterLoad() {}
+func (x *exe) load(m state.Map) {
+ m.Load("Symlink", &x.Symlink)
+ m.Load("t", &x.t)
+}
+
+func (x *namespaceSymlink) beforeSave() {}
+func (x *namespaceSymlink) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Symlink", &x.Symlink)
+ m.Save("t", &x.t)
+}
+
+func (x *namespaceSymlink) afterLoad() {}
+func (x *namespaceSymlink) load(m state.Map) {
+ m.Load("Symlink", &x.Symlink)
+ m.Load("t", &x.t)
+}
+
+func (x *mapsData) beforeSave() {}
+func (x *mapsData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+}
+
+func (x *mapsData) afterLoad() {}
+func (x *mapsData) load(m state.Map) {
+ m.Load("t", &x.t)
+}
+
+func (x *smapsData) beforeSave() {}
+func (x *smapsData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+}
+
+func (x *smapsData) afterLoad() {}
+func (x *smapsData) load(m state.Map) {
+ m.Load("t", &x.t)
+}
+
+func (x *taskStatData) beforeSave() {}
+func (x *taskStatData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+ m.Save("tgstats", &x.tgstats)
+ m.Save("pidns", &x.pidns)
+}
+
+func (x *taskStatData) afterLoad() {}
+func (x *taskStatData) load(m state.Map) {
+ m.Load("t", &x.t)
+ m.Load("tgstats", &x.tgstats)
+ m.Load("pidns", &x.pidns)
+}
+
+func (x *statmData) beforeSave() {}
+func (x *statmData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+}
+
+func (x *statmData) afterLoad() {}
+func (x *statmData) load(m state.Map) {
+ m.Load("t", &x.t)
+}
+
+func (x *statusData) beforeSave() {}
+func (x *statusData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+ m.Save("pidns", &x.pidns)
+}
+
+func (x *statusData) afterLoad() {}
+func (x *statusData) load(m state.Map) {
+ m.Load("t", &x.t)
+ m.Load("pidns", &x.pidns)
+}
+
+func (x *ioData) beforeSave() {}
+func (x *ioData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ioUsage", &x.ioUsage)
+}
+
+func (x *ioData) afterLoad() {}
+func (x *ioData) load(m state.Map) {
+ m.Load("ioUsage", &x.ioUsage)
+}
+
+func (x *comm) beforeSave() {}
+func (x *comm) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+ m.Save("t", &x.t)
+}
+
+func (x *comm) afterLoad() {}
+func (x *comm) load(m state.Map) {
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+ m.Load("t", &x.t)
+}
+
+func (x *commFile) beforeSave() {}
+func (x *commFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+}
+
+func (x *commFile) afterLoad() {}
+func (x *commFile) load(m state.Map) {
+ m.Load("t", &x.t)
+}
+
+func (x *auxvec) beforeSave() {}
+func (x *auxvec) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+ m.Save("t", &x.t)
+}
+
+func (x *auxvec) afterLoad() {}
+func (x *auxvec) load(m state.Map) {
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+ m.Load("t", &x.t)
+}
+
+func (x *auxvecFile) beforeSave() {}
+func (x *auxvecFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+}
+
+func (x *auxvecFile) afterLoad() {}
+func (x *auxvecFile) load(m state.Map) {
+ m.Load("t", &x.t)
+}
+
+func (x *idMapInodeOperations) beforeSave() {}
+func (x *idMapInodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Save("t", &x.t)
+ m.Save("gids", &x.gids)
+}
+
+func (x *idMapInodeOperations) afterLoad() {}
+func (x *idMapInodeOperations) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Load("t", &x.t)
+ m.Load("gids", &x.gids)
+}
+
+func (x *idMapFileOperations) beforeSave() {}
+func (x *idMapFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("iops", &x.iops)
+}
+
+func (x *idMapFileOperations) afterLoad() {}
+func (x *idMapFileOperations) load(m state.Map) {
+ m.Load("iops", &x.iops)
+}
+
+func (x *uptime) beforeSave() {}
+func (x *uptime) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+ m.Save("startTime", &x.startTime)
+}
+
+func (x *uptime) afterLoad() {}
+func (x *uptime) load(m state.Map) {
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+ m.Load("startTime", &x.startTime)
+}
+
+func (x *uptimeFile) beforeSave() {}
+func (x *uptimeFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("startTime", &x.startTime)
+}
+
+func (x *uptimeFile) afterLoad() {}
+func (x *uptimeFile) load(m state.Map) {
+ m.Load("startTime", &x.startTime)
+}
+
+func (x *versionData) beforeSave() {}
+func (x *versionData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("k", &x.k)
+}
+
+func (x *versionData) afterLoad() {}
+func (x *versionData) load(m state.Map) {
+ m.Load("k", &x.k)
+}
+
+func init() {
+ state.Register("proc.execArgInode", (*execArgInode)(nil), state.Fns{Save: (*execArgInode).save, Load: (*execArgInode).load})
+ state.Register("proc.execArgFile", (*execArgFile)(nil), state.Fns{Save: (*execArgFile).save, Load: (*execArgFile).load})
+ state.Register("proc.fdDir", (*fdDir)(nil), state.Fns{Save: (*fdDir).save, Load: (*fdDir).load})
+ state.Register("proc.fdDirFile", (*fdDirFile)(nil), state.Fns{Save: (*fdDirFile).save, Load: (*fdDirFile).load})
+ state.Register("proc.fdInfoDir", (*fdInfoDir)(nil), state.Fns{Save: (*fdInfoDir).save, Load: (*fdInfoDir).load})
+ state.Register("proc.filesystemsData", (*filesystemsData)(nil), state.Fns{Save: (*filesystemsData).save, Load: (*filesystemsData).load})
+ state.Register("proc.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load})
+ state.Register("proc.taskOwnedInodeOps", (*taskOwnedInodeOps)(nil), state.Fns{Save: (*taskOwnedInodeOps).save, Load: (*taskOwnedInodeOps).load})
+ state.Register("proc.staticFileInodeOps", (*staticFileInodeOps)(nil), state.Fns{Save: (*staticFileInodeOps).save, Load: (*staticFileInodeOps).load})
+ state.Register("proc.loadavgData", (*loadavgData)(nil), state.Fns{Save: (*loadavgData).save, Load: (*loadavgData).load})
+ state.Register("proc.meminfoData", (*meminfoData)(nil), state.Fns{Save: (*meminfoData).save, Load: (*meminfoData).load})
+ state.Register("proc.mountInfoFile", (*mountInfoFile)(nil), state.Fns{Save: (*mountInfoFile).save, Load: (*mountInfoFile).load})
+ state.Register("proc.mountsFile", (*mountsFile)(nil), state.Fns{Save: (*mountsFile).save, Load: (*mountsFile).load})
+ state.Register("proc.ifinet6", (*ifinet6)(nil), state.Fns{Save: (*ifinet6).save, Load: (*ifinet6).load})
+ state.Register("proc.netDev", (*netDev)(nil), state.Fns{Save: (*netDev).save, Load: (*netDev).load})
+ state.Register("proc.netUnix", (*netUnix)(nil), state.Fns{Save: (*netUnix).save, Load: (*netUnix).load})
+ state.Register("proc.proc", (*proc)(nil), state.Fns{Save: (*proc).save, Load: (*proc).load})
+ state.Register("proc.self", (*self)(nil), state.Fns{Save: (*self).save, Load: (*self).load})
+ state.Register("proc.threadSelf", (*threadSelf)(nil), state.Fns{Save: (*threadSelf).save, Load: (*threadSelf).load})
+ state.Register("proc.rootProcFile", (*rootProcFile)(nil), state.Fns{Save: (*rootProcFile).save, Load: (*rootProcFile).load})
+ state.Register("proc.statData", (*statData)(nil), state.Fns{Save: (*statData).save, Load: (*statData).load})
+ state.Register("proc.mmapMinAddrData", (*mmapMinAddrData)(nil), state.Fns{Save: (*mmapMinAddrData).save, Load: (*mmapMinAddrData).load})
+ state.Register("proc.overcommitMemory", (*overcommitMemory)(nil), state.Fns{Save: (*overcommitMemory).save, Load: (*overcommitMemory).load})
+ state.Register("proc.hostname", (*hostname)(nil), state.Fns{Save: (*hostname).save, Load: (*hostname).load})
+ state.Register("proc.hostnameFile", (*hostnameFile)(nil), state.Fns{Save: (*hostnameFile).save, Load: (*hostnameFile).load})
+ state.Register("proc.tcpMemInode", (*tcpMemInode)(nil), state.Fns{Save: (*tcpMemInode).save, Load: (*tcpMemInode).load})
+ state.Register("proc.tcpMemFile", (*tcpMemFile)(nil), state.Fns{Save: (*tcpMemFile).save, Load: (*tcpMemFile).load})
+ state.Register("proc.tcpSack", (*tcpSack)(nil), state.Fns{Save: (*tcpSack).save, Load: (*tcpSack).load})
+ state.Register("proc.tcpSackFile", (*tcpSackFile)(nil), state.Fns{Save: (*tcpSackFile).save, Load: (*tcpSackFile).load})
+ state.Register("proc.taskDir", (*taskDir)(nil), state.Fns{Save: (*taskDir).save, Load: (*taskDir).load})
+ state.Register("proc.subtasks", (*subtasks)(nil), state.Fns{Save: (*subtasks).save, Load: (*subtasks).load})
+ state.Register("proc.subtasksFile", (*subtasksFile)(nil), state.Fns{Save: (*subtasksFile).save, Load: (*subtasksFile).load})
+ state.Register("proc.exe", (*exe)(nil), state.Fns{Save: (*exe).save, Load: (*exe).load})
+ state.Register("proc.namespaceSymlink", (*namespaceSymlink)(nil), state.Fns{Save: (*namespaceSymlink).save, Load: (*namespaceSymlink).load})
+ state.Register("proc.mapsData", (*mapsData)(nil), state.Fns{Save: (*mapsData).save, Load: (*mapsData).load})
+ state.Register("proc.smapsData", (*smapsData)(nil), state.Fns{Save: (*smapsData).save, Load: (*smapsData).load})
+ state.Register("proc.taskStatData", (*taskStatData)(nil), state.Fns{Save: (*taskStatData).save, Load: (*taskStatData).load})
+ state.Register("proc.statmData", (*statmData)(nil), state.Fns{Save: (*statmData).save, Load: (*statmData).load})
+ state.Register("proc.statusData", (*statusData)(nil), state.Fns{Save: (*statusData).save, Load: (*statusData).load})
+ state.Register("proc.ioData", (*ioData)(nil), state.Fns{Save: (*ioData).save, Load: (*ioData).load})
+ state.Register("proc.comm", (*comm)(nil), state.Fns{Save: (*comm).save, Load: (*comm).load})
+ state.Register("proc.commFile", (*commFile)(nil), state.Fns{Save: (*commFile).save, Load: (*commFile).load})
+ state.Register("proc.auxvec", (*auxvec)(nil), state.Fns{Save: (*auxvec).save, Load: (*auxvec).load})
+ state.Register("proc.auxvecFile", (*auxvecFile)(nil), state.Fns{Save: (*auxvecFile).save, Load: (*auxvecFile).load})
+ state.Register("proc.idMapInodeOperations", (*idMapInodeOperations)(nil), state.Fns{Save: (*idMapInodeOperations).save, Load: (*idMapInodeOperations).load})
+ state.Register("proc.idMapFileOperations", (*idMapFileOperations)(nil), state.Fns{Save: (*idMapFileOperations).save, Load: (*idMapFileOperations).load})
+ state.Register("proc.uptime", (*uptime)(nil), state.Fns{Save: (*uptime).save, Load: (*uptime).load})
+ state.Register("proc.uptimeFile", (*uptimeFile)(nil), state.Fns{Save: (*uptimeFile).save, Load: (*uptimeFile).load})
+ state.Register("proc.versionData", (*versionData)(nil), state.Fns{Save: (*versionData).save, Load: (*versionData).load})
+}
diff --git a/pkg/sentry/fs/proc/rpcinet_proc.go b/pkg/sentry/fs/proc/rpcinet_proc.go
new file mode 100644
index 000000000..e36c0bfa6
--- /dev/null
+++ b/pkg/sentry/fs/proc/rpcinet_proc.go
@@ -0,0 +1,217 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// rpcInetInode implments fs.InodeOperations.
+type rpcInetInode struct {
+ fsutil.SimpleFileInode
+
+ // filepath is the full path of this rpcInetInode.
+ filepath string
+
+ k *kernel.Kernel
+}
+
+func newRPCInetInode(ctx context.Context, msrc *fs.MountSource, filepath string, mode linux.FileMode) *fs.Inode {
+ f := &rpcInetInode{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(mode), linux.PROC_SUPER_MAGIC),
+ filepath: filepath,
+ k: kernel.KernelFromContext(ctx),
+ }
+ return newProcInode(f, msrc, fs.SpecialFile, nil)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (i *rpcInetInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ fops := &rpcInetFile{
+ inode: i,
+ }
+ return fs.NewFile(ctx, dirent, flags, fops), nil
+}
+
+// rpcInetFile implements fs.FileOperations as RPCs.
+type rpcInetFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ inode *rpcInetInode
+}
+
+// Read implements fs.FileOperations.Read.
+//
+// This method can panic if an rpcInetInode was created without an rpcinet
+// stack.
+func (f *rpcInetFile) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ s, ok := f.inode.k.NetworkStack().(*rpcinet.Stack)
+ if !ok {
+ panic("Network stack is not a rpcinet.")
+ }
+
+ contents, se := s.RPCReadFile(f.inode.filepath)
+ if se != nil || offset >= int64(len(contents)) {
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOut(ctx, contents[offset:])
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+//
+// This method can panic if an rpcInetInode was created without an rpcInet
+// stack.
+func (f *rpcInetFile) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ s, ok := f.inode.k.NetworkStack().(*rpcinet.Stack)
+ if !ok {
+ panic("Network stack is not a rpcinet.")
+ }
+
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ b := make([]byte, src.NumBytes(), src.NumBytes())
+ n, err := src.CopyIn(ctx, b)
+ if err != nil {
+ return int64(n), err
+ }
+
+ written, se := s.RPCWriteFile(f.inode.filepath, b)
+ return int64(written), se.ToError()
+}
+
+// newRPCInetProcNet will build an inode for /proc/net.
+func newRPCInetProcNet(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "arp": newRPCInetInode(ctx, msrc, "/proc/net/arp", 0444),
+ "dev": newRPCInetInode(ctx, msrc, "/proc/net/dev", 0444),
+ "if_inet6": newRPCInetInode(ctx, msrc, "/proc/net/if_inet6", 0444),
+ "ipv6_route": newRPCInetInode(ctx, msrc, "/proc/net/ipv6_route", 0444),
+ "netlink": newRPCInetInode(ctx, msrc, "/proc/net/netlink", 0444),
+ "netstat": newRPCInetInode(ctx, msrc, "/proc/net/netstat", 0444),
+ "packet": newRPCInetInode(ctx, msrc, "/proc/net/packet", 0444),
+ "protocols": newRPCInetInode(ctx, msrc, "/proc/net/protocols", 0444),
+ "psched": newRPCInetInode(ctx, msrc, "/proc/net/psched", 0444),
+ "ptype": newRPCInetInode(ctx, msrc, "/proc/net/ptype", 0444),
+ "route": newRPCInetInode(ctx, msrc, "/proc/net/route", 0444),
+ "tcp": newRPCInetInode(ctx, msrc, "/proc/net/tcp", 0444),
+ "tcp6": newRPCInetInode(ctx, msrc, "/proc/net/tcp6", 0444),
+ "udp": newRPCInetInode(ctx, msrc, "/proc/net/udp", 0444),
+ "udp6": newRPCInetInode(ctx, msrc, "/proc/net/udp6", 0444),
+ }
+
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+// newRPCInetProcSysNet will build an inode for /proc/sys/net.
+func newRPCInetProcSysNet(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "ipv4": newRPCInetSysNetIPv4Dir(ctx, msrc),
+ "core": newRPCInetSysNetCore(ctx, msrc),
+ }
+
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+// newRPCInetSysNetCore builds the /proc/sys/net/core directory.
+func newRPCInetSysNetCore(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "default_qdisc": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/default_qdisc", 0444),
+ "message_burst": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/message_burst", 0444),
+ "message_cost": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/message_cost", 0444),
+ "optmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/optmem_max", 0444),
+ "rmem_default": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/rmem_default", 0444),
+ "rmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/rmem_max", 0444),
+ "somaxconn": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/somaxconn", 0444),
+ "wmem_default": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/wmem_default", 0444),
+ "wmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/wmem_max", 0444),
+ }
+
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+// newRPCInetSysNetIPv4Dir builds the /proc/sys/net/ipv4 directory.
+func newRPCInetSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "ip_local_port_range": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_local_port_range", 0444),
+ "ip_local_reserved_ports": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_local_reserved_ports", 0444),
+ "ipfrag_time": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ipfrag_time", 0444),
+ "ip_nonlocal_bind": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_nonlocal_bind", 0444),
+ "ip_no_pmtu_disc": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_no_pmtu_disc", 0444),
+ "tcp_allowed_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_allowed_congestion_control", 0444),
+ "tcp_available_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_available_congestion_control", 0444),
+ "tcp_base_mss": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_base_mss", 0444),
+ "tcp_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_congestion_control", 0644),
+ "tcp_dsack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_dsack", 0644),
+ "tcp_early_retrans": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_early_retrans", 0644),
+ "tcp_fack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fack", 0644),
+ "tcp_fastopen": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fastopen", 0644),
+ "tcp_fastopen_key": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fastopen_key", 0444),
+ "tcp_fin_timeout": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fin_timeout", 0644),
+ "tcp_invalid_ratelimit": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_invalid_ratelimit", 0444),
+ "tcp_keepalive_intvl": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_intvl", 0644),
+ "tcp_keepalive_probes": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_probes", 0644),
+ "tcp_keepalive_time": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_time", 0644),
+ "tcp_mem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_mem", 0444),
+ "tcp_mtu_probing": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_mtu_probing", 0644),
+ "tcp_no_metrics_save": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_no_metrics_save", 0444),
+ "tcp_probe_interval": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_probe_interval", 0444),
+ "tcp_probe_threshold": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_probe_threshold", 0444),
+ "tcp_retries1": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_retries1", 0644),
+ "tcp_retries2": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_retries2", 0644),
+ "tcp_rfc1337": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_rfc1337", 0444),
+ "tcp_rmem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_rmem", 0444),
+ "tcp_sack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_sack", 0644),
+ "tcp_slow_start_after_idle": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_slow_start_after_idle", 0644),
+ "tcp_synack_retries": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_synack_retries", 0644),
+ "tcp_syn_retries": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_syn_retries", 0644),
+ "tcp_timestamps": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_timestamps", 0644),
+ "tcp_wmem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_wmem", 0444),
+ }
+
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go
new file mode 100644
index 000000000..8364d86ed
--- /dev/null
+++ b/pkg/sentry/fs/proc/seqfile/seqfile.go
@@ -0,0 +1,282 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seqfile
+
+import (
+ "io"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// SeqHandle is a helper handle to seek in the file.
+type SeqHandle interface{}
+
+// SeqData holds the data for one unit in the file.
+//
+// +stateify savable
+type SeqData struct {
+ // The data to be returned to the user.
+ Buf []byte
+
+ // A seek handle used to find the next valid unit in ReadSeqFiledata.
+ Handle SeqHandle
+}
+
+// SeqSource is a data source for a SeqFile file.
+type SeqSource interface {
+ // NeedsUpdate returns true if the consumer of SeqData should call
+ // ReadSeqFileData again. Generation is the generation returned by
+ // ReadSeqFile or 0.
+ NeedsUpdate(generation int64) bool
+
+ // Returns a slice of SeqData ordered by unit and the current
+ // generation. The first entry in the slice is greater than the handle.
+ // If handle is nil then all known records are returned. Generation
+ // must always be greater than 0.
+ ReadSeqFileData(ctx context.Context, handle SeqHandle) ([]SeqData, int64)
+}
+
+// SeqGenerationCounter is a counter to keep track if the SeqSource should be
+// updated. SeqGenerationCounter is not thread-safe and should be protected
+// with a mutex.
+type SeqGenerationCounter struct {
+ // The generation that the SeqData is at.
+ generation int64
+}
+
+// SetGeneration sets the generation to the new value, be careful to not set it
+// to a value less than current.
+func (s *SeqGenerationCounter) SetGeneration(generation int64) {
+ s.generation = generation
+}
+
+// Update increments the current generation.
+func (s *SeqGenerationCounter) Update() {
+ s.generation++
+}
+
+// Generation returns the current generation counter.
+func (s *SeqGenerationCounter) Generation() int64 {
+ return s.generation
+}
+
+// IsCurrent returns whether the given generation is current or not.
+func (s *SeqGenerationCounter) IsCurrent(generation int64) bool {
+ return s.Generation() == generation
+}
+
+// SeqFile is used to provide dynamic files that can be ordered by record.
+//
+// +stateify savable
+type SeqFile struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleExtendedAttributes
+ fsutil.InodeSimpleAttributes
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ SeqSource
+
+ source []SeqData
+ generation int64
+ lastRead int64
+}
+
+var _ fs.InodeOperations = (*SeqFile)(nil)
+
+// NewSeqFile returns a seqfile suitable for use by external consumers.
+func NewSeqFile(ctx context.Context, source SeqSource) *SeqFile {
+ return &SeqFile{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ SeqSource: source,
+ }
+}
+
+// NewSeqFileInode returns an Inode with SeqFile InodeOperations.
+func NewSeqFileInode(ctx context.Context, source SeqSource, msrc *fs.MountSource) *fs.Inode {
+ iops := NewSeqFile(ctx, source)
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(iops, msrc, sattr)
+}
+
+// UnstableAttr returns unstable attributes of the SeqFile.
+func (s *SeqFile) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ uattr, err := s.InodeSimpleAttributes.UnstableAttr(ctx, inode)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ uattr.ModificationTime = ktime.NowFromContext(ctx)
+ return uattr, nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (s *SeqFile) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &seqFileOperations{seqFile: s}), nil
+}
+
+// findIndexAndOffset finds the unit that corresponds to a certain offset.
+// Returns the unit and the offset within the unit. If there are not enough
+// units len(data) and leftover offset is returned.
+func findIndexAndOffset(data []SeqData, offset int64) (int, int64) {
+ for i, buf := range data {
+ l := int64(len(buf.Buf))
+ if offset < l {
+ return i, offset
+ }
+ offset -= l
+ }
+ return len(data), offset
+}
+
+// updateSourceLocked requires that s.mu is held.
+func (s *SeqFile) updateSourceLocked(ctx context.Context, record int) {
+ var h SeqHandle
+ if record == 0 {
+ h = nil
+ } else {
+ h = s.source[record-1].Handle
+ }
+ // Save what we have previously read.
+ s.source = s.source[:record]
+ var newSource []SeqData
+ newSource, s.generation = s.SeqSource.ReadSeqFileData(ctx, h)
+ s.source = append(s.source, newSource...)
+}
+
+// seqFileOperations implements fs.FileOperations.
+//
+// +stateify savable
+type seqFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ seqFile *SeqFile
+}
+
+var _ fs.FileOperations = (*seqFileOperations)(nil)
+
+// Write implements fs.FileOperations.Write.
+func (*seqFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EACCES
+}
+
+// Read implements fs.FileOperations.Read.
+func (sfo *seqFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ sfo.seqFile.mu.Lock()
+ defer sfo.seqFile.mu.Unlock()
+
+ sfo.seqFile.NotifyAccess(ctx)
+ defer func() { sfo.seqFile.lastRead = offset }()
+
+ updated := false
+
+ // Try to find where we should start reading this file.
+ i, recordOffset := findIndexAndOffset(sfo.seqFile.source, offset)
+ if i == len(sfo.seqFile.source) {
+ // Ok, we're at EOF. Let's first check to see if there might be
+ // more data available to us. If there is more data, add it to
+ // the end and try reading again.
+ if !sfo.seqFile.SeqSource.NeedsUpdate(sfo.seqFile.generation) {
+ return 0, io.EOF
+ }
+ oldLen := len(sfo.seqFile.source)
+ sfo.seqFile.updateSourceLocked(ctx, len(sfo.seqFile.source))
+ updated = true
+ // We know that we had consumed everything up until this point
+ // so we search in the new slice instead of starting over.
+ i, recordOffset = findIndexAndOffset(sfo.seqFile.source[oldLen:], recordOffset)
+ i += oldLen
+ // i is at most the length of the slice which is
+ // len(sfo.seqFile.source) - oldLen. So at most i will be equal to
+ // len(sfo.seqFile.source).
+ if i == len(sfo.seqFile.source) {
+ return 0, io.EOF
+ }
+ }
+
+ var done int64
+ // We're reading parts of a record, finish reading the current object
+ // before continuing on to the next. We don't refresh our data source
+ // before this record is completed.
+ if recordOffset != 0 {
+ n, err := dst.CopyOut(ctx, sfo.seqFile.source[i].Buf[recordOffset:])
+ done += int64(n)
+ dst = dst.DropFirst(n)
+ if dst.NumBytes() == 0 || err != nil {
+ return done, err
+ }
+ i++
+ }
+
+ // Next/New unit, update the source file if necessary. Make an extra
+ // check to see if we've seeked backwards and if so always update our
+ // data source.
+ if !updated && (sfo.seqFile.SeqSource.NeedsUpdate(sfo.seqFile.generation) || sfo.seqFile.lastRead > offset) {
+ sfo.seqFile.updateSourceLocked(ctx, i)
+ // recordOffset is 0 here and we won't update records behind the
+ // current one so recordOffset is still 0 even though source
+ // just got updated. Just read the next record.
+ }
+
+ // Finish by reading all the available data.
+ for _, buf := range sfo.seqFile.source[i:] {
+ n, err := dst.CopyOut(ctx, buf.Buf)
+ done += int64(n)
+ dst = dst.DropFirst(n)
+ if dst.NumBytes() == 0 || err != nil {
+ return done, err
+ }
+ }
+
+ // If the file shrank (entries not yet read were removed above)
+ // while we tried to read we can end up with nothing read.
+ if done == 0 && dst.NumBytes() != 0 {
+ return 0, io.EOF
+ }
+ return done, nil
+}
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go b/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go
new file mode 100755
index 000000000..c3b15d513
--- /dev/null
+++ b/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go
@@ -0,0 +1,58 @@
+// automatically generated by stateify.
+
+package seqfile
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *SeqData) beforeSave() {}
+func (x *SeqData) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Buf", &x.Buf)
+ m.Save("Handle", &x.Handle)
+}
+
+func (x *SeqData) afterLoad() {}
+func (x *SeqData) load(m state.Map) {
+ m.Load("Buf", &x.Buf)
+ m.Load("Handle", &x.Handle)
+}
+
+func (x *SeqFile) beforeSave() {}
+func (x *SeqFile) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("SeqSource", &x.SeqSource)
+ m.Save("source", &x.source)
+ m.Save("generation", &x.generation)
+ m.Save("lastRead", &x.lastRead)
+}
+
+func (x *SeqFile) afterLoad() {}
+func (x *SeqFile) load(m state.Map) {
+ m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("SeqSource", &x.SeqSource)
+ m.Load("source", &x.source)
+ m.Load("generation", &x.generation)
+ m.Load("lastRead", &x.lastRead)
+}
+
+func (x *seqFileOperations) beforeSave() {}
+func (x *seqFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("seqFile", &x.seqFile)
+}
+
+func (x *seqFileOperations) afterLoad() {}
+func (x *seqFileOperations) load(m state.Map) {
+ m.Load("seqFile", &x.seqFile)
+}
+
+func init() {
+ state.Register("seqfile.SeqData", (*SeqData)(nil), state.Fns{Save: (*SeqData).save, Load: (*SeqData).load})
+ state.Register("seqfile.SeqFile", (*SeqFile)(nil), state.Fns{Save: (*SeqFile).save, Load: (*SeqFile).load})
+ state.Register("seqfile.seqFileOperations", (*seqFileOperations)(nil), state.Fns{Save: (*seqFileOperations).save, Load: (*seqFileOperations).load})
+}
diff --git a/pkg/sentry/fs/proc/stat.go b/pkg/sentry/fs/proc/stat.go
new file mode 100644
index 000000000..397f9ec6b
--- /dev/null
+++ b/pkg/sentry/fs/proc/stat.go
@@ -0,0 +1,142 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+)
+
+// statData backs /proc/stat.
+//
+// +stateify savable
+type statData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*statData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// cpuStats contains the breakdown of CPU time for /proc/stat.
+type cpuStats struct {
+ // user is time spent in userspace tasks with non-positive niceness.
+ user uint64
+
+ // nice is time spent in userspace tasks with positive niceness.
+ nice uint64
+
+ // system is time spent in non-interrupt kernel context.
+ system uint64
+
+ // idle is time spent idle.
+ idle uint64
+
+ // ioWait is time spent waiting for IO.
+ ioWait uint64
+
+ // irq is time spent in interrupt context.
+ irq uint64
+
+ // softirq is time spent in software interrupt context.
+ softirq uint64
+
+ // steal is involuntary wait time.
+ steal uint64
+
+ // guest is time spent in guests with non-positive niceness.
+ guest uint64
+
+ // guestNice is time spent in guests with positive niceness.
+ guestNice uint64
+}
+
+// String implements fmt.Stringer.
+func (c cpuStats) String() string {
+ return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (s *statData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+
+ // TODO(b/37226836): We currently export only zero CPU stats. We could
+ // at least provide some aggregate stats.
+ var cpu cpuStats
+ fmt.Fprintf(&buf, "cpu %s\n", cpu)
+
+ for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
+ fmt.Fprintf(&buf, "cpu%d %s\n", c, cpu)
+ }
+
+ // The total number of interrupts is dependent on the CPUs and PCI
+ // devices on the system. See arch_probe_nr_irqs.
+ //
+ // Since we don't report real interrupt stats, just choose an arbitrary
+ // value from a representative VM.
+ const numInterrupts = 256
+
+ // The Kernel doesn't handle real interrupts, so report all zeroes.
+ // TODO(b/37226836): We could count page faults as #PF.
+ fmt.Fprintf(&buf, "intr 0") // total
+ for i := 0; i < numInterrupts; i++ {
+ fmt.Fprintf(&buf, " 0")
+ }
+ fmt.Fprintf(&buf, "\n")
+
+ // Total number of context switches.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(&buf, "ctxt 0\n")
+
+ // CLOCK_REALTIME timestamp from boot, in seconds.
+ fmt.Fprintf(&buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
+
+ // Total number of clones.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(&buf, "processes 0\n")
+
+ // Number of runnable tasks.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(&buf, "procs_running 0\n")
+
+ // Number of tasks waiting on IO.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(&buf, "procs_blocked 0\n")
+
+ // Number of each softirq handled.
+ fmt.Fprintf(&buf, "softirq 0") // total
+ for i := 0; i < linux.NumSoftIRQ; i++ {
+ fmt.Fprintf(&buf, " 0")
+ }
+ fmt.Fprintf(&buf, "\n")
+
+ return []seqfile.SeqData{
+ {
+ Buf: buf.Bytes(),
+ Handle: (*statData)(nil),
+ },
+ }, 0
+}
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
new file mode 100644
index 000000000..59846af4f
--- /dev/null
+++ b/pkg/sentry/fs/proc/sys.go
@@ -0,0 +1,162 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "io"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// mmapMinAddrData backs /proc/sys/vm/mmap_min_addr.
+//
+// +stateify savable
+type mmapMinAddrData struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*mmapMinAddrData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (d *mmapMinAddrData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+ return []seqfile.SeqData{
+ {
+ Buf: []byte(fmt.Sprintf("%d\n", d.k.Platform.MinUserAddress())),
+ Handle: (*mmapMinAddrData)(nil),
+ },
+ }, 0
+}
+
+// +stateify savable
+type overcommitMemory struct{}
+
+func (*overcommitMemory) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.
+func (*overcommitMemory) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+ return []seqfile.SeqData{
+ {
+ Buf: []byte("0\n"),
+ Handle: (*overcommitMemory)(nil),
+ },
+ }, 0
+}
+
+func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ h := hostname{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ }
+
+ children := map[string]*fs.Inode{
+ "hostname": newProcInode(&h, msrc, fs.SpecialFile, nil),
+ "shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
+ "shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
+ "shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
+ }
+
+ d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ children := map[string]*fs.Inode{
+ "mmap_min_addr": seqfile.NewSeqFileInode(ctx, &mmapMinAddrData{p.k}, msrc),
+ "overcommit_memory": seqfile.NewSeqFileInode(ctx, &overcommitMemory{}, msrc),
+ }
+ d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+func (p *proc) newSysDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ children := map[string]*fs.Inode{
+ "kernel": p.newKernelDir(ctx, msrc),
+ "vm": p.newVMDir(ctx, msrc),
+ }
+
+ // If we're using rpcinet we will let it manage /proc/sys/net.
+ if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok {
+ children["net"] = newRPCInetProcSysNet(ctx, msrc)
+ } else {
+ children["net"] = p.newSysNetDir(ctx, msrc)
+ }
+
+ d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+// hostname is the inode for a file containing the system hostname.
+//
+// +stateify savable
+type hostname struct {
+ fsutil.SimpleFileInode
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (h *hostname) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &hostnameFile{}), nil
+}
+
+var _ fs.InodeOperations = (*hostname)(nil)
+
+// +stateify savable
+type hostnameFile struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+// Read implements fs.FileOperations.Read.
+func (hf *hostnameFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ utsns := kernel.UTSNamespaceFromContext(ctx)
+ contents := []byte(utsns.HostName() + "\n")
+ if offset >= int64(len(contents)) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, contents[offset:])
+ return int64(n), err
+
+}
+
+var _ fs.FileOperations = (*hostnameFile)(nil)
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
new file mode 100644
index 000000000..dbf1a987c
--- /dev/null
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -0,0 +1,355 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "io"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type tcpMemDir int
+
+const (
+ tcpRMem tcpMemDir = iota
+ tcpWMem
+)
+
+// tcpMemInode is used to read/write the size of netstack tcp buffers.
+//
+// TODO(b/121381035): If we have multiple proc mounts, concurrent writes can
+// leave netstack and the proc files in an inconsistent state. Since we set the
+// buffer size from these proc files on restore, we may also race and end up in
+// an inconsistent state on restore.
+//
+// +stateify savable
+type tcpMemInode struct {
+ fsutil.SimpleFileInode
+ dir tcpMemDir
+ s inet.Stack `state:"wait"`
+
+ // size stores the tcp buffer size during save, and sets the buffer
+ // size in netstack in restore. We must save/restore this here, since
+ // netstack itself is stateless.
+ size inet.TCPBufferSize
+
+ // mu protects against concurrent reads/writes to files based on this
+ // inode.
+ mu sync.Mutex `state:"nosave"`
+}
+
+var _ fs.InodeOperations = (*tcpMemInode)(nil)
+
+func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir tcpMemDir) *fs.Inode {
+ tm := &tcpMemInode{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ s: s,
+ dir: dir,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(tm, msrc, sattr)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil
+}
+
+// +stateify savable
+type tcpMemFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ tcpMemInode *tcpMemInode
+}
+
+var _ fs.FileOperations = (*tcpMemFile)(nil)
+
+// Read implements fs.FileOperations.Read.
+func (f *tcpMemFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+ f.tcpMemInode.mu.Lock()
+ defer f.tcpMemInode.mu.Unlock()
+
+ size, err := readSize(f.tcpMemInode.dir, f.tcpMemInode.s)
+ if err != nil {
+ return 0, err
+ }
+ s := fmt.Sprintf("%d\t%d\t%d\n", size.Min, size.Default, size.Max)
+ n, err := dst.CopyOut(ctx, []byte(s))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *tcpMemFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+ f.tcpMemInode.mu.Lock()
+ defer f.tcpMemInode.mu.Unlock()
+
+ src = src.TakeFirst(usermem.PageSize - 1)
+ size, err := readSize(f.tcpMemInode.dir, f.tcpMemInode.s)
+ if err != nil {
+ return 0, err
+ }
+ buf := []int32{int32(size.Min), int32(size.Default), int32(size.Max)}
+ n, cperr := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, buf, src.Opts)
+ newSize := inet.TCPBufferSize{
+ Min: int(buf[0]),
+ Default: int(buf[1]),
+ Max: int(buf[2]),
+ }
+ if err := writeSize(f.tcpMemInode.dir, f.tcpMemInode.s, newSize); err != nil {
+ return n, err
+ }
+ return n, cperr
+}
+
+func readSize(dirType tcpMemDir, s inet.Stack) (inet.TCPBufferSize, error) {
+ switch dirType {
+ case tcpRMem:
+ return s.TCPReceiveBufferSize()
+ case tcpWMem:
+ return s.TCPSendBufferSize()
+ default:
+ panic(fmt.Sprintf("unknown tcpMemFile type: %v", dirType))
+ }
+}
+
+func writeSize(dirType tcpMemDir, s inet.Stack, size inet.TCPBufferSize) error {
+ switch dirType {
+ case tcpRMem:
+ return s.SetTCPReceiveBufferSize(size)
+ case tcpWMem:
+ return s.SetTCPSendBufferSize(size)
+ default:
+ panic(fmt.Sprintf("unknown tcpMemFile type: %v", dirType))
+ }
+}
+
+// +stateify savable
+type tcpSack struct {
+ stack inet.Stack `state:"wait"`
+ enabled *bool
+ fsutil.SimpleFileInode
+}
+
+func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ ts := &tcpSack{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ stack: s,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ts, msrc, sattr)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (s *tcpSack) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, dirent, flags, &tcpSackFile{
+ tcpSack: s,
+ stack: s.stack,
+ }), nil
+}
+
+// +stateify savable
+type tcpSackFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ tcpSack *tcpSack
+
+ stack inet.Stack `state:"wait"`
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *tcpSackFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+
+ if f.tcpSack.enabled == nil {
+ sack, err := f.stack.TCPSACKEnabled()
+ if err != nil {
+ return 0, err
+ }
+ f.tcpSack.enabled = &sack
+ }
+
+ val := "0\n"
+ if *f.tcpSack.enabled {
+ // Technically, this is not quite compatible with Linux. Linux
+ // stores these as an integer, so if you write "2" into
+ // tcp_sack, you should get 2 back. Tough luck.
+ val = "1\n"
+ }
+ n, err := dst.CopyOut(ctx, []byte(val))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return n, err
+ }
+ if f.tcpSack.enabled == nil {
+ f.tcpSack.enabled = new(bool)
+ }
+ *f.tcpSack.enabled = v != 0
+ return n, f.tcpSack.stack.SetTCPSACKEnabled(*f.tcpSack.enabled)
+}
+
+func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ // The following files are simple stubs until they are implemented in
+ // netstack, most of these files are configuration related. We use the
+ // value closest to the actual netstack behavior or any empty file,
+ // all of these files will have mode 0444 (read-only for all users).
+ contents := map[string]*fs.Inode{
+ "default_qdisc": newStaticProcInode(ctx, msrc, []byte("pfifo_fast")),
+ "message_burst": newStaticProcInode(ctx, msrc, []byte("10")),
+ "message_cost": newStaticProcInode(ctx, msrc, []byte("5")),
+ "optmem_max": newStaticProcInode(ctx, msrc, []byte("0")),
+ "rmem_default": newStaticProcInode(ctx, msrc, []byte("212992")),
+ "rmem_max": newStaticProcInode(ctx, msrc, []byte("212992")),
+ "somaxconn": newStaticProcInode(ctx, msrc, []byte("128")),
+ "wmem_default": newStaticProcInode(ctx, msrc, []byte("212992")),
+ "wmem_max": newStaticProcInode(ctx, msrc, []byte("212992")),
+ }
+
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ // Add tcp_sack.
+ "tcp_sack": newTCPSackInode(ctx, msrc, s),
+
+ // The following files are simple stubs until they are
+ // implemented in netstack, most of these files are
+ // configuration related. We use the value closest to the
+ // actual netstack behavior or any empty file, all of these
+ // files will have mode 0444 (read-only for all users).
+ "ip_local_port_range": newStaticProcInode(ctx, msrc, []byte("16000 65535")),
+ "ip_local_reserved_ports": newStaticProcInode(ctx, msrc, []byte("")),
+ "ipfrag_time": newStaticProcInode(ctx, msrc, []byte("30")),
+ "ip_nonlocal_bind": newStaticProcInode(ctx, msrc, []byte("0")),
+ "ip_no_pmtu_disc": newStaticProcInode(ctx, msrc, []byte("1")),
+
+ // tcp_allowed_congestion_control tell the user what they are
+ // able to do as an unprivledged process so we leave it empty.
+ "tcp_allowed_congestion_control": newStaticProcInode(ctx, msrc, []byte("")),
+ "tcp_available_congestion_control": newStaticProcInode(ctx, msrc, []byte("reno")),
+ "tcp_congestion_control": newStaticProcInode(ctx, msrc, []byte("reno")),
+
+ // Many of the following stub files are features netstack
+ // doesn't support. The unsupported features return "0" to
+ // indicate they are disabled.
+ "tcp_base_mss": newStaticProcInode(ctx, msrc, []byte("1280")),
+ "tcp_dsack": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_early_retrans": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_fack": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_fastopen": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_fastopen_key": newStaticProcInode(ctx, msrc, []byte("")),
+ "tcp_invalid_ratelimit": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_keepalive_intvl": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_keepalive_probes": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_keepalive_time": newStaticProcInode(ctx, msrc, []byte("7200")),
+ "tcp_mtu_probing": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_no_metrics_save": newStaticProcInode(ctx, msrc, []byte("1")),
+ "tcp_probe_interval": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_probe_threshold": newStaticProcInode(ctx, msrc, []byte("0")),
+ "tcp_retries1": newStaticProcInode(ctx, msrc, []byte("3")),
+ "tcp_retries2": newStaticProcInode(ctx, msrc, []byte("15")),
+ "tcp_rfc1337": newStaticProcInode(ctx, msrc, []byte("1")),
+ "tcp_slow_start_after_idle": newStaticProcInode(ctx, msrc, []byte("1")),
+ "tcp_synack_retries": newStaticProcInode(ctx, msrc, []byte("5")),
+ "tcp_syn_retries": newStaticProcInode(ctx, msrc, []byte("3")),
+ "tcp_timestamps": newStaticProcInode(ctx, msrc, []byte("1")),
+ }
+
+ // Add tcp_rmem.
+ if _, err := s.TCPReceiveBufferSize(); err == nil {
+ contents["tcp_rmem"] = newTCPMemInode(ctx, msrc, s, tcpRMem)
+ }
+
+ // Add tcp_wmem.
+ if _, err := s.TCPSendBufferSize(); err == nil {
+ contents["tcp_wmem"] = newTCPMemInode(ctx, msrc, s, tcpWMem)
+ }
+
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
+
+func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ var contents map[string]*fs.Inode
+ if s := p.k.NetworkStack(); s != nil {
+ contents = map[string]*fs.Inode{
+ "ipv4": p.newSysNetIPv4Dir(ctx, msrc, s),
+ "core": p.newSysNetCore(ctx, msrc, s),
+ }
+ }
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(d, msrc, fs.SpecialDirectory, nil)
+}
diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go
new file mode 100644
index 000000000..6eba709c6
--- /dev/null
+++ b/pkg/sentry/fs/proc/sys_net_state.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import "fmt"
+
+// beforeSave is invoked by stateify.
+func (t *tcpMemInode) beforeSave() {
+ size, err := readSize(t.dir, t.s)
+ if err != nil {
+ panic(fmt.Sprintf("failed to read TCP send / receive buffer sizes: %v", err))
+ }
+ t.size = size
+}
+
+// afterLoad is invoked by stateify.
+func (t *tcpMemInode) afterLoad() {
+ if err := writeSize(t.dir, t.s, t.size); err != nil {
+ panic(fmt.Sprintf("failed to write previous TCP send / receive buffer sizes [%v]: %v", t.size, err))
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (s *tcpSack) afterLoad() {
+ if s.enabled != nil {
+ if err := s.stack.SetTCPSACKEnabled(*s.enabled); err != nil {
+ panic(fmt.Sprintf("failed to set previous TCP sack configuration [%v]: %v", *s.enabled, err))
+ }
+ }
+}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
new file mode 100644
index 000000000..77e03d349
--- /dev/null
+++ b/pkg/sentry/fs/proc/task.go
@@ -0,0 +1,776 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "sort"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's
+// users count is incremented, and must be decremented by the caller when it is
+// no longer in use.
+func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) {
+ if t.ExitState() == kernel.TaskExitDead {
+ return nil, syserror.ESRCH
+ }
+ var m *mm.MemoryManager
+ t.WithMuLocked(func(t *kernel.Task) {
+ m = t.MemoryManager()
+ })
+ if m == nil || !m.IncUsers() {
+ return nil, io.EOF
+ }
+ return m, nil
+}
+
+// taskDir represents a task-level directory.
+//
+// +stateify savable
+type taskDir struct {
+ ramfs.Dir
+
+ t *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+var _ fs.InodeOperations = (*taskDir)(nil)
+
+// newTaskDir creates a new proc task entry.
+func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "auxv": newAuxvec(t, msrc),
+ "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
+ "comm": newComm(t, msrc),
+ "environ": newExecArgInode(t, msrc, environExecArg),
+ "exe": newExe(t, msrc),
+ "fd": newFdDir(t, msrc),
+ "fdinfo": newFdInfoDir(t, msrc),
+ "gid_map": newGIDMap(t, msrc),
+ // FIXME(b/123511468): create the correct io file for threads.
+ "io": newIO(t, msrc),
+ "maps": newMaps(t, msrc),
+ "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
+ "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ "ns": newNamespaceDir(t, msrc),
+ "smaps": newSmaps(t, msrc),
+ "stat": newTaskStat(t, msrc, showSubtasks, p.pidns),
+ "statm": newStatm(t, msrc),
+ "status": newStatus(t, msrc, p.pidns),
+ "uid_map": newUIDMap(t, msrc),
+ }
+ if showSubtasks {
+ contents["task"] = p.newSubtasks(t, msrc)
+ }
+ if len(p.cgroupControllers) > 0 {
+ contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers)
+ }
+
+ // TODO(b/31916171): Set EUID/EGID based on dumpability.
+ d := &taskDir{
+ Dir: *ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)),
+ t: t,
+ }
+ return newProcInode(d, msrc, fs.SpecialDirectory, t)
+}
+
+// subtasks represents a /proc/TID/task directory.
+//
+// +stateify savable
+type subtasks struct {
+ ramfs.Dir
+
+ t *kernel.Task
+ p *proc
+}
+
+var _ fs.InodeOperations = (*subtasks)(nil)
+
+func (p *proc) newSubtasks(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ s := &subtasks{
+ Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermsFromMode(0555)),
+ t: t,
+ p: p,
+ }
+ return newProcInode(s, msrc, fs.SpecialDirectory, t)
+}
+
+// UnstableAttr returns unstable attributes of the subtasks.
+func (s *subtasks) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ uattr, err := s.Dir.UnstableAttr(ctx, inode)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ // We can't rely on ramfs' implementation because the task directories are
+ // generated dynamically.
+ uattr.Links = uint64(2 + s.t.ThreadGroup().Count())
+ return uattr, nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (s *subtasks) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &subtasksFile{t: s.t, pidns: s.p.pidns}), nil
+}
+
+// +stateify savable
+type subtasksFile struct {
+ fsutil.DirFileOperations `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ t *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+// Readdir implements fs.FileOperations.Readdir.
+func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) {
+ dirCtx := fs.DirCtx{
+ Serializer: ser,
+ }
+
+ // Note that unlike most Readdir implementations, the offset here is
+ // not an index into the subtasks, but rather the TID of the next
+ // subtask to emit.
+ offset := file.Offset()
+
+ if offset == 0 {
+ // Serialize "." and "..".
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dot, dotdot := file.Dirent.GetDotAttrs(root)
+ if err := dirCtx.DirEmit(".", dot); err != nil {
+ return offset, err
+ }
+ if err := dirCtx.DirEmit("..", dotdot); err != nil {
+ return offset, err
+ }
+ }
+
+ // Serialize tasks.
+ tasks := f.t.ThreadGroup().MemberIDs(f.pidns)
+ taskInts := make([]int, 0, len(tasks))
+ for _, tid := range tasks {
+ taskInts = append(taskInts, int(tid))
+ }
+
+ // Find the task to start at.
+ idx := sort.SearchInts(taskInts, int(offset))
+ if idx == len(taskInts) {
+ return offset, nil
+ }
+ taskInts = taskInts[idx:]
+
+ var tid int
+ for _, tid = range taskInts {
+ name := strconv.FormatUint(uint64(tid), 10)
+ attr := fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice)
+ if err := dirCtx.DirEmit(name, attr); err != nil {
+ // Returned offset is next tid to serialize.
+ return int64(tid), err
+ }
+ }
+ // We serialized them all. Next offset should be higher than last
+ // serialized tid.
+ return int64(tid) + 1, nil
+}
+
+var _ fs.FileOperations = (*subtasksFile)(nil)
+
+// Lookup loads an Inode in a task's subtask directory into a Dirent.
+func (s *subtasks) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) {
+ tid, err := strconv.ParseUint(p, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+
+ task := s.p.pidns.TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return nil, syserror.ENOENT
+ }
+ if task.ThreadGroup() != s.t.ThreadGroup() {
+ return nil, syserror.ENOENT
+ }
+
+ td := s.p.newTaskDir(task, dir.MountSource, false)
+ return fs.NewDirent(td, p), nil
+}
+
+// exe is an fs.InodeOperations symlink for the /proc/PID/exe file.
+//
+// +stateify savable
+type exe struct {
+ ramfs.Symlink
+
+ t *kernel.Task
+}
+
+func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ exeSymlink := &exe{
+ Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""),
+ t: t,
+ }
+ return newProcInode(exeSymlink, msrc, fs.Symlink, t)
+}
+
+func (e *exe) executable() (d *fs.Dirent, err error) {
+ e.t.WithMuLocked(func(t *kernel.Task) {
+ mm := t.MemoryManager()
+ if mm == nil {
+ // TODO(b/34851096): Check shouldn't allow Readlink once the
+ // Task is zombied.
+ err = syserror.EACCES
+ return
+ }
+
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ d = mm.Executable()
+ if d == nil {
+ err = syserror.ENOENT
+ }
+ })
+ return
+}
+
+// Readlink implements fs.InodeOperations.
+func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if !kernel.ContextCanTrace(ctx, e.t, false) {
+ return "", syserror.EACCES
+ }
+
+ // Pull out the executable for /proc/TID/exe.
+ exec, err := e.executable()
+ if err != nil {
+ return "", err
+ }
+ defer exec.DecRef()
+
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // This doesn't correspond to anything in Linux because the vfs is
+ // global there.
+ return "", syserror.EINVAL
+ }
+ defer root.DecRef()
+ n, _ := exec.FullName(root)
+ return n, nil
+}
+
+// namespaceSymlink represents a symlink in the namespacefs, such as the files
+// in /proc/<pid>/ns.
+//
+// +stateify savable
+type namespaceSymlink struct {
+ ramfs.Symlink
+
+ t *kernel.Task
+}
+
+func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs.Inode {
+ // TODO(rahat): Namespace symlinks should contain the namespace name and the
+ // inode number for the namespace instance, so for example user:[123456]. We
+ // currently fake the inode number by sticking the symlink inode in its
+ // place.
+ target := fmt.Sprintf("%s:[%d]", name, device.ProcDevice.NextIno())
+ n := &namespaceSymlink{
+ Symlink: *ramfs.NewSymlink(t, fs.RootOwner, target),
+ t: t,
+ }
+ return newProcInode(n, msrc, fs.Symlink, t)
+}
+
+// Getlink implements fs.InodeOperations.Getlink.
+func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Dirent, error) {
+ if !kernel.ContextCanTrace(ctx, n.t, false) {
+ return nil, syserror.EACCES
+ }
+
+ // Create a new regular file to fake the namespace file.
+ iops := fsutil.NewNoReadWriteFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0777), linux.PROC_SUPER_MAGIC)
+ return fs.NewDirent(newProcInode(iops, inode.MountSource, fs.RegularFile, nil), n.Symlink.Target), nil
+}
+
+func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ contents := map[string]*fs.Inode{
+ "net": newNamespaceSymlink(t, msrc, "net"),
+ "pid": newNamespaceSymlink(t, msrc, "pid"),
+ "user": newNamespaceSymlink(t, msrc, "user"),
+ }
+ d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0511))
+ return newProcInode(d, msrc, fs.SpecialDirectory, t)
+}
+
+// mapsData implements seqfile.SeqSource for /proc/[pid]/maps.
+//
+// +stateify savable
+type mapsData struct {
+ t *kernel.Task
+}
+
+func newMaps(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newProcInode(seqfile.NewSeqFile(t, &mapsData{t}), msrc, fs.SpecialFile, t)
+}
+
+func (md *mapsData) mm() *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ md.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ // No additional reference is taken on mm here. This is safe
+ // because MemoryManager.destroy is required to leave the
+ // MemoryManager in a state where it's still usable as a SeqSource.
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (md *mapsData) NeedsUpdate(generation int64) bool {
+ if mm := md.mm(); mm != nil {
+ return mm.NeedsUpdate(generation)
+ }
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (md *mapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if mm := md.mm(); mm != nil {
+ return mm.ReadMapsSeqFileData(ctx, h)
+ }
+ return []seqfile.SeqData{}, 0
+}
+
+// smapsData implements seqfile.SeqSource for /proc/[pid]/smaps.
+//
+// +stateify savable
+type smapsData struct {
+ t *kernel.Task
+}
+
+func newSmaps(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newProcInode(seqfile.NewSeqFile(t, &smapsData{t}), msrc, fs.SpecialFile, t)
+}
+
+func (sd *smapsData) mm() *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ sd.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ // No additional reference is taken on mm here. This is safe
+ // because MemoryManager.destroy is required to leave the
+ // MemoryManager in a state where it's still usable as a SeqSource.
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (sd *smapsData) NeedsUpdate(generation int64) bool {
+ if mm := sd.mm(); mm != nil {
+ return mm.NeedsUpdate(generation)
+ }
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (sd *smapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if mm := sd.mm(); mm != nil {
+ return mm.ReadSmapsSeqFileData(ctx, h)
+ }
+ return []seqfile.SeqData{}, 0
+}
+
+// +stateify savable
+type taskStatData struct {
+ t *kernel.Task
+
+ // If tgstats is true, accumulate fault stats (not implemented) and CPU
+ // time across all tasks in t's thread group.
+ tgstats bool
+
+ // pidns is the PID namespace associated with the proc filesystem that
+ // includes the file using this statData.
+ pidns *kernel.PIDNamespace
+}
+
+func newTaskStat(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool, pidns *kernel.PIDNamespace) *fs.Inode {
+ return newProcInode(seqfile.NewSeqFile(t, &taskStatData{t, showSubtasks /* tgstats */, pidns}), msrc, fs.SpecialFile, t)
+}
+
+// NeedsUpdate returns whether the generation is old or not.
+func (s *taskStatData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData returns data for the SeqFile reader.
+// SeqData, the current generation and where in the file the handle corresponds to.
+func (s *taskStatData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+
+ fmt.Fprintf(&buf, "%d ", s.pidns.IDOfTask(s.t))
+ fmt.Fprintf(&buf, "(%s) ", s.t.Name())
+ fmt.Fprintf(&buf, "%c ", s.t.StateStatus()[0])
+ ppid := kernel.ThreadID(0)
+ if parent := s.t.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(&buf, "%d ", ppid)
+ fmt.Fprintf(&buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup()))
+ fmt.Fprintf(&buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session()))
+ fmt.Fprintf(&buf, "0 0 " /* tty_nr tpgid */)
+ fmt.Fprintf(&buf, "0 " /* flags */)
+ fmt.Fprintf(&buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
+ var cputime usage.CPUStats
+ if s.tgstats {
+ cputime = s.t.ThreadGroup().CPUStats()
+ } else {
+ cputime = s.t.CPUStats()
+ }
+ fmt.Fprintf(&buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ cputime = s.t.ThreadGroup().JoinedChildCPUStats()
+ fmt.Fprintf(&buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ fmt.Fprintf(&buf, "%d %d ", s.t.Priority(), s.t.Niceness())
+ fmt.Fprintf(&buf, "%d ", s.t.ThreadGroup().Count())
+
+ // itrealvalue. Since kernel 2.6.17, this field is no longer
+ // maintained, and is hard coded as 0.
+ fmt.Fprintf(&buf, "0 ")
+
+ // Start time is relative to boot time, expressed in clock ticks.
+ fmt.Fprintf(&buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
+
+ var vss, rss uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+ fmt.Fprintf(&buf, "%d %d ", vss, rss/usermem.PageSize)
+
+ // rsslim.
+ fmt.Fprintf(&buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
+
+ fmt.Fprintf(&buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
+ fmt.Fprintf(&buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
+ fmt.Fprintf(&buf, "0 0 " /* nswap cnswap */)
+ terminationSignal := linux.Signal(0)
+ if s.t == s.t.ThreadGroup().Leader() {
+ terminationSignal = s.t.ThreadGroup().TerminationSignal()
+ }
+ fmt.Fprintf(&buf, "%d ", terminationSignal)
+ fmt.Fprintf(&buf, "0 0 0 " /* processor rt_priority policy */)
+ fmt.Fprintf(&buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
+ fmt.Fprintf(&buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
+ fmt.Fprintf(&buf, "0\n" /* exit_code */)
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*taskStatData)(nil)}}, 0
+}
+
+// statmData implements seqfile.SeqSource for /proc/[pid]/statm.
+//
+// +stateify savable
+type statmData struct {
+ t *kernel.Task
+}
+
+func newStatm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newProcInode(seqfile.NewSeqFile(t, &statmData{t}), msrc, fs.SpecialFile, t)
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (s *statmData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (s *statmData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var vss, rss uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statmData)(nil)}}, 0
+}
+
+// statusData implements seqfile.SeqSource for /proc/[pid]/status.
+//
+// +stateify savable
+type statusData struct {
+ t *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+func newStatus(t *kernel.Task, msrc *fs.MountSource, pidns *kernel.PIDNamespace) *fs.Inode {
+ return newProcInode(seqfile.NewSeqFile(t, &statusData{t, pidns}), msrc, fs.SpecialFile, t)
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (s *statusData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "Name:\t%s\n", s.t.Name())
+ fmt.Fprintf(&buf, "State:\t%s\n", s.t.StateStatus())
+ fmt.Fprintf(&buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup()))
+ fmt.Fprintf(&buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t))
+ ppid := kernel.ThreadID(0)
+ if parent := s.t.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(&buf, "PPid:\t%d\n", ppid)
+ tpid := kernel.ThreadID(0)
+ if tracer := s.t.Tracer(); tracer != nil {
+ tpid = s.pidns.IDOfTask(tracer)
+ }
+ fmt.Fprintf(&buf, "TracerPid:\t%d\n", tpid)
+ var fds int
+ var vss, rss, data uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if fdm := t.FDMap(); fdm != nil {
+ fds = fdm.Size()
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
+ })
+ fmt.Fprintf(&buf, "FDSize:\t%d\n", fds)
+ fmt.Fprintf(&buf, "VmSize:\t%d kB\n", vss>>10)
+ fmt.Fprintf(&buf, "VmRSS:\t%d kB\n", rss>>10)
+ fmt.Fprintf(&buf, "VmData:\t%d kB\n", data>>10)
+ fmt.Fprintf(&buf, "Threads:\t%d\n", s.t.ThreadGroup().Count())
+ creds := s.t.Credentials()
+ fmt.Fprintf(&buf, "CapInh:\t%016x\n", creds.InheritableCaps)
+ fmt.Fprintf(&buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
+ fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
+ fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
+ fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0
+}
+
+// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
+type ioUsage interface {
+ // IOUsage returns the io usage data.
+ IOUsage() *usage.IO
+}
+
+// +stateify savable
+type ioData struct {
+ ioUsage
+}
+
+func newIO(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newProcInode(seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t)
+}
+
+// NeedsUpdate returns whether the generation is old or not.
+func (i *ioData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData returns data for the SeqFile reader.
+// SeqData, the current generation and where in the file the handle corresponds to.
+func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ io := usage.IO{}
+ io.Accumulate(i.IOUsage())
+
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten)
+ fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls)
+ fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls)
+ fmt.Fprintf(&buf, "read_bytes: %d\n", io.BytesRead)
+ fmt.Fprintf(&buf, "write_bytes: %d\n", io.BytesWritten)
+ fmt.Fprintf(&buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
+
+ return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*ioData)(nil)}}, 0
+}
+
+// comm is a file containing the command name for a task.
+//
+// On Linux, /proc/[pid]/comm is writable, and writing to the comm file changes
+// the thread name. We don't implement this yet as there are no known users of
+// this feature.
+//
+// +stateify savable
+type comm struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// newComm returns a new comm file.
+func newComm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ c := &comm{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(c, msrc, fs.SpecialFile, t)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (c *comm) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &commFile{t: c.t}), nil
+}
+
+// +stateify savable
+type commFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+var _ fs.FileOperations = (*commFile)(nil)
+
+// Read implements fs.FileOperations.Read.
+func (f *commFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ buf := []byte(f.t.Name() + "\n")
+ if offset >= int64(len(buf)) {
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOut(ctx, buf[offset:])
+ return int64(n), err
+}
+
+// auxvec is a file containing the auxiliary vector for a task.
+//
+// +stateify savable
+type auxvec struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// newAuxvec returns a new auxvec file.
+func newAuxvec(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ a := &auxvec{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(a, msrc, fs.SpecialFile, t)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (a *auxvec) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &auxvecFile{t: a.t}), nil
+}
+
+// +stateify savable
+type auxvecFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ m, err := getTaskMM(f.t)
+ if err != nil {
+ return 0, err
+ }
+ defer m.DecUsers(ctx)
+ auxv := m.Auxv()
+
+ // Space for buffer with AT_NULL (0) terminator at the end.
+ size := (len(auxv) + 1) * 16
+ if offset >= int64(size) {
+ return 0, io.EOF
+ }
+
+ buf := make([]byte, size)
+ for i, e := range auxv {
+ usermem.ByteOrder.PutUint64(buf[16*i:], e.Key)
+ usermem.ByteOrder.PutUint64(buf[16*i+8:], uint64(e.Value))
+ }
+
+ n, err := dst.CopyOut(ctx, buf[offset:])
+ return int64(n), err
+}
diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go
new file mode 100644
index 000000000..a14b1b45f
--- /dev/null
+++ b/pkg/sentry/fs/proc/uid_gid_map.go
@@ -0,0 +1,179 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// idMapInodeOperations implements fs.InodeOperations for
+// /proc/[pid]/{uid,gid}_map.
+//
+// +stateify savable
+type idMapInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeSimpleExtendedAttributes
+
+ t *kernel.Task
+ gids bool
+}
+
+var _ fs.InodeOperations = (*idMapInodeOperations)(nil)
+
+// newUIDMap returns a new uid_map file.
+func newUIDMap(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newIDMap(t, msrc, false /* gids */)
+}
+
+// newGIDMap returns a new gid_map file.
+func newGIDMap(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newIDMap(t, msrc, true /* gids */)
+}
+
+func newIDMap(t *kernel.Task, msrc *fs.MountSource, gids bool) *fs.Inode {
+ return newProcInode(&idMapInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ t: t,
+ gids: gids,
+ }, msrc, fs.SpecialFile, t)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (imio *idMapInodeOperations) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &idMapFileOperations{
+ iops: imio,
+ }), nil
+}
+
+// +stateify savable
+type idMapFileOperations struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ iops *idMapInodeOperations
+}
+
+var _ fs.FileOperations = (*idMapFileOperations)(nil)
+
+// "There is an (arbitrary) limit on the number of lines in the file. As at
+// Linux 3.18, the limit is five lines." - user_namespaces(7)
+const maxIDMapLines = 5
+
+// Read implements fs.FileOperations.Read.
+func (imfo *idMapFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ var entries []auth.IDMapEntry
+ if imfo.iops.gids {
+ entries = imfo.iops.t.UserNamespace().GIDMap()
+ } else {
+ entries = imfo.iops.t.UserNamespace().UIDMap()
+ }
+ var buf bytes.Buffer
+ for _, e := range entries {
+ fmt.Fprintf(&buf, "%10d %10d %10d\n", e.FirstID, e.FirstParentID, e.Length)
+ }
+ if offset >= int64(buf.Len()) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, buf.Bytes()[offset:])
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ // "In addition, the number of bytes written to the file must be less than
+ // the system page size, and the write must be performed at the start of
+ // the file ..." - user_namespaces(7)
+ srclen := src.NumBytes()
+ if srclen >= usermem.PageSize || offset != 0 {
+ return 0, syserror.EINVAL
+ }
+ b := make([]byte, srclen)
+ if _, err := src.CopyIn(ctx, b); err != nil {
+ return 0, err
+ }
+
+ // Truncate from the first NULL byte.
+ var nul int64
+ nul = int64(bytes.IndexByte(b, 0))
+ if nul == -1 {
+ nul = srclen
+ }
+ b = b[:nul]
+ // Remove the last \n.
+ if nul >= 1 && b[nul-1] == '\n' {
+ b = b[:nul-1]
+ }
+ lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1)
+ if len(lines) > maxIDMapLines {
+ return 0, syserror.EINVAL
+ }
+
+ entries := make([]auth.IDMapEntry, len(lines))
+ for i, l := range lines {
+ var e auth.IDMapEntry
+ _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length)
+ if err != nil {
+ return 0, syserror.EINVAL
+ }
+ entries[i] = e
+ }
+ var err error
+ if imfo.iops.gids {
+ err = imfo.iops.t.UserNamespace().SetGIDMap(ctx, entries)
+ } else {
+ err = imfo.iops.t.UserNamespace().SetUIDMap(ctx, entries)
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ // On success, Linux's kernel/user_namespace.c:map_write() always returns
+ // count, even if fewer bytes were used.
+ return int64(srclen), nil
+}
diff --git a/pkg/sentry/fs/proc/uptime.go b/pkg/sentry/fs/proc/uptime.go
new file mode 100644
index 000000000..35c3851e1
--- /dev/null
+++ b/pkg/sentry/fs/proc/uptime.go
@@ -0,0 +1,87 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// uptime is a file containing the system uptime.
+//
+// +stateify savable
+type uptime struct {
+ fsutil.SimpleFileInode
+
+ // The "start time" of the sandbox.
+ startTime ktime.Time
+}
+
+// newUptime returns a new uptime file.
+func newUptime(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ u := &uptime{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ startTime: ktime.NowFromContext(ctx),
+ }
+ return newProcInode(u, msrc, fs.SpecialFile, nil)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (u *uptime) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &uptimeFile{startTime: u.startTime}), nil
+}
+
+// +stateify savable
+type uptimeFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ startTime ktime.Time
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *uptimeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ now := ktime.NowFromContext(ctx)
+ // Pretend that we've spent zero time sleeping (second number).
+ s := []byte(fmt.Sprintf("%.2f 0.00\n", now.Sub(f.startTime).Seconds()))
+ if offset >= int64(len(s)) {
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOut(ctx, s[offset:])
+ return int64(n), err
+}
diff --git a/pkg/sentry/fs/proc/version.go b/pkg/sentry/fs/proc/version.go
new file mode 100644
index 000000000..a5479990c
--- /dev/null
+++ b/pkg/sentry/fs/proc/version.go
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+)
+
+// versionData backs /proc/version.
+//
+// +stateify savable
+type versionData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*versionData) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (v *versionData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ init := v.k.GlobalInit()
+ if init == nil {
+ // Attempted to read before the init Task is created. This can
+ // only occur during startup, which should never need to read
+ // this file.
+ panic("Attempted to read version before initial Task is available")
+ }
+
+ // /proc/version takes the form:
+ //
+ // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
+ // (COMPILER_VERSION) VERSION"
+ //
+ // where:
+ // - SYSNAME, RELEASE, and VERSION are the same as returned by
+ // sys_utsname
+ // - COMPILE_USER is the user that build the kernel
+ // - COMPILE_HOST is the hostname of the machine on which the kernel
+ // was built
+ // - COMPILER_VERSION is the version reported by the building compiler
+ //
+ // Since we don't really want to expose build information to
+ // applications, those fields are omitted.
+ //
+ // FIXME(mpratt): Using Version from the init task SyscallTable
+ // disregards the different version a task may have (e.g., in a uts
+ // namespace).
+ ver := init.Leader().SyscallTable().Version
+ return []seqfile.SeqData{
+ {
+ Buf: []byte(fmt.Sprintf("%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)),
+ Handle: (*versionData)(nil),
+ },
+ }, 0
+}
diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go
new file mode 100644
index 000000000..cd6e03d66
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/dir.go
@@ -0,0 +1,534 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ramfs
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// CreateOps represents operations to create different file types.
+type CreateOps struct {
+ // NewDir creates a new directory.
+ NewDir func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error)
+
+ // NewFile creates a new file.
+ NewFile func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error)
+
+ // NewSymlink creates a new symlink with permissions 0777.
+ NewSymlink func(ctx context.Context, dir *fs.Inode, target string) (*fs.Inode, error)
+
+ // NewBoundEndpoint creates a new socket.
+ NewBoundEndpoint func(ctx context.Context, dir *fs.Inode, ep transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Inode, error)
+
+ // NewFifo creates a new fifo.
+ NewFifo func(ctx context.Context, dir *fs.Inode, perm fs.FilePermissions) (*fs.Inode, error)
+}
+
+// Dir represents a single directory in the filesystem.
+//
+// +stateify savable
+type Dir struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeIsDirAllocate `state:"nosave"`
+ fsutil.InodeIsDirTruncate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeSimpleExtendedAttributes
+
+ // CreateOps may be provided.
+ //
+ // These may only be modified during initialization (while the application
+ // is not running). No sychronization is performed when accessing these
+ // operations during syscalls.
+ *CreateOps `state:"nosave"`
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // children are inodes that are in this directory. A reference is held
+ // on each inode while it is in the map.
+ children map[string]*fs.Inode
+
+ // dentryMap is a sortedDentryMap containing entries for all children.
+ // Its entries are kept up-to-date with d.children.
+ dentryMap *fs.SortedDentryMap
+}
+
+var _ fs.InodeOperations = (*Dir)(nil)
+
+// NewDir returns a new Dir with the given contents and attributes.
+func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions) *Dir {
+ d := &Dir{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, perms, linux.RAMFS_MAGIC),
+ }
+
+ if contents == nil {
+ contents = make(map[string]*fs.Inode)
+ }
+ d.children = contents
+
+ // Build the entries map ourselves, rather than calling addChildLocked,
+ // because it will be faster.
+ entries := make(map[string]fs.DentAttr, len(contents))
+ for name, inode := range contents {
+ entries[name] = fs.DentAttr{
+ Type: inode.StableAttr.Type,
+ InodeID: inode.StableAttr.InodeID,
+ }
+ }
+ d.dentryMap = fs.NewSortedDentryMap(entries)
+
+ // Directories have an extra link, corresponding to '.'.
+ d.AddLink()
+
+ return d
+}
+
+// addChildLocked add the child inode, inheriting its reference.
+func (d *Dir) addChildLocked(ctx context.Context, name string, inode *fs.Inode) {
+ d.children[name] = inode
+ d.dentryMap.Add(name, fs.DentAttr{
+ Type: inode.StableAttr.Type,
+ InodeID: inode.StableAttr.InodeID,
+ })
+
+ // If the child is a directory, increment this dir's link count,
+ // corresponding to '..' from the subdirectory.
+ if fs.IsDir(inode.StableAttr) {
+ d.AddLink()
+ // ctime updated below.
+ }
+
+ // Given we're now adding this inode to the directory we must also
+ // increase its link count. Similarly we decrement it in removeChildLocked.
+ //
+ // Changing link count updates ctime.
+ inode.AddLink()
+ inode.InodeOperations.NotifyStatusChange(ctx)
+
+ // We've change the directory. This always updates our mtime and ctime.
+ d.NotifyModificationAndStatusChange(ctx)
+}
+
+// AddChild adds a child to this dir.
+func (d *Dir) AddChild(ctx context.Context, name string, inode *fs.Inode) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.addChildLocked(ctx, name, inode)
+}
+
+// FindChild returns (child, true) if the directory contains name.
+func (d *Dir) FindChild(name string) (*fs.Inode, bool) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ child, ok := d.children[name]
+ return child, ok
+}
+
+// Children returns the names and DentAttrs of all children. It can be used to
+// implement Readdir for types that embed ramfs.Dir.
+func (d *Dir) Children() ([]string, map[string]fs.DentAttr) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Return a copy to prevent callers from modifying our children.
+ names, entries := d.dentryMap.GetAll()
+ namesCopy := make([]string, len(names))
+ copy(namesCopy, names)
+
+ entriesCopy := make(map[string]fs.DentAttr)
+ for k, v := range entries {
+ entriesCopy[k] = v
+ }
+
+ return namesCopy, entriesCopy
+}
+
+// removeChildLocked attempts to remove an entry from this directory.
+func (d *Dir) removeChildLocked(ctx context.Context, name string) (*fs.Inode, error) {
+ inode, ok := d.children[name]
+ if !ok {
+ return nil, syserror.EACCES
+ }
+
+ delete(d.children, name)
+ d.dentryMap.Remove(name)
+ d.NotifyModification(ctx)
+
+ // If the child was a subdirectory, then we must decrement this dir's
+ // link count which was the child's ".." directory entry.
+ if fs.IsDir(inode.StableAttr) {
+ d.DropLink()
+ // ctime changed below.
+ }
+
+ // Given we're now removing this inode to the directory we must also
+ // decrease its link count. Similarly it is increased in addChildLocked.
+ //
+ // Changing link count updates ctime.
+ inode.DropLink()
+ inode.InodeOperations.NotifyStatusChange(ctx)
+
+ // We've change the directory. This always updates our mtime and ctime.
+ d.NotifyModificationAndStatusChange(ctx)
+
+ return inode, nil
+}
+
+// Remove removes the named non-directory.
+func (d *Dir) Remove(ctx context.Context, _ *fs.Inode, name string) error {
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ inode, err := d.removeChildLocked(ctx, name)
+ if err != nil {
+ return err
+ }
+
+ // Remove our reference on the inode.
+ inode.DecRef()
+ return nil
+}
+
+// RemoveDirectory removes the named directory.
+func (d *Dir) RemoveDirectory(ctx context.Context, _ *fs.Inode, name string) error {
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Get the child and make sure it is not empty.
+ childInode, err := d.walkLocked(ctx, name)
+ if err != nil {
+ return err
+ }
+ if ok, err := hasChildren(ctx, childInode); err != nil {
+ return err
+ } else if ok {
+ return syserror.ENOTEMPTY
+ }
+
+ // Child was empty. Proceed with removal.
+ inode, err := d.removeChildLocked(ctx, name)
+ if err != nil {
+ return err
+ }
+
+ // Remove our reference on the inode.
+ inode.DecRef()
+
+ return nil
+}
+
+// Lookup loads an inode at p into a Dirent.
+func (d *Dir) Lookup(ctx context.Context, _ *fs.Inode, p string) (*fs.Dirent, error) {
+ if len(p) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ inode, err := d.walkLocked(ctx, p)
+ if err != nil {
+ return nil, err
+ }
+
+ // Take a reference on the inode before returning it. This reference
+ // is owned by the dirent we are about to create.
+ inode.IncRef()
+ return fs.NewDirent(inode, p), nil
+}
+
+// walkLocked must be called with d.mu held.
+func (d *Dir) walkLocked(ctx context.Context, p string) (*fs.Inode, error) {
+ // Lookup a child node.
+ if inode, ok := d.children[p]; ok {
+ return inode, nil
+ }
+
+ // fs.InodeOperations.Lookup returns syserror.ENOENT if p
+ // does not exist.
+ return nil, syserror.ENOENT
+}
+
+// createInodeOperationsCommon creates a new child node at this dir by calling
+// makeInodeOperations. It is the common logic for creating a new child.
+func (d *Dir) createInodeOperationsCommon(ctx context.Context, name string, makeInodeOperations func() (*fs.Inode, error)) (*fs.Inode, error) {
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ inode, err := makeInodeOperations()
+ if err != nil {
+ return nil, err
+ }
+
+ d.addChildLocked(ctx, name, inode)
+
+ return inode, nil
+}
+
+// Create creates a new Inode with the given name and returns its File.
+func (d *Dir) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perms fs.FilePermissions) (*fs.File, error) {
+ if d.CreateOps == nil || d.CreateOps.NewFile == nil {
+ return nil, syserror.EACCES
+ }
+
+ inode, err := d.createInodeOperationsCommon(ctx, name, func() (*fs.Inode, error) {
+ return d.NewFile(ctx, dir, perms)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Take an extra ref on inode, which will be owned by the dirent.
+ inode.IncRef()
+
+ // Create the Dirent and corresponding file.
+ created := fs.NewDirent(inode, name)
+ defer created.DecRef()
+ return created.Inode.GetFile(ctx, created, flags)
+}
+
+// CreateLink returns a new link.
+func (d *Dir) CreateLink(ctx context.Context, dir *fs.Inode, oldname, newname string) error {
+ if d.CreateOps == nil || d.CreateOps.NewSymlink == nil {
+ return syserror.EACCES
+ }
+ _, err := d.createInodeOperationsCommon(ctx, newname, func() (*fs.Inode, error) {
+ return d.NewSymlink(ctx, dir, oldname)
+ })
+ return err
+}
+
+// CreateHardLink creates a new hard link.
+func (d *Dir) CreateHardLink(ctx context.Context, dir *fs.Inode, target *fs.Inode, name string) error {
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Take an extra reference on the inode and add it to our children.
+ target.IncRef()
+
+ // The link count will be incremented in addChildLocked.
+ d.addChildLocked(ctx, name, target)
+
+ return nil
+}
+
+// CreateDirectory returns a new subdirectory.
+func (d *Dir) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perms fs.FilePermissions) error {
+ if d.CreateOps == nil || d.CreateOps.NewDir == nil {
+ return syserror.EACCES
+ }
+ _, err := d.createInodeOperationsCommon(ctx, name, func() (*fs.Inode, error) {
+ return d.NewDir(ctx, dir, perms)
+ })
+ return err
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (d *Dir) Bind(ctx context.Context, dir *fs.Inode, name string, ep transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Dirent, error) {
+ if d.CreateOps == nil || d.CreateOps.NewBoundEndpoint == nil {
+ return nil, syserror.EACCES
+ }
+ inode, err := d.createInodeOperationsCommon(ctx, name, func() (*fs.Inode, error) {
+ return d.NewBoundEndpoint(ctx, dir, ep, perms)
+ })
+ if err == syscall.EEXIST {
+ return nil, syscall.EADDRINUSE
+ }
+ if err != nil {
+ return nil, err
+ }
+ // Take another ref on inode which will be donated to the new dirent.
+ inode.IncRef()
+ return fs.NewDirent(inode, name), nil
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (d *Dir) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perms fs.FilePermissions) error {
+ if d.CreateOps == nil || d.CreateOps.NewFifo == nil {
+ return syserror.EACCES
+ }
+ _, err := d.createInodeOperationsCommon(ctx, name, func() (*fs.Inode, error) {
+ return d.NewFifo(ctx, dir, perms)
+ })
+ return err
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (d *Dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &dirFileOperations{dir: d}), nil
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (*Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return Rename(ctx, oldParent.InodeOperations, oldName, newParent.InodeOperations, newName, replacement)
+}
+
+// dirFileOperations implements fs.FileOperations for a ramfs directory.
+//
+// +stateify savable
+type dirFileOperations struct {
+ fsutil.DirFileOperations `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // dirCursor contains the name of the last directory entry that was
+ // serialized.
+ dirCursor string
+
+ // dir is the ramfs dir that this file corresponds to.
+ dir *Dir
+}
+
+var _ fs.FileOperations = (*dirFileOperations)(nil)
+
+// Seek implements fs.FileOperations.Seek.
+func (dfo *dirFileOperations) Seek(ctx context.Context, file *fs.File, whence fs.SeekWhence, offset int64) (int64, error) {
+ return fsutil.SeekWithDirCursor(ctx, file, whence, offset, &dfo.dirCursor)
+}
+
+// IterateDir implements DirIterator.IterateDir.
+func (dfo *dirFileOperations) IterateDir(ctx context.Context, dirCtx *fs.DirCtx, offset int) (int, error) {
+ dfo.dir.mu.Lock()
+ defer dfo.dir.mu.Unlock()
+
+ n, err := fs.GenericReaddir(dirCtx, dfo.dir.dentryMap)
+ return offset + n, err
+}
+
+// Readdir implements FileOperations.Readdir.
+func (dfo *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &dfo.dirCursor,
+ }
+ dfo.dir.InodeSimpleAttributes.NotifyAccess(ctx)
+ return fs.DirentReaddir(ctx, file.Dirent, dfo, root, dirCtx, file.Offset())
+}
+
+// hasChildren is a helper method that determines whether an arbitrary inode
+// (not necessarily ramfs) has any children.
+func hasChildren(ctx context.Context, inode *fs.Inode) (bool, error) {
+ // Take an extra ref on inode which will be given to the dirent and
+ // dropped when that dirent is destroyed.
+ inode.IncRef()
+ d := fs.NewTransientDirent(inode)
+ defer d.DecRef()
+
+ file, err := inode.GetFile(ctx, d, fs.FileFlags{Read: true})
+ if err != nil {
+ return false, err
+ }
+ defer file.DecRef()
+
+ ser := &fs.CollectEntriesSerializer{}
+ if err := file.Readdir(ctx, ser); err != nil {
+ return false, err
+ }
+ // We will always write "." and "..", so ignore those two.
+ if ser.Written() > 2 {
+ return true, nil
+ }
+ return false, nil
+}
+
+// Rename renames from a *ramfs.Dir to another *ramfs.Dir.
+func Rename(ctx context.Context, oldParent fs.InodeOperations, oldName string, newParent fs.InodeOperations, newName string, replacement bool) error {
+ op, ok := oldParent.(*Dir)
+ if !ok {
+ return syserror.EXDEV
+ }
+ np, ok := newParent.(*Dir)
+ if !ok {
+ return syserror.EXDEV
+ }
+ if len(newName) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+
+ np.mu.Lock()
+ defer np.mu.Unlock()
+
+ // Is this is an overwriting rename?
+ if replacement {
+ replaced, ok := np.children[newName]
+ if !ok {
+ panic(fmt.Sprintf("Dirent claims rename is replacement, but %q is missing from %+v", newName, np))
+ }
+
+ // Non-empty directories cannot be replaced.
+ if fs.IsDir(replaced.StableAttr) {
+ if ok, err := hasChildren(ctx, replaced); err != nil {
+ return err
+ } else if ok {
+ return syserror.ENOTEMPTY
+ }
+ }
+
+ // Remove the replaced child and drop our reference on it.
+ inode, err := np.removeChildLocked(ctx, newName)
+ if err != nil {
+ return err
+ }
+ inode.DecRef()
+ }
+
+ // Be careful, we may have already grabbed this mutex above.
+ if op != np {
+ op.mu.Lock()
+ defer op.mu.Unlock()
+ }
+
+ // Do the swap.
+ n := op.children[oldName]
+ op.removeChildLocked(ctx, oldName)
+ np.addChildLocked(ctx, newName, n)
+
+ return nil
+}
diff --git a/pkg/sentry/fs/ramfs/ramfs_state_autogen.go b/pkg/sentry/fs/ramfs/ramfs_state_autogen.go
new file mode 100755
index 000000000..dde1765e4
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/ramfs_state_autogen.go
@@ -0,0 +1,94 @@
+// automatically generated by stateify.
+
+package ramfs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Dir) beforeSave() {}
+func (x *Dir) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Save("children", &x.children)
+ m.Save("dentryMap", &x.dentryMap)
+}
+
+func (x *Dir) afterLoad() {}
+func (x *Dir) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Load("children", &x.children)
+ m.Load("dentryMap", &x.dentryMap)
+}
+
+func (x *dirFileOperations) beforeSave() {}
+func (x *dirFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("dirCursor", &x.dirCursor)
+ m.Save("dir", &x.dir)
+}
+
+func (x *dirFileOperations) afterLoad() {}
+func (x *dirFileOperations) load(m state.Map) {
+ m.Load("dirCursor", &x.dirCursor)
+ m.Load("dir", &x.dir)
+}
+
+func (x *Socket) beforeSave() {}
+func (x *Socket) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Save("ep", &x.ep)
+}
+
+func (x *Socket) afterLoad() {}
+func (x *Socket) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Load("ep", &x.ep)
+}
+
+func (x *socketFileOperations) beforeSave() {}
+func (x *socketFileOperations) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *socketFileOperations) afterLoad() {}
+func (x *socketFileOperations) load(m state.Map) {
+}
+
+func (x *Symlink) beforeSave() {}
+func (x *Symlink) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Save("Target", &x.Target)
+}
+
+func (x *Symlink) afterLoad() {}
+func (x *Symlink) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Load("Target", &x.Target)
+}
+
+func (x *symlinkFileOperations) beforeSave() {}
+func (x *symlinkFileOperations) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *symlinkFileOperations) afterLoad() {}
+func (x *symlinkFileOperations) load(m state.Map) {
+}
+
+func init() {
+ state.Register("ramfs.Dir", (*Dir)(nil), state.Fns{Save: (*Dir).save, Load: (*Dir).load})
+ state.Register("ramfs.dirFileOperations", (*dirFileOperations)(nil), state.Fns{Save: (*dirFileOperations).save, Load: (*dirFileOperations).load})
+ state.Register("ramfs.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load})
+ state.Register("ramfs.socketFileOperations", (*socketFileOperations)(nil), state.Fns{Save: (*socketFileOperations).save, Load: (*socketFileOperations).load})
+ state.Register("ramfs.Symlink", (*Symlink)(nil), state.Fns{Save: (*Symlink).save, Load: (*Symlink).load})
+ state.Register("ramfs.symlinkFileOperations", (*symlinkFileOperations)(nil), state.Fns{Save: (*symlinkFileOperations).save, Load: (*symlinkFileOperations).load})
+}
diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go
new file mode 100644
index 000000000..7d8bca70e
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/socket.go
@@ -0,0 +1,85 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ramfs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Socket represents a socket.
+//
+// +stateify savable
+type Socket struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeSimpleExtendedAttributes
+
+ // ep is the bound endpoint.
+ ep transport.BoundEndpoint
+}
+
+var _ fs.InodeOperations = (*Socket)(nil)
+
+// NewSocket returns a new Socket.
+func NewSocket(ctx context.Context, ep transport.BoundEndpoint, owner fs.FileOwner, perms fs.FilePermissions) *Socket {
+ return &Socket{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, perms, linux.SOCKFS_MAGIC),
+ ep: ep,
+ }
+}
+
+// BoundEndpoint returns the socket data.
+func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint {
+ // ramfs only supports stored sentry internal sockets. Only gofer sockets
+ // care about the path argument.
+ return s.ep
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (s *Socket) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &socketFileOperations{}), nil
+}
+
+// +stateify savable
+type socketFileOperations struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoRead `state:"nosave"`
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*socketFileOperations)(nil)
diff --git a/pkg/sentry/fs/ramfs/symlink.go b/pkg/sentry/fs/ramfs/symlink.go
new file mode 100644
index 000000000..21c246169
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/symlink.go
@@ -0,0 +1,106 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ramfs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Symlink represents a symlink.
+//
+// +stateify savable
+type Symlink struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeSimpleExtendedAttributes
+
+ // Target is the symlink target.
+ Target string
+}
+
+var _ fs.InodeOperations = (*Symlink)(nil)
+
+// NewSymlink returns a new Symlink.
+func NewSymlink(ctx context.Context, owner fs.FileOwner, target string) *Symlink {
+ // A symlink is assumed to always have permissions 0777.
+ return &Symlink{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(0777), linux.RAMFS_MAGIC),
+ Target: target,
+ }
+}
+
+// UnstableAttr returns all attributes of this ramfs symlink.
+func (s *Symlink) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ uattr, err := s.InodeSimpleAttributes.UnstableAttr(ctx, inode)
+ if err != nil {
+ return fs.UnstableAttr{}, err
+ }
+ uattr.Size = int64(len(s.Target))
+ uattr.Usage = uattr.Size
+ return uattr, nil
+}
+
+// SetPermissions on a symlink is always rejected.
+func (s *Symlink) SetPermissions(context.Context, *fs.Inode, fs.FilePermissions) bool {
+ return false
+}
+
+// Readlink reads the symlink value.
+func (s *Symlink) Readlink(ctx context.Context, _ *fs.Inode) (string, error) {
+ s.NotifyAccess(ctx)
+ return s.Target, nil
+}
+
+// Getlink returns ErrResolveViaReadlink, falling back to walking to the result
+// of Readlink().
+func (*Symlink) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
+ return nil, fs.ErrResolveViaReadlink
+}
+
+// GetFile implements fs.FileOperations.GetFile.
+func (s *Symlink) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &symlinkFileOperations{}), nil
+}
+
+// +stateify savable
+type symlinkFileOperations struct {
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoRead `state:"nosave"`
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+}
+
+var _ fs.FileOperations = (*symlinkFileOperations)(nil)
diff --git a/pkg/sentry/fs/ramfs/tree.go b/pkg/sentry/fs/ramfs/tree.go
new file mode 100644
index 000000000..8c6b31f70
--- /dev/null
+++ b/pkg/sentry/fs/ramfs/tree.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ramfs
+
+import (
+ "fmt"
+ "path"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/anon"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// MakeDirectoryTree constructs a ramfs tree of all directories containing
+// subdirs. Each element of subdir must be a clean path, and cannot be empty or
+// "/".
+//
+// All directories in the created tree will have full (read-write-execute)
+// permissions, but note that file creation inside the directories is not
+// actually supported because ramfs.Dir.CreateOpts == nil. However, these
+// directory trees are normally "underlayed" under another filesystem (possibly
+// the root), and file creation inside these directories in the overlay will be
+// possible if the upper is writeable.
+func MakeDirectoryTree(ctx context.Context, msrc *fs.MountSource, subdirs []string) (*fs.Inode, error) {
+ root := emptyDir(ctx, msrc)
+ for _, subdir := range subdirs {
+ if path.Clean(subdir) != subdir {
+ return nil, fmt.Errorf("cannot add subdir at an unclean path: %q", subdir)
+ }
+ if subdir == "" || subdir == "/" {
+ return nil, fmt.Errorf("cannot add subdir at %q", subdir)
+ }
+ makeSubdir(ctx, msrc, root.InodeOperations.(*Dir), subdir)
+ }
+ return root, nil
+}
+
+// makeSubdir installs into root each component of subdir. The final component is
+// a *ramfs.Dir.
+func makeSubdir(ctx context.Context, msrc *fs.MountSource, root *Dir, subdir string) {
+ for _, c := range strings.Split(subdir, "/") {
+ if len(c) == 0 {
+ continue
+ }
+ child, ok := root.FindChild(c)
+ if !ok {
+ child = emptyDir(ctx, msrc)
+ root.AddChild(ctx, c, child)
+ }
+ root = child.InodeOperations.(*Dir)
+ }
+}
+
+// emptyDir returns an empty *ramfs.Dir with all permissions granted.
+func emptyDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ dir := NewDir(ctx, make(map[string]*fs.Inode), fs.RootOwner, fs.FilePermsFromMode(0777))
+ return fs.NewInode(dir, msrc, fs.StableAttr{
+ DeviceID: anon.PseudoDevice.DeviceID(),
+ InodeID: anon.PseudoDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go
new file mode 100644
index 000000000..f10168125
--- /dev/null
+++ b/pkg/sentry/fs/restore.go
@@ -0,0 +1,78 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "sync"
+)
+
+// RestoreEnvironment is the restore environment for file systems. It consists
+// of things that change across save and restore and therefore cannot be saved
+// in the object graph.
+type RestoreEnvironment struct {
+ // MountSources maps Filesystem.Name() to mount arguments.
+ MountSources map[string][]MountArgs
+
+ // ValidateFileSize indicates file size should not change across S/R.
+ ValidateFileSize bool
+
+ // ValidateFileTimestamp indicates file modification timestamp should
+ // not change across S/R.
+ ValidateFileTimestamp bool
+}
+
+// MountArgs holds arguments to Mount.
+type MountArgs struct {
+ // Dev corresponds to the devname argumnent of Mount.
+ Dev string
+
+ // Flags corresponds to the flags argument of Mount.
+ Flags MountSourceFlags
+
+ // DataString corresponds to the data argument of Mount.
+ DataString string
+
+ // DataObj corresponds to the data interface argument of Mount.
+ DataObj interface{}
+}
+
+// restoreEnv holds the fs package global RestoreEnvironment.
+var restoreEnv = struct {
+ mu sync.Mutex
+ env RestoreEnvironment
+ set bool
+}{}
+
+// SetRestoreEnvironment sets the RestoreEnvironment. Must be called before
+// state.Load and only once.
+func SetRestoreEnvironment(r RestoreEnvironment) {
+ restoreEnv.mu.Lock()
+ defer restoreEnv.mu.Unlock()
+ if restoreEnv.set {
+ panic("RestoreEnvironment may only be set once")
+ }
+ restoreEnv.env = r
+ restoreEnv.set = true
+}
+
+// CurrentRestoreEnvironment returns the current, read-only RestoreEnvironment.
+// If no RestoreEnvironment was ever set, returns (_, false).
+func CurrentRestoreEnvironment() (RestoreEnvironment, bool) {
+ restoreEnv.mu.Lock()
+ defer restoreEnv.mu.Unlock()
+ e := restoreEnv.env
+ set := restoreEnv.set
+ return e, set
+}
diff --git a/pkg/sentry/fs/save.go b/pkg/sentry/fs/save.go
new file mode 100644
index 000000000..2eaf6ab69
--- /dev/null
+++ b/pkg/sentry/fs/save.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+// SaveInodeMappings saves a mapping of path -> inode ID for every
+// user-reachable Dirent.
+//
+// The entire kernel must be frozen to call this, and filesystem state must not
+// change between SaveInodeMappings and state.Save, otherwise the saved state
+// of any MountSource may be incoherent.
+func SaveInodeMappings() {
+ mountsSeen := make(map[*MountSource]struct{})
+ for dirent := range allDirents.dirents {
+ if _, ok := mountsSeen[dirent.Inode.MountSource]; !ok {
+ dirent.Inode.MountSource.ResetInodeMappings()
+ mountsSeen[dirent.Inode.MountSource] = struct{}{}
+ }
+ }
+
+ for dirent := range allDirents.dirents {
+ if dirent.Inode != nil {
+ // We cannot trust the root provided in the mount due
+ // to the overlay. We can trust the overlay to delegate
+ // SaveInodeMappings to the right underlying
+ // filesystems, though.
+ root := dirent
+ for !root.mounted && root.parent != nil {
+ root = root.parent
+ }
+
+ // Add the mapping.
+ n, reachable := dirent.FullName(root)
+ if !reachable {
+ // Something has gone seriously wrong if we can't reach our root.
+ panic(fmt.Sprintf("Unreachable root on dirent file %s", n))
+ }
+ dirent.Inode.MountSource.SaveInodeMapping(dirent.Inode, n)
+ }
+ }
+}
+
+// SaveFileFsyncError converts an fs.File.Fsync error to an error that
+// indicates that the fs.File was not synced sufficiently to be saved.
+func SaveFileFsyncError(err error) error {
+ switch err {
+ case nil:
+ // We succeeded, everything is great.
+ return nil
+ case syscall.EBADF, syscall.EINVAL, syscall.EROFS, syscall.ENOSYS, syscall.EPERM:
+ // These errors mean that the underlying node might not be syncable,
+ // which we expect to be reported as such even from the gofer.
+ log.Infof("failed to sync during save: %v", err)
+ return nil
+ default:
+ // We failed in some way that indicates potential data loss.
+ return fmt.Errorf("failed to sync: %v, data loss may occur", err)
+ }
+}
diff --git a/pkg/sentry/fs/seek.go b/pkg/sentry/fs/seek.go
new file mode 100644
index 000000000..0f43918ad
--- /dev/null
+++ b/pkg/sentry/fs/seek.go
@@ -0,0 +1,43 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+// SeekWhence determines seek direction.
+type SeekWhence int
+
+const (
+ // SeekSet sets the absolute offset.
+ SeekSet SeekWhence = iota
+
+ // SeekCurrent sets relative to the current position.
+ SeekCurrent
+
+ // SeekEnd sets relative to the end of the file.
+ SeekEnd
+)
+
+// String returns a human readable string for whence.
+func (s SeekWhence) String() string {
+ switch s {
+ case SeekSet:
+ return "Set"
+ case SeekCurrent:
+ return "Current"
+ case SeekEnd:
+ return "End"
+ default:
+ return "Unknown"
+ }
+}
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
new file mode 100644
index 000000000..65937f44d
--- /dev/null
+++ b/pkg/sentry/fs/splice.go
@@ -0,0 +1,187 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+import (
+ "io"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/secio"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Splice moves data to this file, directly from another.
+//
+// Offsets are updated only if DstOffset and SrcOffset are set.
+func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, error) {
+ // Verify basic file flag permissions.
+ if !dst.Flags().Write || !src.Flags().Read {
+ return 0, syserror.EBADF
+ }
+
+ // Check whether or not the objects being sliced are stream-oriented
+ // (i.e. pipes or sockets). If yes, we elide checks and offset locks.
+ srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr)
+ dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr)
+
+ if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset {
+ switch {
+ case dst.UniqueID < src.UniqueID:
+ // Acquire dst first.
+ if !dst.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer dst.mu.Unlock()
+ if !src.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer src.mu.Unlock()
+ case dst.UniqueID > src.UniqueID:
+ // Acquire src first.
+ if !src.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer src.mu.Unlock()
+ if !dst.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer dst.mu.Unlock()
+ case dst.UniqueID == src.UniqueID:
+ // Acquire only one lock; it's the same file. This is a
+ // bit of a edge case, but presumably it's possible.
+ if !dst.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer dst.mu.Unlock()
+ }
+ // Use both offsets (locked).
+ opts.DstStart = dst.offset
+ opts.SrcStart = src.offset
+ } else if !dstPipe && !opts.DstOffset {
+ // Acquire only dst.
+ if !dst.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer dst.mu.Unlock()
+ opts.DstStart = dst.offset // Safe: locked.
+ } else if !srcPipe && !opts.SrcOffset {
+ // Acquire only src.
+ if !src.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer src.mu.Unlock()
+ opts.SrcStart = src.offset // Safe: locked.
+ }
+
+ // Check append-only mode and the limit.
+ if !dstPipe {
+ if dst.Flags().Append {
+ if opts.DstOffset {
+ // We need to acquire the lock.
+ if !dst.mu.Lock(ctx) {
+ return 0, syserror.ErrInterrupted
+ }
+ defer dst.mu.Unlock()
+ }
+ // Figure out the appropriate offset to use.
+ if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil {
+ return 0, err
+ }
+ }
+
+ // Enforce file limits.
+ limit, ok := dst.checkLimit(ctx, opts.DstStart)
+ switch {
+ case ok && limit == 0:
+ return 0, syserror.ErrExceedsFileSizeLimit
+ case ok && limit < opts.Length:
+ opts.Length = limit // Cap the write.
+ }
+ }
+
+ // Attempt to do a WriteTo; this is likely the most efficient.
+ //
+ // The underlying implementation may be able to donate buffers.
+ newOpts := SpliceOpts{
+ Length: opts.Length,
+ SrcStart: opts.SrcStart,
+ SrcOffset: !srcPipe,
+ Dup: opts.Dup,
+ DstStart: opts.DstStart,
+ DstOffset: !dstPipe,
+ }
+ n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts)
+ if n == 0 && err != nil {
+ // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also
+ // be more efficient than a copy if buffers are cached or readily
+ // available. (It's unlikely that they can actually be donate
+ n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts)
+ }
+ if n == 0 && err != nil {
+ // If we've failed up to here, and at least one of the sources
+ // is a pipe or socket, then we can't properly support dup.
+ // Return an error indicating that this operation is not
+ // supported.
+ if (srcPipe || dstPipe) && newOpts.Dup {
+ return 0, syserror.EINVAL
+ }
+
+ // We failed to splice the files. But that's fine; we just fall
+ // back to a slow path in this case. This copies without doing
+ // any mode changes, so should still be more efficient.
+ var (
+ r io.Reader
+ w io.Writer
+ )
+ fw := &lockedWriter{
+ Ctx: ctx,
+ File: dst,
+ }
+ if newOpts.DstOffset {
+ // Use the provided offset.
+ w = secio.NewOffsetWriter(fw, newOpts.DstStart)
+ } else {
+ // Writes will proceed with no offset.
+ w = fw
+ }
+ fr := &lockedReader{
+ Ctx: ctx,
+ File: src,
+ }
+ if newOpts.SrcOffset {
+ // Limit to the given offset and length.
+ r = io.NewSectionReader(fr, opts.SrcStart, opts.Length)
+ } else {
+ // Limit just to the given length.
+ r = &io.LimitedReader{fr, opts.Length}
+ }
+
+ // Copy between the two.
+ n, err = io.Copy(w, r)
+ }
+
+ // Update offsets, if required.
+ if n > 0 {
+ if !dstPipe && !opts.DstOffset {
+ atomic.StoreInt64(&dst.offset, dst.offset+n)
+ }
+ if !srcPipe && !opts.SrcOffset {
+ atomic.StoreInt64(&src.offset, src.offset+n)
+ }
+ }
+
+ return n, err
+}
diff --git a/pkg/sentry/fs/sync.go b/pkg/sentry/fs/sync.go
new file mode 100644
index 000000000..1fff8059c
--- /dev/null
+++ b/pkg/sentry/fs/sync.go
@@ -0,0 +1,43 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fs
+
+// SyncType enumerates ways in which a File can be synced.
+type SyncType int
+
+const (
+ // SyncAll indicates that modified in-memory metadata and data should
+ // be written to backing storage. SyncAll implies SyncBackingStorage.
+ SyncAll SyncType = iota
+
+ // SyncData indicates that along with modified in-memory data, only
+ // metadata needed to access that data needs to be written.
+ //
+ // For example, changes to access time or modification time do not
+ // need to be written because they are not necessary for a data read
+ // to be handled correctly, unlike the file size.
+ //
+ // The aim of SyncData is to reduce disk activity for applications
+ // that do not require all metadata to be synchronized with the disk,
+ // see fdatasync(2). File systems that implement SyncData as SyncAll
+ // do not support this optimization.
+ //
+ // SyncData implies SyncBackingStorage.
+ SyncData
+
+ // SyncBackingStorage indicates that in-flight write operations to
+ // backing storage should be flushed.
+ SyncBackingStorage
+)
diff --git a/pkg/sentry/fs/sys/device.go b/pkg/sentry/fs/sys/device.go
new file mode 100644
index 000000000..128d3a9d9
--- /dev/null
+++ b/pkg/sentry/fs/sys/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sys
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// sysfsDevice is the sysfs virtual device.
+var sysfsDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/sys/devices.go b/pkg/sentry/fs/sys/devices.go
new file mode 100644
index 000000000..54f35c6a0
--- /dev/null
+++ b/pkg/sentry/fs/sys/devices.go
@@ -0,0 +1,91 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sys
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+)
+
+// +stateify savable
+type cpunum struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeNotVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+ fsutil.InodeStaticFileGetter
+}
+
+var _ fs.InodeOperations = (*cpunum)(nil)
+
+func newPossible(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ var maxCore uint
+ k := kernel.KernelFromContext(ctx)
+ if k != nil {
+ maxCore = k.ApplicationCores() - 1
+ }
+ contents := []byte(fmt.Sprintf("0-%d\n", maxCore))
+
+ c := &cpunum{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.SYSFS_MAGIC),
+ InodeStaticFileGetter: fsutil.InodeStaticFileGetter{
+ Contents: contents,
+ },
+ }
+ return newFile(c, msrc)
+}
+
+func newCPU(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ m := map[string]*fs.Inode{
+ "online": newPossible(ctx, msrc),
+ "possible": newPossible(ctx, msrc),
+ "present": newPossible(ctx, msrc),
+ }
+
+ // Add directories for each of the cpus.
+ if k := kernel.KernelFromContext(ctx); k != nil {
+ for i := 0; uint(i) < k.ApplicationCores(); i++ {
+ m[fmt.Sprintf("cpu%d", i)] = newDir(ctx, msrc, nil)
+ }
+ }
+
+ return newDir(ctx, msrc, m)
+}
+
+func newSystemDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ return newDir(ctx, msrc, map[string]*fs.Inode{
+ "cpu": newCPU(ctx, msrc),
+ })
+}
+
+func newDevicesDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ return newDir(ctx, msrc, map[string]*fs.Inode{
+ "system": newSystemDir(ctx, msrc),
+ })
+}
diff --git a/pkg/sentry/fs/sys/fs.go b/pkg/sentry/fs/sys/fs.go
new file mode 100644
index 000000000..f0c2322e0
--- /dev/null
+++ b/pkg/sentry/fs/sys/fs.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sys
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// filesystem is a sysfs.
+//
+// +stateify savable
+type filesystem struct{}
+
+var _ fs.Filesystem = (*filesystem)(nil)
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name underwhich the filesystem is registered.
+// Name matches fs/sysfs/mount.c:sysfs_fs_type.name.
+const FilesystemName = "sysfs"
+
+// Name is the name of the file system.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*filesystem) AllowUserMount() bool {
+ return true
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// In Linux, sysfs returns FS_USERNS_VISIBLE | FS_USERNS_MOUNT, see fs/sysfs/mount.c.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns a sysfs root which can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+ // sysfs ignores data, see fs/sysfs/mount.c:sysfs_mount.
+
+ return New(ctx, fs.NewNonCachingMountSource(f, flags)), nil
+}
diff --git a/pkg/sentry/fs/sys/sys.go b/pkg/sentry/fs/sys/sys.go
new file mode 100644
index 000000000..d20ef91fa
--- /dev/null
+++ b/pkg/sentry/fs/sys/sys.go
@@ -0,0 +1,64 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sys implements a sysfs filesystem.
+package sys
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+func newFile(node fs.InodeOperations, msrc *fs.MountSource) *fs.Inode {
+ sattr := fs.StableAttr{
+ DeviceID: sysfsDevice.DeviceID(),
+ InodeID: sysfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(node, msrc, sattr)
+}
+
+func newDir(ctx context.Context, msrc *fs.MountSource, contents map[string]*fs.Inode) *fs.Inode {
+ d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return fs.NewInode(d, msrc, fs.StableAttr{
+ DeviceID: sysfsDevice.DeviceID(),
+ InodeID: sysfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialDirectory,
+ })
+}
+
+// New returns the root node of a partial simple sysfs.
+func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ return newDir(ctx, msrc, map[string]*fs.Inode{
+ // Add a basic set of top-level directories. In Linux, these
+ // are dynamically added depending on the KConfig. Here we just
+ // add the most common ones.
+ "block": newDir(ctx, msrc, nil),
+ "bus": newDir(ctx, msrc, nil),
+ "class": newDir(ctx, msrc, map[string]*fs.Inode{
+ "power_supply": newDir(ctx, msrc, nil),
+ }),
+ "dev": newDir(ctx, msrc, nil),
+ "devices": newDevicesDir(ctx, msrc),
+ "firmware": newDir(ctx, msrc, nil),
+ "fs": newDir(ctx, msrc, nil),
+ "kernel": newDir(ctx, msrc, nil),
+ "module": newDir(ctx, msrc, nil),
+ "power": newDir(ctx, msrc, nil),
+ })
+}
diff --git a/pkg/sentry/fs/sys/sys_state_autogen.go b/pkg/sentry/fs/sys/sys_state_autogen.go
new file mode 100755
index 000000000..84779f991
--- /dev/null
+++ b/pkg/sentry/fs/sys/sys_state_autogen.go
@@ -0,0 +1,34 @@
+// automatically generated by stateify.
+
+package sys
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *cpunum) beforeSave() {}
+func (x *cpunum) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("InodeStaticFileGetter", &x.InodeStaticFileGetter)
+}
+
+func (x *cpunum) afterLoad() {}
+func (x *cpunum) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("InodeStaticFileGetter", &x.InodeStaticFileGetter)
+}
+
+func (x *filesystem) beforeSave() {}
+func (x *filesystem) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *filesystem) afterLoad() {}
+func (x *filesystem) load(m state.Map) {
+}
+
+func init() {
+ state.Register("sys.cpunum", (*cpunum)(nil), state.Fns{Save: (*cpunum).save, Load: (*cpunum).load})
+ state.Register("sys.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load})
+}
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
new file mode 100644
index 000000000..bce5f091d
--- /dev/null
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -0,0 +1,148 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package timerfd implements the semantics of Linux timerfd objects as
+// described by timerfd_create(2).
+package timerfd
+
+import (
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/anon"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// TimerOperations implements fs.FileOperations for timerfds.
+//
+// +stateify savable
+type TimerOperations struct {
+ fsutil.FileZeroSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ events waiter.Queue `state:"zerovalue"`
+ timer *ktime.Timer
+
+ // val is the number of timer expirations since the last successful call to
+ // Readv, Preadv, or SetTime. val is accessed using atomic memory
+ // operations.
+ val uint64
+}
+
+// NewFile returns a timerfd File that receives time from c.
+func NewFile(ctx context.Context, c ktime.Clock) *fs.File {
+ dirent := fs.NewDirent(anon.NewInode(ctx), "anon_inode:[timerfd]")
+ tops := &TimerOperations{}
+ tops.timer = ktime.NewTimer(c, tops)
+ // Timerfds reject writes, but the Write flag must be set in order to
+ // ensure that our Writev/Pwritev methods actually get called to return
+ // the correct errors.
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, tops)
+}
+
+// Release implements fs.FileOperations.Release.
+func (t *TimerOperations) Release() {
+ t.timer.Destroy()
+}
+
+// PauseTimer pauses the associated Timer.
+func (t *TimerOperations) PauseTimer() {
+ t.timer.Pause()
+}
+
+// ResumeTimer resumes the associated Timer.
+func (t *TimerOperations) ResumeTimer() {
+ t.timer.Resume()
+}
+
+// Clock returns the associated Timer's Clock.
+func (t *TimerOperations) Clock() ktime.Clock {
+ return t.timer.Clock()
+}
+
+// GetTime returns the associated Timer's setting and the time at which it was
+// observed.
+func (t *TimerOperations) GetTime() (ktime.Time, ktime.Setting) {
+ return t.timer.Get()
+}
+
+// SetTime atomically changes the associated Timer's setting, resets the number
+// of expirations to 0, and returns the previous setting and the time at which
+// it was observed.
+func (t *TimerOperations) SetTime(s ktime.Setting) (ktime.Time, ktime.Setting) {
+ return t.timer.SwapAnd(s, func() { atomic.StoreUint64(&t.val, 0) })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (t *TimerOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ if atomic.LoadUint64(&t.val) != 0 {
+ ready |= waiter.EventIn
+ }
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (t *TimerOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ t.events.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (t *TimerOperations) EventUnregister(e *waiter.Entry) {
+ t.events.EventUnregister(e)
+}
+
+// Read implements fs.FileOperations.Read.
+func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ const sizeofUint64 = 8
+ if dst.NumBytes() < sizeofUint64 {
+ return 0, syserror.EINVAL
+ }
+ if val := atomic.SwapUint64(&t.val, 0); val != 0 {
+ var buf [sizeofUint64]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ if _, err := dst.CopyOut(ctx, buf[:]); err != nil {
+ // Linux does not undo consuming the number of expirations even if
+ // writing to userspace fails.
+ return 0, err
+ }
+ return sizeofUint64, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// Write implements fs.FileOperations.Write.
+func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EINVAL
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (t *TimerOperations) Notify(exp uint64) {
+ atomic.AddUint64(&t.val, exp)
+ t.events.Notify(waiter.EventIn)
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (t *TimerOperations) Destroy() {}
diff --git a/pkg/sentry/fs/timerfd/timerfd_state_autogen.go b/pkg/sentry/fs/timerfd/timerfd_state_autogen.go
new file mode 100755
index 000000000..bae449d97
--- /dev/null
+++ b/pkg/sentry/fs/timerfd/timerfd_state_autogen.go
@@ -0,0 +1,25 @@
+// automatically generated by stateify.
+
+package timerfd
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *TimerOperations) beforeSave() {}
+func (x *TimerOperations) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.events) { m.Failf("events is %v, expected zero", x.events) }
+ m.Save("timer", &x.timer)
+ m.Save("val", &x.val)
+}
+
+func (x *TimerOperations) afterLoad() {}
+func (x *TimerOperations) load(m state.Map) {
+ m.Load("timer", &x.timer)
+ m.Load("val", &x.val)
+}
+
+func init() {
+ state.Register("timerfd.TimerOperations", (*TimerOperations)(nil), state.Fns{Save: (*TimerOperations).save, Load: (*TimerOperations).load})
+}
diff --git a/pkg/sentry/fs/tmpfs/device.go b/pkg/sentry/fs/tmpfs/device.go
new file mode 100644
index 000000000..179c3a46f
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// tmpfsDevice is the kernel tmpfs device.
+var tmpfsDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/fs/tmpfs/file_regular.go b/pkg/sentry/fs/tmpfs/file_regular.go
new file mode 100644
index 000000000..d1c163879
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/file_regular.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// regularFileOperations implements fs.FileOperations for a regular
+// tmpfs file.
+//
+// +stateify savable
+type regularFileOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ // iops is the InodeOperations of a regular tmpfs file. It is
+ // guaranteed to be the same as file.Dirent.Inode.InodeOperations,
+ // see operations that take fs.File below.
+ iops *fileInodeOperations
+}
+
+// Read implements fs.FileOperations.Read.
+func (r *regularFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ return r.iops.read(ctx, file, dst, offset)
+}
+
+// Write implements fs.FileOperations.Write.
+func (r *regularFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ return r.iops.write(ctx, src, offset)
+}
+
+// ConfigureMMap implements fs.FileOperations.ConfigureMMap.
+func (r *regularFileOperations) ConfigureMMap(ctx context.Context, file *fs.File, opts *memmap.MMapOpts) error {
+ return fsutil.GenericConfigureMMap(file, r.iops, opts)
+}
diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go
new file mode 100644
index 000000000..b7c29a4d1
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/fs.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+)
+
+const (
+ // Set initial permissions for the root directory.
+ modeKey = "mode"
+
+ // UID for the root directory.
+ rootUIDKey = "uid"
+
+ // GID for the root directory.
+ rootGIDKey = "gid"
+
+ // TODO(edahlgren/mpratt): support a tmpfs size limit.
+ // size = "size"
+
+ // Permissions that exceed modeMask will be rejected.
+ modeMask = 01777
+
+ // Default permissions are read/write/execute.
+ defaultMode = 0777
+)
+
+// Filesystem is a tmpfs.
+//
+// +stateify savable
+type Filesystem struct{}
+
+var _ fs.Filesystem = (*Filesystem)(nil)
+
+func init() {
+ fs.RegisterFilesystem(&Filesystem{})
+}
+
+// FilesystemName is the name underwhich the filesystem is registered.
+// Name matches mm/shmem.c:shmem_fs_type.name.
+const FilesystemName = "tmpfs"
+
+// Name is the name of the file system.
+func (*Filesystem) Name() string {
+ return FilesystemName
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*Filesystem) AllowUserMount() bool {
+ return true
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*Filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+//
+// In Linux, tmpfs returns FS_USERNS_MOUNT, see mm/shmem.c.
+func (*Filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// Mount returns a tmpfs root that can be positioned in the vfs.
+func (f *Filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+
+ // Parse generic comma-separated key=value options, this file system expects them.
+ options := fs.GenericMountSourceOptions(data)
+
+ // Parse the root directory permissions.
+ perms := fs.FilePermsFromMode(defaultMode)
+ if m, ok := options[modeKey]; ok {
+ i, err := strconv.ParseUint(m, 8, 32)
+ if err != nil {
+ return nil, fmt.Errorf("mode value not parsable 'mode=%s': %v", m, err)
+ }
+ if i&^modeMask != 0 {
+ return nil, fmt.Errorf("invalid mode %q: must be less than %o", m, modeMask)
+ }
+ perms = fs.FilePermsFromMode(linux.FileMode(i))
+ delete(options, modeKey)
+ }
+
+ creds := auth.CredentialsFromContext(ctx)
+ owner := fs.FileOwnerFromContext(ctx)
+ if uidstr, ok := options[rootUIDKey]; ok {
+ uid, err := strconv.ParseInt(uidstr, 10, 32)
+ if err != nil {
+ return nil, fmt.Errorf("uid value not parsable 'uid=%d': %v", uid, err)
+ }
+ owner.UID = creds.UserNamespace.MapToKUID(auth.UID(uid))
+ delete(options, rootUIDKey)
+ }
+
+ if gidstr, ok := options[rootGIDKey]; ok {
+ gid, err := strconv.ParseInt(gidstr, 10, 32)
+ if err != nil {
+ return nil, fmt.Errorf("gid value not parsable 'gid=%d': %v", gid, err)
+ }
+ owner.GID = creds.UserNamespace.MapToKGID(auth.GID(gid))
+ delete(options, rootGIDKey)
+ }
+
+ // Fail if the caller passed us more options than we can parse. They may be
+ // expecting us to set something we can't set.
+ if len(options) > 0 {
+ return nil, fmt.Errorf("unsupported mount options: %v", options)
+ }
+
+ // Construct a mount which will cache dirents.
+ msrc := fs.NewCachingMountSource(f, flags)
+
+ // Construct the tmpfs root.
+ return NewDir(ctx, nil, owner, perms, msrc), nil
+}
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
new file mode 100644
index 000000000..3fe659543
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -0,0 +1,681 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "io"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+var (
+ opensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.")
+ opensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.")
+ reads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.")
+ readWait = metric.MustCreateNewUint64Metric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.")
+)
+
+// fileInodeOperations implements fs.InodeOperations for a regular tmpfs file.
+// These files are backed by pages allocated from a platform.Memory, and may be
+// directly mapped.
+//
+// Lock order: attrMu -> mapsMu -> dataMu.
+//
+// +stateify savable
+type fileInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+
+ fsutil.InodeSimpleExtendedAttributes
+
+ // kernel is used to allocate memory that stores the file's contents.
+ kernel *kernel.Kernel
+
+ // memUsage is the default memory usage that will be reported by this file.
+ memUsage usage.MemoryKind
+
+ attrMu sync.Mutex `state:"nosave"`
+
+ // attr contains the unstable metadata for the file.
+ //
+ // attr is protected by attrMu. attr.Size is protected by both attrMu
+ // and dataMu; reading it requires locking either mutex, while mutating
+ // it requires locking both.
+ attr fs.UnstableAttr
+
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ dataMu sync.RWMutex `state:"nosave"`
+
+ // data maps offsets into the file to offsets into platform.Memory() that
+ // store the file's data.
+ //
+ // data is protected by dataMu.
+ data fsutil.FileRangeSet
+
+ // seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
+ seals uint32
+}
+
+var _ fs.InodeOperations = (*fileInodeOperations)(nil)
+
+// NewInMemoryFile returns a new file backed by Kernel.MemoryFile().
+func NewInMemoryFile(ctx context.Context, usage usage.MemoryKind, uattr fs.UnstableAttr) fs.InodeOperations {
+ return &fileInodeOperations{
+ attr: uattr,
+ kernel: kernel.KernelFromContext(ctx),
+ memUsage: usage,
+ seals: linux.F_SEAL_SEAL,
+ }
+}
+
+// NewMemfdInode creates a new inode backing a memfd. Memory used by the memfd
+// is backed by platform memory.
+func NewMemfdInode(ctx context.Context, allowSeals bool) *fs.Inode {
+ // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with
+ // S_IRWXUGO.
+ perms := fs.PermMask{Read: true, Write: true, Execute: true}
+ iops := NewInMemoryFile(ctx, usage.Tmpfs, fs.UnstableAttr{
+ Owner: fs.FileOwnerFromContext(ctx),
+ Perms: fs.FilePermissions{User: perms, Group: perms, Other: perms}}).(*fileInodeOperations)
+ if allowSeals {
+ iops.seals = 0
+ }
+ return fs.NewInode(iops, fs.NewNonCachingMountSource(nil, fs.MountSourceFlags{}), fs.StableAttr{
+ Type: fs.RegularFile,
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ })
+}
+
+// Release implements fs.InodeOperations.Release.
+func (f *fileInodeOperations) Release(context.Context) {
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+ f.data.DropAll(f.kernel.MemoryFile())
+}
+
+// Mappable implements fs.InodeOperations.Mappable.
+func (f *fileInodeOperations) Mappable(*fs.Inode) memmap.Mappable {
+ return f
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ if flags.Write {
+ opensW.Increment()
+ } else if flags.Read {
+ opensRO.Increment()
+ }
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, d, flags, &regularFileOperations{iops: f}), nil
+}
+
+// UnstableAttr returns unstable attributes of this tmpfs file.
+func (f *fileInodeOperations) UnstableAttr(ctx context.Context, inode *fs.Inode) (fs.UnstableAttr, error) {
+ f.attrMu.Lock()
+ f.dataMu.RLock()
+ attr := f.attr
+ attr.Usage = int64(f.data.Span())
+ f.dataMu.RUnlock()
+ f.attrMu.Unlock()
+ return attr, nil
+}
+
+// Check implements fs.InodeOperations.Check.
+func (f *fileInodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (f *fileInodeOperations) SetPermissions(ctx context.Context, _ *fs.Inode, p fs.FilePermissions) bool {
+ f.attrMu.Lock()
+ f.attr.SetPermissions(ctx, p)
+ f.attrMu.Unlock()
+ return true
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (f *fileInodeOperations) SetTimestamps(ctx context.Context, _ *fs.Inode, ts fs.TimeSpec) error {
+ f.attrMu.Lock()
+ f.attr.SetTimestamps(ctx, ts)
+ f.attrMu.Unlock()
+ return nil
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (f *fileInodeOperations) SetOwner(ctx context.Context, _ *fs.Inode, owner fs.FileOwner) error {
+ f.attrMu.Lock()
+ f.attr.SetOwner(ctx, owner)
+ f.attrMu.Unlock()
+ return nil
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (f *fileInodeOperations) Truncate(ctx context.Context, _ *fs.Inode, size int64) error {
+ f.attrMu.Lock()
+ defer f.attrMu.Unlock()
+
+ f.dataMu.Lock()
+ oldSize := f.attr.Size
+
+ // Check if current seals allow truncation.
+ switch {
+ case size > oldSize && f.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ fallthrough
+ case oldSize > size && f.seals&linux.F_SEAL_SHRINK != 0: // Shrink sealed
+ f.dataMu.Unlock()
+ return syserror.EPERM
+ }
+
+ if oldSize != size {
+ f.attr.Size = size
+ // Update mtime and ctime.
+ now := ktime.NowFromContext(ctx)
+ f.attr.ModificationTime = now
+ f.attr.StatusChangeTime = now
+ }
+ f.dataMu.Unlock()
+
+ // Nothing left to do unless shrinking the file.
+ if oldSize <= size {
+ return nil
+ }
+
+ oldpgend := fs.OffsetPageEnd(oldSize)
+ newpgend := fs.OffsetPageEnd(size)
+
+ // Invalidate past translations of truncated pages.
+ if newpgend != oldpgend {
+ f.mapsMu.Lock()
+ f.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ f.mapsMu.Unlock()
+ }
+
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+ f.data.Truncate(uint64(size), f.kernel.MemoryFile())
+
+ return nil
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (f *fileInodeOperations) Allocate(ctx context.Context, _ *fs.Inode, offset, length int64) error {
+ newSize := offset + length
+
+ f.attrMu.Lock()
+ defer f.attrMu.Unlock()
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+
+ if newSize <= f.attr.Size {
+ return nil
+ }
+
+ // Check if current seals allow growth.
+ if f.seals&linux.F_SEAL_GROW != 0 {
+ return syserror.EPERM
+ }
+
+ f.attr.Size = newSize
+
+ now := ktime.NowFromContext(ctx)
+ f.attr.ModificationTime = now
+ f.attr.StatusChangeTime = now
+
+ return nil
+}
+
+// AddLink implements fs.InodeOperations.AddLink.
+func (f *fileInodeOperations) AddLink() {
+ f.attrMu.Lock()
+ f.attr.Links++
+ f.attrMu.Unlock()
+}
+
+// DropLink implements fs.InodeOperations.DropLink.
+func (f *fileInodeOperations) DropLink() {
+ f.attrMu.Lock()
+ f.attr.Links--
+ f.attrMu.Unlock()
+}
+
+// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
+func (f *fileInodeOperations) NotifyStatusChange(ctx context.Context) {
+ f.attrMu.Lock()
+ f.attr.StatusChangeTime = ktime.NowFromContext(ctx)
+ f.attrMu.Unlock()
+}
+
+// IsVirtual implements fs.InodeOperations.IsVirtual.
+func (*fileInodeOperations) IsVirtual() bool {
+ return true
+}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (*fileInodeOperations) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
+
+func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ var start time.Time
+ if fs.RecordWaitTime {
+ start = time.Now()
+ }
+ reads.Increment()
+ // Zero length reads for tmpfs are no-ops.
+ if dst.NumBytes() == 0 {
+ fs.IncrementWait(readWait, start)
+ return 0, nil
+ }
+
+ // Have we reached EOF? We check for this again in
+ // fileReadWriter.ReadToBlocks to avoid holding f.attrMu (which would
+ // serialize reads) or f.dataMu (which would violate lock ordering), but
+ // check here first (before calling into MM) since reading at EOF is
+ // common: getting a return value of 0 from a read syscall is the only way
+ // to detect EOF.
+ //
+ // TODO(jamieliu): Separate out f.attr.Size and use atomics instead of
+ // f.dataMu.
+ f.dataMu.RLock()
+ size := f.attr.Size
+ f.dataMu.RUnlock()
+ if offset >= size {
+ fs.IncrementWait(readWait, start)
+ return 0, io.EOF
+ }
+
+ n, err := dst.CopyOutFrom(ctx, &fileReadWriter{f, offset})
+ if !file.Dirent.Inode.MountSource.Flags.NoAtime {
+ // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
+ f.attrMu.Lock()
+ f.attr.AccessTime = ktime.NowFromContext(ctx)
+ f.attrMu.Unlock()
+ }
+ fs.IncrementWait(readWait, start)
+ return n, err
+}
+
+func (f *fileInodeOperations) write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // Zero length writes for tmpfs are no-ops.
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ f.attrMu.Lock()
+ defer f.attrMu.Unlock()
+ // Compare Linux's mm/filemap.c:__generic_file_write_iter() => file_update_time().
+ now := ktime.NowFromContext(ctx)
+ f.attr.ModificationTime = now
+ f.attr.StatusChangeTime = now
+ return src.CopyInTo(ctx, &fileReadWriter{f, offset})
+}
+
+type fileReadWriter struct {
+ f *fileInodeOperations
+ offset int64
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ rw.f.dataMu.RLock()
+ defer rw.f.dataMu.RUnlock()
+
+ // Compute the range to read.
+ if rw.offset >= rw.f.attr.Size {
+ return 0, io.EOF
+ }
+ end := fs.ReadEndOffset(rw.offset, int64(dsts.NumBytes()), rw.f.attr.Size)
+ if end == rw.offset { // dsts.NumBytes() == 0?
+ return 0, nil
+ }
+
+ mf := rw.f.kernel.MemoryFile()
+ var done uint64
+ seg, gap := rw.f.data.Find(uint64(rw.offset))
+ for rw.offset < end {
+ mr := memmap.MappableRange{uint64(rw.offset), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Tmpfs holes are zero-filled.
+ gapmr := gap.Range().Intersect(mr)
+ dst := dsts.TakeFirst64(gapmr.Length())
+ n, err := safemem.ZeroSeq(dst)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+
+ default:
+ break
+ }
+ }
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ rw.f.dataMu.Lock()
+ defer rw.f.dataMu.Unlock()
+
+ // Compute the range to write.
+ end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
+ if end == rw.offset { // srcs.NumBytes() == 0?
+ return 0, nil
+ }
+
+ // Check if seals prevent either file growth or all writes.
+ switch {
+ case rw.f.seals&linux.F_SEAL_WRITE != 0: // Write sealed
+ return 0, syserror.EPERM
+ case end > rw.f.attr.Size && rw.f.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ // When growth is sealed, Linux effectively allows writes which would
+ // normally grow the file to partially succeed up to the current EOF,
+ // rounded down to the page boundary before the EOF.
+ //
+ // This happens because writes (and thus the growth check) for tmpfs
+ // files proceed page-by-page on Linux, and the final write to the page
+ // containing EOF fails, resulting in a partial write up to the start of
+ // that page.
+ //
+ // To emulate this behaviour, artifically truncate the write to the
+ // start of the page containing the current EOF.
+ //
+ // See Linux, mm/filemap.c:generic_perform_write() and
+ // mm/shmem.c:shmem_write_begin().
+ if pgstart := int64(usermem.Addr(rw.f.attr.Size).RoundDown()); end > pgstart {
+ end = pgstart
+ }
+ if end <= rw.offset {
+ // Truncation would result in no data being written.
+ return 0, syserror.EPERM
+ }
+ }
+
+ defer func() {
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.offset > rw.f.attr.Size {
+ rw.f.attr.Size = rw.offset
+ }
+ }()
+
+ mf := rw.f.kernel.MemoryFile()
+ // Page-aligned mr for when we need to allocate memory. RoundUp can't
+ // overflow since end is an int64.
+ pgstartaddr := usermem.Addr(rw.offset).RoundDown()
+ pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
+
+ var done uint64
+ seg, gap := rw.f.data.Find(uint64(rw.offset))
+ for rw.offset < end {
+ mr := memmap.MappableRange{uint64(rw.offset), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ if err != nil {
+ return done, err
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.offset += int64(n)
+ srcs = srcs.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Allocate memory for the write.
+ gapMR := gap.Range().Intersect(pgMR)
+ fr, err := mf.Allocate(gapMR.Length(), rw.f.memUsage)
+ if err != nil {
+ return done, err
+ }
+
+ // Write to that memory as usual.
+ seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
+
+ default:
+ break
+ }
+ }
+ return done, nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+
+ f.dataMu.RLock()
+ defer f.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if f.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ f.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := f.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ f.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if f.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, f.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+
+ f.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := f.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ f.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if f.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, f.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (f *fileInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return f.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (f *fileInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(f.attr.Size)
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ mf := f.kernel.MemoryFile()
+ cerr := f.data.Fill(ctx, required, optional, mf, f.memUsage, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := f.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: mf,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (f *fileInodeOperations) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
+
+// GetSeals returns the current set of seals on a memfd inode.
+func GetSeals(inode *fs.Inode) (uint32, error) {
+ if f, ok := inode.InodeOperations.(*fileInodeOperations); ok {
+ f.dataMu.RLock()
+ defer f.dataMu.RUnlock()
+ return f.seals, nil
+ }
+ // Not a memfd inode.
+ return 0, syserror.EINVAL
+}
+
+// AddSeals adds new file seals to a memfd inode.
+func AddSeals(inode *fs.Inode, val uint32) error {
+ if f, ok := inode.InodeOperations.(*fileInodeOperations); ok {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+ f.dataMu.Lock()
+ defer f.dataMu.Unlock()
+
+ if f.seals&linux.F_SEAL_SEAL != 0 {
+ // Seal applied which prevents addition of any new seals.
+ return syserror.EPERM
+ }
+
+ // F_SEAL_WRITE can only be added if there are no active writable maps.
+ if f.seals&linux.F_SEAL_WRITE == 0 && val&linux.F_SEAL_WRITE != 0 {
+ if f.writableMappingPages > 0 {
+ return syserror.EBUSY
+ }
+ }
+
+ // Seals can only be added, never removed.
+ f.seals |= val
+ return nil
+ }
+ // Not a memfd inode.
+ return syserror.EINVAL
+}
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
new file mode 100644
index 000000000..263d10cfe
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -0,0 +1,348 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tmpfs is a filesystem implementation backed by memory.
+package tmpfs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+var fsInfo = fs.Info{
+ Type: linux.TMPFS_MAGIC,
+
+ // TODO(b/29637826): allow configuring a tmpfs size and enforce it.
+ TotalBlocks: 0,
+ FreeBlocks: 0,
+}
+
+// rename implements fs.InodeOperations.Rename for tmpfs nodes.
+func rename(ctx context.Context, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ op, ok := oldParent.InodeOperations.(*Dir)
+ if !ok {
+ return syserror.EXDEV
+ }
+ np, ok := newParent.InodeOperations.(*Dir)
+ if !ok {
+ return syserror.EXDEV
+ }
+ return ramfs.Rename(ctx, op.ramfsDir, oldName, np.ramfsDir, newName, replacement)
+}
+
+// Dir is a directory.
+//
+// +stateify savable
+type Dir struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeIsDirTruncate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ // Ideally this would be embedded, so that we "inherit" all of the
+ // InodeOperations implemented by ramfs.Dir for free.
+ //
+ // However, ramfs.dirFileOperations stores a pointer to a ramfs.Dir,
+ // and our save/restore package does not allow saving a pointer to an
+ // embedded field elsewhere.
+ //
+ // Thus, we must make the ramfs.Dir is a field, and we delegate all the
+ // InodeOperation methods to it.
+ ramfsDir *ramfs.Dir
+
+ // kernel is used to allocate memory as storage for tmpfs Files.
+ kernel *kernel.Kernel
+}
+
+var _ fs.InodeOperations = (*Dir)(nil)
+
+// NewDir returns a new directory.
+func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
+ d := &Dir{
+ ramfsDir: ramfs.NewDir(ctx, contents, owner, perms),
+ kernel: kernel.KernelFromContext(ctx),
+ }
+
+ // Manually set the CreateOps.
+ d.ramfsDir.CreateOps = d.newCreateOps()
+
+ return fs.NewInode(d, msrc, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
+
+// afterLoad is invoked by stateify.
+func (d *Dir) afterLoad() {
+ // Per NewDir, manually set the CreateOps.
+ d.ramfsDir.CreateOps = d.newCreateOps()
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (d *Dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return d.ramfsDir.GetFile(ctx, dirent, flags)
+}
+
+// AddLink implements fs.InodeOperations.AddLink.
+func (d *Dir) AddLink() {
+ d.ramfsDir.AddLink()
+}
+
+// DropLink implements fs.InodeOperations.DropLink.
+func (d *Dir) DropLink() {
+ d.ramfsDir.DropLink()
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (d *Dir) Bind(ctx context.Context, dir *fs.Inode, name string, ep transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Dirent, error) {
+ return d.ramfsDir.Bind(ctx, dir, name, ep, perms)
+}
+
+// Create implements fs.InodeOperations.Create.
+func (d *Dir) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perms fs.FilePermissions) (*fs.File, error) {
+ return d.ramfsDir.Create(ctx, dir, name, flags, perms)
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+func (d *Dir) CreateLink(ctx context.Context, dir *fs.Inode, oldname, newname string) error {
+ return d.ramfsDir.CreateLink(ctx, dir, oldname, newname)
+}
+
+// CreateHardLink implements fs.InodeOperations.CreateHardLink.
+func (d *Dir) CreateHardLink(ctx context.Context, dir *fs.Inode, target *fs.Inode, name string) error {
+ return d.ramfsDir.CreateHardLink(ctx, dir, target, name)
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+func (d *Dir) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perms fs.FilePermissions) error {
+ return d.ramfsDir.CreateDirectory(ctx, dir, name, perms)
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (d *Dir) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perms fs.FilePermissions) error {
+ return d.ramfsDir.CreateFifo(ctx, dir, name, perms)
+}
+
+// Getxattr implements fs.InodeOperations.Getxattr.
+func (d *Dir) Getxattr(i *fs.Inode, name string) (string, error) {
+ return d.ramfsDir.Getxattr(i, name)
+}
+
+// Setxattr implements fs.InodeOperations.Setxattr.
+func (d *Dir) Setxattr(i *fs.Inode, name, value string) error {
+ return d.ramfsDir.Setxattr(i, name, value)
+}
+
+// Listxattr implements fs.InodeOperations.Listxattr.
+func (d *Dir) Listxattr(i *fs.Inode) (map[string]struct{}, error) {
+ return d.ramfsDir.Listxattr(i)
+}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (d *Dir) Lookup(ctx context.Context, i *fs.Inode, p string) (*fs.Dirent, error) {
+ return d.ramfsDir.Lookup(ctx, i, p)
+}
+
+// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
+func (d *Dir) NotifyStatusChange(ctx context.Context) {
+ d.ramfsDir.NotifyStatusChange(ctx)
+}
+
+// Remove implements fs.InodeOperations.Remove.
+func (d *Dir) Remove(ctx context.Context, i *fs.Inode, name string) error {
+ return d.ramfsDir.Remove(ctx, i, name)
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+func (d *Dir) RemoveDirectory(ctx context.Context, i *fs.Inode, name string) error {
+ return d.ramfsDir.RemoveDirectory(ctx, i, name)
+}
+
+// UnstableAttr implements fs.InodeOperations.UnstableAttr.
+func (d *Dir) UnstableAttr(ctx context.Context, i *fs.Inode) (fs.UnstableAttr, error) {
+ return d.ramfsDir.UnstableAttr(ctx, i)
+}
+
+// SetPermissions implements fs.InodeOperations.SetPermissions.
+func (d *Dir) SetPermissions(ctx context.Context, i *fs.Inode, p fs.FilePermissions) bool {
+ return d.ramfsDir.SetPermissions(ctx, i, p)
+}
+
+// SetOwner implements fs.InodeOperations.SetOwner.
+func (d *Dir) SetOwner(ctx context.Context, i *fs.Inode, owner fs.FileOwner) error {
+ return d.ramfsDir.SetOwner(ctx, i, owner)
+}
+
+// SetTimestamps implements fs.InodeOperations.SetTimestamps.
+func (d *Dir) SetTimestamps(ctx context.Context, i *fs.Inode, ts fs.TimeSpec) error {
+ return d.ramfsDir.SetTimestamps(ctx, i, ts)
+}
+
+// newCreateOps builds the custom CreateOps for this Dir.
+func (d *Dir) newCreateOps() *ramfs.CreateOps {
+ return &ramfs.CreateOps{
+ NewDir: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
+ return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil
+ },
+ NewFile: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
+ uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
+ Owner: fs.FileOwnerFromContext(ctx),
+ Perms: perms,
+ // Always start unlinked.
+ Links: 0,
+ })
+ iops := NewInMemoryFile(ctx, usage.Tmpfs, uattr)
+ return fs.NewInode(iops, dir.MountSource, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.RegularFile,
+ }), nil
+ },
+ NewSymlink: func(ctx context.Context, dir *fs.Inode, target string) (*fs.Inode, error) {
+ return NewSymlink(ctx, target, fs.FileOwnerFromContext(ctx), dir.MountSource), nil
+ },
+ NewBoundEndpoint: func(ctx context.Context, dir *fs.Inode, socket transport.BoundEndpoint, perms fs.FilePermissions) (*fs.Inode, error) {
+ return NewSocket(ctx, socket, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil
+ },
+ NewFifo: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
+ return NewFifo(ctx, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil
+ },
+ }
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (d *Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS implements fs.InodeOperations.StatFS.
+func (*Dir) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
+
+// Allocate implements fs.InodeOperations.Allocate.
+func (d *Dir) Allocate(ctx context.Context, node *fs.Inode, offset, length int64) error {
+ return d.ramfsDir.Allocate(ctx, node, offset, length)
+}
+
+// Symlink is a symlink.
+//
+// +stateify savable
+type Symlink struct {
+ ramfs.Symlink
+}
+
+// NewSymlink returns a new symlink with the provided permissions.
+func NewSymlink(ctx context.Context, target string, owner fs.FileOwner, msrc *fs.MountSource) *fs.Inode {
+ s := &Symlink{Symlink: *ramfs.NewSymlink(ctx, owner, target)}
+ return fs.NewInode(s, msrc, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Symlink,
+ })
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (s *Symlink) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS returns the tmpfs info.
+func (s *Symlink) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
+
+// Socket is a socket.
+//
+// +stateify savable
+type Socket struct {
+ ramfs.Socket
+ fsutil.InodeNotTruncatable `state:"nosave"`
+ fsutil.InodeNotAllocatable `state:"nosave"`
+}
+
+// NewSocket returns a new socket with the provided permissions.
+func NewSocket(ctx context.Context, socket transport.BoundEndpoint, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
+ s := &Socket{Socket: *ramfs.NewSocket(ctx, socket, owner, perms)}
+ return fs.NewInode(s, msrc, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Socket,
+ })
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (s *Socket) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS returns the tmpfs info.
+func (s *Socket) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
+
+// Fifo is a tmpfs named pipe.
+//
+// +stateify savable
+type Fifo struct {
+ fs.InodeOperations
+}
+
+// NewFifo creates a new named pipe.
+func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
+ // First create a pipe.
+ p := pipe.NewPipe(ctx, true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+
+ // Build pipe InodeOperations.
+ iops := pipe.NewInodeOperations(ctx, perms, p)
+
+ // Wrap the iops with our Fifo.
+ fifoIops := &Fifo{iops}
+
+ // Build a new Inode.
+ return fs.NewInode(fifoIops, msrc, fs.StableAttr{
+ DeviceID: tmpfsDevice.DeviceID(),
+ InodeID: tmpfsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Pipe,
+ })
+}
+
+// Rename implements fs.InodeOperations.Rename.
+func (f *Fifo) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return rename(ctx, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS returns the tmpfs info.
+func (*Fifo) StatFS(context.Context) (fs.Info, error) {
+ return fsInfo, nil
+}
diff --git a/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go b/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go
new file mode 100755
index 000000000..0fe2e2e93
--- /dev/null
+++ b/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go
@@ -0,0 +1,108 @@
+// automatically generated by stateify.
+
+package tmpfs
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *regularFileOperations) beforeSave() {}
+func (x *regularFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("iops", &x.iops)
+}
+
+func (x *regularFileOperations) afterLoad() {}
+func (x *regularFileOperations) load(m state.Map) {
+ m.Load("iops", &x.iops)
+}
+
+func (x *Filesystem) beforeSave() {}
+func (x *Filesystem) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *Filesystem) afterLoad() {}
+func (x *Filesystem) load(m state.Map) {
+}
+
+func (x *fileInodeOperations) beforeSave() {}
+func (x *fileInodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Save("kernel", &x.kernel)
+ m.Save("memUsage", &x.memUsage)
+ m.Save("attr", &x.attr)
+ m.Save("mappings", &x.mappings)
+ m.Save("writableMappingPages", &x.writableMappingPages)
+ m.Save("data", &x.data)
+ m.Save("seals", &x.seals)
+}
+
+func (x *fileInodeOperations) afterLoad() {}
+func (x *fileInodeOperations) load(m state.Map) {
+ m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes)
+ m.Load("kernel", &x.kernel)
+ m.Load("memUsage", &x.memUsage)
+ m.Load("attr", &x.attr)
+ m.Load("mappings", &x.mappings)
+ m.Load("writableMappingPages", &x.writableMappingPages)
+ m.Load("data", &x.data)
+ m.Load("seals", &x.seals)
+}
+
+func (x *Dir) beforeSave() {}
+func (x *Dir) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ramfsDir", &x.ramfsDir)
+ m.Save("kernel", &x.kernel)
+}
+
+func (x *Dir) load(m state.Map) {
+ m.Load("ramfsDir", &x.ramfsDir)
+ m.Load("kernel", &x.kernel)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *Symlink) beforeSave() {}
+func (x *Symlink) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Symlink", &x.Symlink)
+}
+
+func (x *Symlink) afterLoad() {}
+func (x *Symlink) load(m state.Map) {
+ m.Load("Symlink", &x.Symlink)
+}
+
+func (x *Socket) beforeSave() {}
+func (x *Socket) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Socket", &x.Socket)
+}
+
+func (x *Socket) afterLoad() {}
+func (x *Socket) load(m state.Map) {
+ m.Load("Socket", &x.Socket)
+}
+
+func (x *Fifo) beforeSave() {}
+func (x *Fifo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeOperations", &x.InodeOperations)
+}
+
+func (x *Fifo) afterLoad() {}
+func (x *Fifo) load(m state.Map) {
+ m.Load("InodeOperations", &x.InodeOperations)
+}
+
+func init() {
+ state.Register("tmpfs.regularFileOperations", (*regularFileOperations)(nil), state.Fns{Save: (*regularFileOperations).save, Load: (*regularFileOperations).load})
+ state.Register("tmpfs.Filesystem", (*Filesystem)(nil), state.Fns{Save: (*Filesystem).save, Load: (*Filesystem).load})
+ state.Register("tmpfs.fileInodeOperations", (*fileInodeOperations)(nil), state.Fns{Save: (*fileInodeOperations).save, Load: (*fileInodeOperations).load})
+ state.Register("tmpfs.Dir", (*Dir)(nil), state.Fns{Save: (*Dir).save, Load: (*Dir).load})
+ state.Register("tmpfs.Symlink", (*Symlink)(nil), state.Fns{Save: (*Symlink).save, Load: (*Symlink).load})
+ state.Register("tmpfs.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load})
+ state.Register("tmpfs.Fifo", (*Fifo)(nil), state.Fns{Save: (*Fifo).save, Load: (*Fifo).load})
+}
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
new file mode 100644
index 000000000..2603354c4
--- /dev/null
+++ b/pkg/sentry/fs/tty/dir.go
@@ -0,0 +1,339 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tty provide pseudoterminals via a devpts filesystem.
+package tty
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// dirInodeOperations is the root of a devpts mount.
+//
+// This indirectly manages all terminals within the mount.
+//
+// New Terminals are created by masterInodeOperations.GetFile, which registers
+// the slave Inode in the this directory for discovery via Lookup/Readdir. The
+// slave inode is unregistered when the master file is Released, as the slave
+// is no longer discoverable at that point.
+//
+// References on the underlying Terminal are held by masterFileOperations and
+// slaveInodeOperations.
+//
+// masterInodeOperations and slaveInodeOperations hold a pointer to
+// dirInodeOperations, which is reference counted by the refcount their
+// corresponding Dirents hold on their parent (this directory).
+//
+// dirInodeOperations implements fs.InodeOperations.
+//
+// +stateify savable
+type dirInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeIsDirAllocate `state:"nosave"`
+ fsutil.InodeIsDirTruncate `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotRenameable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+
+ // msrc is the super block this directory is on.
+ //
+ // TODO(chrisko): Plumb this through instead of storing it here.
+ msrc *fs.MountSource
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // master is the master PTY inode.
+ master *fs.Inode
+
+ // slaves contains the slave inodes reachable from the directory.
+ //
+ // A new slave is added by allocateTerminal and is removed by
+ // masterFileOperations.Release.
+ //
+ // A reference is held on every slave in the map.
+ slaves map[uint32]*fs.Inode
+
+ // dentryMap is a SortedDentryMap used to implement Readdir containing
+ // the master and all entries in slaves.
+ dentryMap *fs.SortedDentryMap
+
+ // next is the next pty index to use.
+ //
+ // TODO(b/29356795): reuse indices when ptys are closed.
+ next uint32
+}
+
+var _ fs.InodeOperations = (*dirInodeOperations)(nil)
+
+// newDir creates a new dir with a ptmx file and no terminals.
+func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode {
+ d := &dirInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0555), linux.DEVPTS_SUPER_MAGIC),
+ msrc: m,
+ slaves: make(map[uint32]*fs.Inode),
+ dentryMap: fs.NewSortedDentryMap(nil),
+ }
+ // Linux devpts uses a default mode of 0000 for ptmx which can be
+ // changed with the ptmxmode mount option. However, that default is not
+ // useful here (since we'd *always* need the mount option, so it is
+ // accessible by default).
+ d.master = newMasterInode(ctx, d, fs.RootOwner, fs.FilePermsFromMode(0666))
+ d.dentryMap.Add("ptmx", fs.DentAttr{
+ Type: d.master.StableAttr.Type,
+ InodeID: d.master.StableAttr.InodeID,
+ })
+
+ return fs.NewInode(d, m, fs.StableAttr{
+ DeviceID: ptsDevice.DeviceID(),
+ // N.B. Linux always uses inode id 1 for the directory. See
+ // fs/devpts/inode.c:devpts_fill_super.
+ //
+ // TODO(b/75267214): Since ptsDevice must be shared between
+ // different mounts, we must not assign fixed numbers.
+ InodeID: ptsDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.Directory,
+ })
+}
+
+// Release implements fs.InodeOperations.Release.
+func (d *dirInodeOperations) Release(ctx context.Context) {
+ d.master.DecRef()
+ if len(d.slaves) != 0 {
+ panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d))
+ }
+}
+
+// Lookup implements fs.InodeOperations.Lookup.
+func (d *dirInodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Master?
+ if name == "ptmx" {
+ d.master.IncRef()
+ return fs.NewDirent(d.master, name), nil
+ }
+
+ // Slave number?
+ n, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ // Not found.
+ return nil, syserror.ENOENT
+ }
+
+ s, ok := d.slaves[uint32(n)]
+ if !ok {
+ return nil, syserror.ENOENT
+ }
+
+ s.IncRef()
+ return fs.NewDirent(s, name), nil
+}
+
+// Create implements fs.InodeOperations.Create.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
+ return nil, syserror.EACCES
+}
+
+// CreateDirectory implements fs.InodeOperations.CreateDirectory.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ return syserror.EACCES
+}
+
+// CreateLink implements fs.InodeOperations.CreateLink.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname, newname string) error {
+ return syserror.EACCES
+}
+
+// CreateHardLink implements fs.InodeOperations.CreateHardLink.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) CreateHardLink(ctx context.Context, dir *fs.Inode, target *fs.Inode, name string) error {
+ return syserror.EACCES
+}
+
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+//
+// Creation is never allowed.
+func (d *dirInodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ return syserror.EACCES
+}
+
+// Remove implements fs.InodeOperations.Remove.
+//
+// Removal is never allowed.
+func (d *dirInodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
+ return syserror.EPERM
+}
+
+// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
+//
+// Removal is never allowed.
+func (d *dirInodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
+ return syserror.EPERM
+}
+
+// Bind implements fs.InodeOperations.Bind.
+func (d *dirInodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, data transport.BoundEndpoint, perm fs.FilePermissions) (*fs.Dirent, error) {
+ return nil, syserror.EPERM
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (d *dirInodeOperations) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &dirFileOperations{di: d}), nil
+}
+
+// allocateTerminal creates a new Terminal and installs a pts node for it.
+//
+// The caller must call DecRef when done with the returned Terminal.
+func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, error) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ n := d.next
+ if n == math.MaxUint32 {
+ return nil, syserror.ENOMEM
+ }
+
+ if _, ok := d.slaves[n]; ok {
+ panic(fmt.Sprintf("pty index collision; index %d already exists", n))
+ }
+
+ t := newTerminal(ctx, d, n)
+ d.next++
+
+ // The reference returned by newTerminal is returned to the caller.
+ // Take another for the slave inode.
+ t.IncRef()
+
+ // Create a pts node. The owner is based on the context that opens
+ // ptmx.
+ creds := auth.CredentialsFromContext(ctx)
+ uid, gid := creds.EffectiveKUID, creds.EffectiveKGID
+ slave := newSlaveInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666))
+
+ d.slaves[n] = slave
+ d.dentryMap.Add(strconv.FormatUint(uint64(n), 10), fs.DentAttr{
+ Type: slave.StableAttr.Type,
+ InodeID: slave.StableAttr.InodeID,
+ })
+
+ return t, nil
+}
+
+// masterClose is called when the master end of t is closed.
+func (d *dirInodeOperations) masterClose(t *Terminal) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // The slave end disappears from the directory when the master end is
+ // closed, even if the slave end is open elsewhere.
+ //
+ // N.B. since we're using a backdoor method to remove a directory entry
+ // we won't properly fire inotify events like Linux would.
+ s, ok := d.slaves[t.n]
+ if !ok {
+ panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d))
+ }
+
+ s.DecRef()
+ delete(d.slaves, t.n)
+ d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10))
+}
+
+// dirFileOperations are the fs.FileOperations for the directory.
+//
+// This is nearly identical to fsutil.DirFileOperations, except that it takes
+// df.di.mu in IterateDir.
+//
+// +stateify savable
+type dirFileOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ // di is the inode operations.
+ di *dirInodeOperations
+
+ // dirCursor contains the name of the last directory entry that was
+ // serialized.
+ dirCursor string
+}
+
+var _ fs.FileOperations = (*dirFileOperations)(nil)
+
+// IterateDir implements DirIterator.IterateDir.
+func (df *dirFileOperations) IterateDir(ctx context.Context, dirCtx *fs.DirCtx, offset int) (int, error) {
+ df.di.mu.Lock()
+ defer df.di.mu.Unlock()
+
+ n, err := fs.GenericReaddir(dirCtx, df.di.dentryMap)
+ return offset + n, err
+}
+
+// Readdir implements FileOperations.Readdir.
+func (df *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
+ root := fs.RootFromContext(ctx)
+ if root != nil {
+ defer root.DecRef()
+ }
+ dirCtx := &fs.DirCtx{
+ Serializer: serializer,
+ DirCursor: &df.dirCursor,
+ }
+ return fs.DirentReaddir(ctx, file.Dirent, df, root, dirCtx, file.Offset())
+}
+
+// Read implements FileOperations.Read
+func (df *dirFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// Write implements FileOperations.Write.
+func (df *dirFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syserror.EISDIR
+}
diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go
new file mode 100644
index 000000000..701b2f7d9
--- /dev/null
+++ b/pkg/sentry/fs/tty/fs.go
@@ -0,0 +1,104 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// ptsDevice is the pseudo-filesystem device.
+var ptsDevice = device.NewAnonDevice()
+
+// filesystem is a devpts filesystem.
+//
+// This devpts is always in the new "multi-instance" mode. i.e., it contains a
+// ptmx device tied to this mount.
+//
+// +stateify savable
+type filesystem struct{}
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// Name matches drivers/devpts/indoe.c:devpts_fs_type.name.
+func (*filesystem) Name() string {
+ return "devpts"
+}
+
+// AllowUserMount allows users to mount(2) this file system.
+func (*filesystem) AllowUserMount() bool {
+ // TODO(b/29356795): Users may mount this once the terminals are in a
+ // usable state.
+ return false
+}
+
+// AllowUserList allows this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return true
+}
+
+// Flags returns that there is nothing special about this file system.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
+
+// MountSource returns a devpts root that can be positioned in the vfs.
+func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
+ // device is always ignored.
+
+ // No options are supported.
+ if data != "" {
+ return nil, syserror.EINVAL
+ }
+
+ return newDir(ctx, fs.NewMountSource(&superOperations{}, f, flags)), nil
+}
+
+// superOperations implements fs.MountSourceOperations, preventing caching.
+//
+// +stateify savable
+type superOperations struct{}
+
+// Revalidate implements fs.DirentOperations.Revalidate.
+//
+// It always returns true, forcing a Lookup for all entries.
+//
+// Slave entries are dropped from dir when their master is closed, so an
+// existing slave Dirent in the tree is not sufficient to guarantee that it
+// still exists on the filesystem.
+func (superOperations) Revalidate(context.Context, string, *fs.Inode, *fs.Inode) bool {
+ return true
+}
+
+// Keep implements fs.DirentOperations.Keep.
+//
+// Keep returns false because Revalidate would force a lookup on cached entries
+// anyways.
+func (superOperations) Keep(*fs.Dirent) bool {
+ return false
+}
+
+// ResetInodeMappings implements MountSourceOperations.ResetInodeMappings.
+func (superOperations) ResetInodeMappings() {}
+
+// SaveInodeMapping implements MountSourceOperations.SaveInodeMapping.
+func (superOperations) SaveInodeMapping(*fs.Inode, string) {}
+
+// Destroy implements MountSourceOperations.Destroy.
+func (superOperations) Destroy() {}
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
new file mode 100644
index 000000000..20d29d130
--- /dev/null
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -0,0 +1,443 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "bytes"
+ "sync"
+ "unicode/utf8"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // canonMaxBytes is the number of bytes that fit into a single line of
+ // terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE
+ // in include/linux/tty.h.
+ canonMaxBytes = 4096
+
+ // nonCanonMaxBytes is the maximum number of bytes that can be read at
+ // a time in noncanonical mode.
+ nonCanonMaxBytes = canonMaxBytes - 1
+
+ spacesPerTab = 8
+)
+
+// lineDiscipline dictates how input and output are handled between the
+// pseudoterminal (pty) master and slave. It can be configured to alter I/O,
+// modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man
+// pages are good resources for how to affect the line discipline:
+//
+// * termios(3)
+// * tty_ioctl(4)
+//
+// This file corresponds most closely to drivers/tty/n_tty.c.
+//
+// lineDiscipline has a simple structure but supports a multitude of options
+// (see the above man pages). It consists of two queues of bytes: one from the
+// terminal master to slave (the input queue) and one from slave to master (the
+// output queue). When bytes are written to one end of the pty, the line
+// discipline reads the bytes, modifies them or takes special action if
+// required, and enqueues them to be read by the other end of the pty:
+//
+// input from terminal +-------------+ input to process (e.g. bash)
+// +------------------------>| input queue |---------------------------+
+// | (inputQueueWrite) +-------------+ (inputQueueRead) |
+// | |
+// | v
+// masterFD slaveFD
+// ^ |
+// | |
+// | output to terminal +--------------+ output from process |
+// +------------------------| output queue |<--------------------------+
+// (outputQueueRead) +--------------+ (outputQueueWrite)
+//
+// Lock order:
+// termiosMu
+// inQueue.mu
+// outQueue.mu
+//
+// +stateify savable
+type lineDiscipline struct {
+ // sizeMu protects size.
+ sizeMu sync.Mutex `state:"nosave"`
+
+ // size is the terminal size (width and height).
+ size linux.WindowSize
+
+ // inQueue is the input queue of the terminal.
+ inQueue queue
+
+ // outQueue is the output queue of the terminal.
+ outQueue queue
+
+ // termiosMu protects termios.
+ termiosMu sync.RWMutex `state:"nosave"`
+
+ // termios is the terminal configuration used by the lineDiscipline.
+ termios linux.KernelTermios
+
+ // column is the location in a row of the cursor. This is important for
+ // handling certain special characters like backspace.
+ column int
+
+ // masterWaiter is used to wait on the master end of the TTY.
+ masterWaiter waiter.Queue `state:"zerovalue"`
+
+ // slaveWaiter is used to wait on the slave end of the TTY.
+ slaveWaiter waiter.Queue `state:"zerovalue"`
+}
+
+func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
+ ld := lineDiscipline{termios: termios}
+ ld.inQueue.transformer = &inputQueueTransformer{}
+ ld.outQueue.transformer = &outputQueueTransformer{}
+ return &ld
+}
+
+// getTermios gets the linux.Termios for the tty.
+func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ // We must copy a Termios struct, not KernelTermios.
+ t := l.termios.ToTermios()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// setTermios sets a linux.Termios for the tty.
+func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.Lock()
+ defer l.termiosMu.Unlock()
+ oldCanonEnabled := l.termios.LEnabled(linux.ICANON)
+ // We must copy a Termios struct, not KernelTermios.
+ var t linux.Termios
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ l.termios.FromTermios(t)
+
+ // If canonical mode is turned off, move bytes from inQueue's wait
+ // buffer to its read buffer. Anything already in the read buffer is
+ // now readable.
+ if oldCanonEnabled && !l.termios.LEnabled(linux.ICANON) {
+ l.inQueue.pushWaitBuf(l)
+ l.inQueue.readable = true
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+
+ return 0, err
+}
+
+func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) masterReadiness() waiter.EventMask {
+ // We don't have to lock a termios because the default master termios
+ // is immutable.
+ return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios)
+}
+
+func (l *lineDiscipline) slaveReadiness() waiter.EventMask {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios)
+}
+
+func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.inQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.inQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.inQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.outQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.outQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.outQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// transformer is a helper interface to make it easier to stateify queue.
+type transformer interface {
+ // transform functions require queue's mutex to be held.
+ transform(*lineDiscipline, *queue, []byte) int
+}
+
+// outputQueueTransformer implements transformer. It performs line discipline
+// transformations on the output queue.
+//
+// +stateify savable
+type outputQueueTransformer struct{}
+
+// transform does output processing for one end of the pty. See
+// drivers/tty/n_tty.c:do_output_char for an analogous kernel function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*outputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // transformOutput is effectively always in noncanonical mode, as the
+ // master termios never has ICANON set.
+
+ if !l.termios.OEnabled(linux.OPOST) {
+ q.readBuf = append(q.readBuf, buf...)
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return len(buf)
+ }
+
+ var ret int
+ for len(buf) > 0 {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ ret += size
+ buf = buf[size:]
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\n':
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ if l.termios.OEnabled(linux.ONLCR) {
+ q.readBuf = append(q.readBuf, '\r', '\n')
+ continue
+ }
+ case '\r':
+ if l.termios.OEnabled(linux.ONOCR) && l.column == 0 {
+ continue
+ }
+ if l.termios.OEnabled(linux.OCRNL) {
+ cBytes[0] = '\n'
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ break
+ }
+ l.column = 0
+ case '\t':
+ spaces := spacesPerTab - l.column%spacesPerTab
+ if l.termios.OutputFlags&linux.TABDLY == linux.XTABS {
+ l.column += spaces
+ q.readBuf = append(q.readBuf, bytes.Repeat([]byte{' '}, spacesPerTab)...)
+ continue
+ }
+ l.column += spaces
+ case '\b':
+ if l.column > 0 {
+ l.column--
+ }
+ default:
+ l.column++
+ }
+ q.readBuf = append(q.readBuf, cBytes...)
+ }
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return ret
+}
+
+// inputQueueTransformer implements transformer. It performs line discipline
+// transformations on the input queue.
+//
+// +stateify savable
+type inputQueueTransformer struct{}
+
+// transform does input processing for one end of the pty. Characters read are
+// transformed according to flags set in the termios struct. See
+// drivers/tty/n_tty.c:n_tty_receive_char_special for an analogous kernel
+// function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*inputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // If there's a line waiting to be read in canonical mode, don't write
+ // anything else to the read buffer.
+ if l.termios.LEnabled(linux.ICANON) && q.readable {
+ return 0
+ }
+
+ maxBytes := nonCanonMaxBytes
+ if l.termios.LEnabled(linux.ICANON) {
+ maxBytes = canonMaxBytes
+ }
+
+ var ret int
+ for len(buf) > 0 && len(q.readBuf) < canonMaxBytes {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\r':
+ if l.termios.IEnabled(linux.IGNCR) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+ if l.termios.IEnabled(linux.ICRNL) {
+ cBytes[0] = '\n'
+ }
+ case '\n':
+ if l.termios.IEnabled(linux.INLCR) {
+ cBytes[0] = '\r'
+ }
+ }
+
+ // In canonical mode, we discard non-terminating characters
+ // after the first 4095.
+ if l.shouldDiscard(q, cBytes) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+
+ // Stop if the buffer would be overfilled.
+ if len(q.readBuf)+size > maxBytes {
+ break
+ }
+ buf = buf[size:]
+ ret += size
+
+ // If we get EOF, make the buffer available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsEOF(cBytes[0]) {
+ q.readable = true
+ break
+ }
+
+ q.readBuf = append(q.readBuf, cBytes...)
+
+ // Anything written to the readBuf will have to be echoed.
+ if l.termios.LEnabled(linux.ECHO) {
+ l.outQueue.writeBytes(cBytes, l)
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+
+ // If we finish a line, make it available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsTerminating(cBytes) {
+ q.readable = true
+ break
+ }
+ }
+
+ // In noncanonical mode, everything is readable.
+ if !l.termios.LEnabled(linux.ICANON) && len(q.readBuf) > 0 {
+ q.readable = true
+ }
+
+ return ret
+}
+
+// shouldDiscard returns whether c should be discarded. In canonical mode, if
+// too many bytes are enqueued, we keep reading input and discarding it until
+// we find a terminating character. Signal/echo processing still occurs.
+//
+// Precondition:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (l *lineDiscipline) shouldDiscard(q *queue, cBytes []byte) bool {
+ return l.termios.LEnabled(linux.ICANON) && len(q.readBuf)+len(cBytes) >= canonMaxBytes && !l.termios.IsTerminating(cBytes)
+}
+
+// peek returns the size in bytes of the next character to process. As long as
+// b isn't empty, peek returns a value of at least 1.
+func (l *lineDiscipline) peek(b []byte) int {
+ size := 1
+ // If UTF-8 support is enabled, runes might be multiple bytes.
+ if l.termios.IEnabled(linux.IUTF8) {
+ _, size = utf8.DecodeRune(b)
+ }
+ return size
+}
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
new file mode 100644
index 000000000..afdf44cd1
--- /dev/null
+++ b/pkg/sentry/fs/tty/master.go
@@ -0,0 +1,220 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// masterInodeOperations are the fs.InodeOperations for the master end of the
+// Terminal (ptmx file).
+//
+// +stateify savable
+type masterInodeOperations struct {
+ fsutil.SimpleFileInode
+
+ // d is the containing dir.
+ d *dirInodeOperations
+}
+
+var _ fs.InodeOperations = (*masterInodeOperations)(nil)
+
+// newMasterInode creates an Inode for the master end of a terminal.
+func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode {
+ iops := &masterInodeOperations{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, owner, p, linux.DEVPTS_SUPER_MAGIC),
+ d: d,
+ }
+
+ return fs.NewInode(iops, d.msrc, fs.StableAttr{
+ DeviceID: ptsDevice.DeviceID(),
+ // N.B. Linux always uses inode id 2 for ptmx. See
+ // fs/devpts/inode.c:mknod_ptmx.
+ //
+ // TODO(b/75267214): Since ptsDevice must be shared between
+ // different mounts, we must not assign fixed numbers.
+ InodeID: ptsDevice.NextIno(),
+ Type: fs.CharacterDevice,
+ // See fs/devpts/inode.c:devpts_fill_super.
+ BlockSize: 1024,
+ // The PTY master effectively has two different major/minor
+ // device numbers.
+ //
+ // This one is returned by stat for both opened and unopened
+ // instances of this inode.
+ //
+ // When the inode is opened (GetFile), a new device number is
+ // allocated based on major UNIX98_PTY_MASTER_MAJOR and the tty
+ // index as minor number. However, this device number is only
+ // accessible via ioctl(TIOCGDEV) and /proc/TID/stat.
+ DeviceFileMajor: linux.TTYAUX_MAJOR,
+ DeviceFileMinor: linux.PTMX_MINOR,
+ })
+}
+
+// Release implements fs.InodeOperations.Release.
+func (mi *masterInodeOperations) Release(ctx context.Context) {
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+//
+// It allocates a new terminal.
+func (mi *masterInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ t, err := mi.d.allocateTerminal(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return fs.NewFile(ctx, d, flags, &masterFileOperations{
+ d: mi.d,
+ t: t,
+ }), nil
+}
+
+// masterFileOperations are the fs.FileOperations for the master end of a terminal.
+//
+// +stateify savable
+type masterFileOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // d is the containing dir.
+ d *dirInodeOperations
+
+ // t is the connected Terminal.
+ t *Terminal
+}
+
+var _ fs.FileOperations = (*masterFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (mf *masterFileOperations) Release() {
+ mf.d.masterClose(mf.t)
+ mf.t.DecRef()
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (mf *masterFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ mf.t.ld.masterWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (mf *masterFileOperations) EventUnregister(e *waiter.Entry) {
+ mf.t.ld.masterWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (mf *masterFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mf.t.ld.masterReadiness()
+}
+
+// Read implements fs.FileOperations.Read.
+func (mf *masterFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ return mf.t.ld.outputQueueRead(ctx, dst)
+}
+
+// Write implements fs.FileOperations.Write.
+func (mf *masterFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ return mf.t.ld.inputQueueWrite(ctx, src)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (mf *masterFileOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the output queue read buffer.
+ return 0, mf.t.ld.outputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ // N.B. TCGETS on the master actually returns the configuration
+ // of the slave end.
+ return mf.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ // N.B. TCSETS on the master actually affects the configuration
+ // of the slave end.
+ return mf.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return mf.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mf.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCSPTLCK:
+ // TODO(b/29356795): Implement pty locking. For now just pretend we do.
+ return 0, nil
+ case linux.TIOCGWINSZ:
+ return 0, mf.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, mf.t.ld.setWindowSize(ctx, io, args)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid.
+func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
+ switch cmd {
+ case linux.TCGETS,
+ linux.TCSETS,
+ linux.TCSETSW,
+ linux.TCSETSF,
+ linux.TIOCGPGRP,
+ linux.TIOCSPGRP,
+ linux.TIOCGWINSZ,
+ linux.TIOCSWINSZ,
+ linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCNOTTY,
+ linux.TIOCSCTTY,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ }
+}
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
new file mode 100644
index 000000000..11fb92be3
--- /dev/null
+++ b/pkg/sentry/fs/tty/queue.go
@@ -0,0 +1,244 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// waitBufMaxBytes is the maximum size of a wait buffer. It is based on
+// TTYB_DEFAULT_MEM_LIMIT.
+const waitBufMaxBytes = 131072
+
+// queue represents one of the input or output queues between a pty master and
+// slave. Bytes written to a queue are added to the read buffer until it is
+// full, at which point they are written to the wait buffer. Bytes are
+// processed (i.e. undergo termios transformations) as they are added to the
+// read buffer. The read buffer is readable when its length is nonzero and
+// readable is true.
+//
+// +stateify savable
+type queue struct {
+ // mu protects everything in queue.
+ mu sync.Mutex `state:"nosave"`
+
+ // readBuf is buffer of data ready to be read when readable is true.
+ // This data has been processed.
+ readBuf []byte
+
+ // waitBuf contains data that can't fit into readBuf. It is put here
+ // until it can be loaded into the read buffer. waitBuf contains data
+ // that hasn't been processed.
+ waitBuf [][]byte
+ waitBufLen uint64
+
+ // readable indicates whether the read buffer can be read from. In
+ // canonical mode, there can be an unterminated line in the read buffer,
+ // so readable must be checked.
+ readable bool
+
+ // transform is the the queue's function for transforming bytes
+ // entering the queue. For example, transform might convert all '\r's
+ // entering the queue to '\n's.
+ transformer
+}
+
+// readReadiness returns whether q is ready to be read from.
+func (q *queue) readReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if len(q.readBuf) > 0 && q.readable {
+ return waiter.EventIn
+ }
+ return waiter.EventMask(0)
+}
+
+// writeReadiness returns whether q is ready to be written to.
+func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if q.waitBufLen < waitBufMaxBytes {
+ return waiter.EventOut
+ }
+ return waiter.EventMask(0)
+}
+
+// readableSize writes the number of readable bytes to userspace.
+func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ var size int32
+ if q.readable {
+ size = int32(len(q.readBuf))
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+
+}
+
+// read reads from q to userspace. It returns the number of bytes read as well
+// as whether the read caused more readable data to become available (whether
+// data was pushed from the wait buffer to the read buffer).
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.readable {
+ return 0, false, syserror.ErrWouldBlock
+ }
+
+ if dst.NumBytes() > canonMaxBytes {
+ dst = dst.TakeFirst(canonMaxBytes)
+ }
+
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dst safemem.BlockSeq) (uint64, error) {
+ src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(q.readBuf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.readBuf = q.readBuf[n:]
+
+ // If we read everything, this queue is no longer readable.
+ if len(q.readBuf) == 0 {
+ q.readable = false
+ }
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, false, err
+ }
+
+ // Move data from the queue's wait buffer to its read buffer.
+ nPushed := q.pushWaitBufLocked(l)
+
+ return int64(n), nPushed > 0, nil
+}
+
+// write writes to q from userspace.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Copy data into the wait buffer.
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(src safemem.BlockSeq) (uint64, error) {
+ copyLen := src.NumBytes()
+ room := waitBufMaxBytes - q.waitBufLen
+ // If out of room, return EAGAIN.
+ if room == 0 && copyLen > 0 {
+ return 0, syserror.ErrWouldBlock
+ }
+ // Cap the size of the wait buffer.
+ if copyLen > room {
+ copyLen = room
+ src = src.TakeFirst64(room)
+ }
+ buf := make([]byte, copyLen)
+
+ // Copy the data into the wait buffer.
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.waitBufAppend(buf)
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, err
+ }
+
+ // Push data from the wait to the read buffer.
+ q.pushWaitBufLocked(l)
+
+ return n, nil
+}
+
+// writeBytes writes to q from b.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) writeBytes(b []byte, l *lineDiscipline) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Write to the wait buffer.
+ q.waitBufAppend(b)
+ q.pushWaitBufLocked(l)
+}
+
+// pushWaitBuf fills the queue's read buffer with data from the wait buffer.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) pushWaitBuf(l *lineDiscipline) int {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ return q.pushWaitBufLocked(l)
+}
+
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be locked.
+func (q *queue) pushWaitBufLocked(l *lineDiscipline) int {
+ if q.waitBufLen == 0 {
+ return 0
+ }
+
+ // Move data from the wait to the read buffer.
+ var total int
+ var i int
+ for i = 0; i < len(q.waitBuf); i++ {
+ n := q.transform(l, q, q.waitBuf[i])
+ total += n
+ if n != len(q.waitBuf[i]) {
+ // The read buffer filled up without consuming the
+ // entire buffer.
+ q.waitBuf[i] = q.waitBuf[i][n:]
+ break
+ }
+ }
+
+ // Update wait buffer based on consumed data.
+ q.waitBuf = q.waitBuf[i:]
+ q.waitBufLen -= uint64(total)
+
+ return total
+}
+
+// Precondition: q.mu must be locked.
+func (q *queue) waitBufAppend(b []byte) {
+ q.waitBuf = append(q.waitBuf, b)
+ q.waitBufLen += uint64(len(b))
+}
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
new file mode 100644
index 000000000..2abf32e57
--- /dev/null
+++ b/pkg/sentry/fs/tty/slave.go
@@ -0,0 +1,162 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// slaveInodeOperations are the fs.InodeOperations for the slave end of the
+// Terminal (pts file).
+//
+// +stateify savable
+type slaveInodeOperations struct {
+ fsutil.SimpleFileInode
+
+ // d is the containing dir.
+ d *dirInodeOperations
+
+ // t is the connected Terminal.
+ t *Terminal
+}
+
+var _ fs.InodeOperations = (*slaveInodeOperations)(nil)
+
+// newSlaveInode creates an fs.Inode for the slave end of a terminal.
+//
+// newSlaveInode takes ownership of t.
+func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode {
+ iops := &slaveInodeOperations{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, owner, p, linux.DEVPTS_SUPER_MAGIC),
+ d: d,
+ t: t,
+ }
+
+ return fs.NewInode(iops, d.msrc, fs.StableAttr{
+ DeviceID: ptsDevice.DeviceID(),
+ // N.B. Linux always uses inode id = tty index + 3. See
+ // fs/devpts/inode.c:devpts_pty_new.
+ //
+ // TODO(b/75267214): Since ptsDevice must be shared between
+ // different mounts, we must not assign fixed numbers.
+ InodeID: ptsDevice.NextIno(),
+ Type: fs.CharacterDevice,
+ // See fs/devpts/inode.c:devpts_fill_super.
+ BlockSize: 1024,
+ DeviceFileMajor: linux.UNIX98_PTY_SLAVE_MAJOR,
+ DeviceFileMinor: t.n,
+ })
+}
+
+// Release implements fs.InodeOperations.Release.
+func (si *slaveInodeOperations) Release(ctx context.Context) {
+ si.t.DecRef()
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+//
+// This may race with destruction of the terminal. If the terminal is gone, it
+// returns ENOENT.
+func (si *slaveInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &slaveFileOperations{si: si}), nil
+}
+
+// slaveFileOperations are the fs.FileOperations for the slave end of a terminal.
+//
+// +stateify savable
+type slaveFileOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // si is the inode operations.
+ si *slaveInodeOperations
+}
+
+var _ fs.FileOperations = (*slaveFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (sf *slaveFileOperations) Release() {
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (sf *slaveFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ sf.si.t.ld.slaveWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (sf *slaveFileOperations) EventUnregister(e *waiter.Entry) {
+ sf.si.t.ld.slaveWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (sf *slaveFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return sf.si.t.ld.slaveReadiness()
+}
+
+// Read implements fs.FileOperations.Read.
+func (sf *slaveFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ return sf.si.t.ld.inputQueueRead(ctx, dst)
+}
+
+// Write implements fs.FileOperations.Write.
+func (sf *slaveFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ return sf.si.t.ld.outputQueueWrite(ctx, src)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (sf *slaveFileOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the input queue read buffer.
+ return 0, sf.si.t.ld.inputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ return sf.si.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ return sf.si.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return sf.si.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sf.si.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCGWINSZ:
+ return 0, sf.si.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, sf.si.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ // TODO(b/129283598): Implement once we have support for job
+ // control.
+ return 0, nil
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
new file mode 100644
index 000000000..2b4160ba5
--- /dev/null
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tty
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// Terminal is a pseudoterminal.
+//
+// +stateify savable
+type Terminal struct {
+ refs.AtomicRefCount
+
+ // n is the terminal index.
+ n uint32
+
+ // d is the containing directory.
+ d *dirInodeOperations
+
+ // ld is the line discipline of the terminal.
+ ld *lineDiscipline
+}
+
+func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal {
+ termios := linux.DefaultSlaveTermios
+ return &Terminal{
+ d: d,
+ n: n,
+ ld: newLineDiscipline(termios),
+ }
+}
diff --git a/pkg/sentry/fs/tty/tty_state_autogen.go b/pkg/sentry/fs/tty/tty_state_autogen.go
new file mode 100755
index 000000000..303e2eef2
--- /dev/null
+++ b/pkg/sentry/fs/tty/tty_state_autogen.go
@@ -0,0 +1,202 @@
+// automatically generated by stateify.
+
+package tty
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *dirInodeOperations) beforeSave() {}
+func (x *dirInodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("msrc", &x.msrc)
+ m.Save("master", &x.master)
+ m.Save("slaves", &x.slaves)
+ m.Save("dentryMap", &x.dentryMap)
+ m.Save("next", &x.next)
+}
+
+func (x *dirInodeOperations) afterLoad() {}
+func (x *dirInodeOperations) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("msrc", &x.msrc)
+ m.Load("master", &x.master)
+ m.Load("slaves", &x.slaves)
+ m.Load("dentryMap", &x.dentryMap)
+ m.Load("next", &x.next)
+}
+
+func (x *dirFileOperations) beforeSave() {}
+func (x *dirFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("di", &x.di)
+ m.Save("dirCursor", &x.dirCursor)
+}
+
+func (x *dirFileOperations) afterLoad() {}
+func (x *dirFileOperations) load(m state.Map) {
+ m.Load("di", &x.di)
+ m.Load("dirCursor", &x.dirCursor)
+}
+
+func (x *filesystem) beforeSave() {}
+func (x *filesystem) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *filesystem) afterLoad() {}
+func (x *filesystem) load(m state.Map) {
+}
+
+func (x *superOperations) beforeSave() {}
+func (x *superOperations) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *superOperations) afterLoad() {}
+func (x *superOperations) load(m state.Map) {
+}
+
+func (x *lineDiscipline) beforeSave() {}
+func (x *lineDiscipline) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.masterWaiter) { m.Failf("masterWaiter is %v, expected zero", x.masterWaiter) }
+ if !state.IsZeroValue(x.slaveWaiter) { m.Failf("slaveWaiter is %v, expected zero", x.slaveWaiter) }
+ m.Save("size", &x.size)
+ m.Save("inQueue", &x.inQueue)
+ m.Save("outQueue", &x.outQueue)
+ m.Save("termios", &x.termios)
+ m.Save("column", &x.column)
+}
+
+func (x *lineDiscipline) afterLoad() {}
+func (x *lineDiscipline) load(m state.Map) {
+ m.Load("size", &x.size)
+ m.Load("inQueue", &x.inQueue)
+ m.Load("outQueue", &x.outQueue)
+ m.Load("termios", &x.termios)
+ m.Load("column", &x.column)
+}
+
+func (x *outputQueueTransformer) beforeSave() {}
+func (x *outputQueueTransformer) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *outputQueueTransformer) afterLoad() {}
+func (x *outputQueueTransformer) load(m state.Map) {
+}
+
+func (x *inputQueueTransformer) beforeSave() {}
+func (x *inputQueueTransformer) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *inputQueueTransformer) afterLoad() {}
+func (x *inputQueueTransformer) load(m state.Map) {
+}
+
+func (x *masterInodeOperations) beforeSave() {}
+func (x *masterInodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+ m.Save("d", &x.d)
+}
+
+func (x *masterInodeOperations) afterLoad() {}
+func (x *masterInodeOperations) load(m state.Map) {
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+ m.Load("d", &x.d)
+}
+
+func (x *masterFileOperations) beforeSave() {}
+func (x *masterFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("d", &x.d)
+ m.Save("t", &x.t)
+}
+
+func (x *masterFileOperations) afterLoad() {}
+func (x *masterFileOperations) load(m state.Map) {
+ m.Load("d", &x.d)
+ m.Load("t", &x.t)
+}
+
+func (x *queue) beforeSave() {}
+func (x *queue) save(m state.Map) {
+ x.beforeSave()
+ m.Save("readBuf", &x.readBuf)
+ m.Save("waitBuf", &x.waitBuf)
+ m.Save("waitBufLen", &x.waitBufLen)
+ m.Save("readable", &x.readable)
+ m.Save("transformer", &x.transformer)
+}
+
+func (x *queue) afterLoad() {}
+func (x *queue) load(m state.Map) {
+ m.Load("readBuf", &x.readBuf)
+ m.Load("waitBuf", &x.waitBuf)
+ m.Load("waitBufLen", &x.waitBufLen)
+ m.Load("readable", &x.readable)
+ m.Load("transformer", &x.transformer)
+}
+
+func (x *slaveInodeOperations) beforeSave() {}
+func (x *slaveInodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SimpleFileInode", &x.SimpleFileInode)
+ m.Save("d", &x.d)
+ m.Save("t", &x.t)
+}
+
+func (x *slaveInodeOperations) afterLoad() {}
+func (x *slaveInodeOperations) load(m state.Map) {
+ m.Load("SimpleFileInode", &x.SimpleFileInode)
+ m.Load("d", &x.d)
+ m.Load("t", &x.t)
+}
+
+func (x *slaveFileOperations) beforeSave() {}
+func (x *slaveFileOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("si", &x.si)
+}
+
+func (x *slaveFileOperations) afterLoad() {}
+func (x *slaveFileOperations) load(m state.Map) {
+ m.Load("si", &x.si)
+}
+
+func (x *Terminal) beforeSave() {}
+func (x *Terminal) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("n", &x.n)
+ m.Save("d", &x.d)
+ m.Save("ld", &x.ld)
+}
+
+func (x *Terminal) afterLoad() {}
+func (x *Terminal) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("n", &x.n)
+ m.Load("d", &x.d)
+ m.Load("ld", &x.ld)
+}
+
+func init() {
+ state.Register("tty.dirInodeOperations", (*dirInodeOperations)(nil), state.Fns{Save: (*dirInodeOperations).save, Load: (*dirInodeOperations).load})
+ state.Register("tty.dirFileOperations", (*dirFileOperations)(nil), state.Fns{Save: (*dirFileOperations).save, Load: (*dirFileOperations).load})
+ state.Register("tty.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load})
+ state.Register("tty.superOperations", (*superOperations)(nil), state.Fns{Save: (*superOperations).save, Load: (*superOperations).load})
+ state.Register("tty.lineDiscipline", (*lineDiscipline)(nil), state.Fns{Save: (*lineDiscipline).save, Load: (*lineDiscipline).load})
+ state.Register("tty.outputQueueTransformer", (*outputQueueTransformer)(nil), state.Fns{Save: (*outputQueueTransformer).save, Load: (*outputQueueTransformer).load})
+ state.Register("tty.inputQueueTransformer", (*inputQueueTransformer)(nil), state.Fns{Save: (*inputQueueTransformer).save, Load: (*inputQueueTransformer).load})
+ state.Register("tty.masterInodeOperations", (*masterInodeOperations)(nil), state.Fns{Save: (*masterInodeOperations).save, Load: (*masterInodeOperations).load})
+ state.Register("tty.masterFileOperations", (*masterFileOperations)(nil), state.Fns{Save: (*masterFileOperations).save, Load: (*masterFileOperations).load})
+ state.Register("tty.queue", (*queue)(nil), state.Fns{Save: (*queue).save, Load: (*queue).load})
+ state.Register("tty.slaveInodeOperations", (*slaveInodeOperations)(nil), state.Fns{Save: (*slaveInodeOperations).save, Load: (*slaveInodeOperations).load})
+ state.Register("tty.slaveFileOperations", (*slaveFileOperations)(nil), state.Fns{Save: (*slaveFileOperations).save, Load: (*slaveFileOperations).load})
+ state.Register("tty.Terminal", (*Terminal)(nil), state.Fns{Save: (*Terminal).save, Load: (*Terminal).load})
+}
diff --git a/pkg/sentry/hostcpu/getcpu_amd64.s b/pkg/sentry/hostcpu/getcpu_amd64.s
new file mode 100644
index 000000000..aa00316da
--- /dev/null
+++ b/pkg/sentry/hostcpu/getcpu_amd64.s
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// func GetCPU() (cpu uint32)
+TEXT ·GetCPU(SB), NOSPLIT, $0-4
+ BYTE $0x0f; BYTE $0x01; BYTE $0xf9; // RDTSCP
+ // On Linux, the bottom 12 bits of IA32_TSC_AUX are CPU and the upper 20
+ // are node. See arch/x86/entry/vdso/vma.c:vgetcpu_cpu_init().
+ ANDL $0xfff, CX
+ MOVL CX, cpu+0(FP)
+ RET
diff --git a/pkg/sentry/hostcpu/hostcpu.go b/pkg/sentry/hostcpu/hostcpu.go
new file mode 100644
index 000000000..d78f78402
--- /dev/null
+++ b/pkg/sentry/hostcpu/hostcpu.go
@@ -0,0 +1,67 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostcpu provides utilities for working with CPU information provided
+// by a host Linux kernel.
+package hostcpu
+
+import (
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+ "unicode"
+)
+
+// GetCPU returns the caller's current CPU number, without using the Linux VDSO
+// (which is not available to the sentry) or the getcpu(2) system call (which
+// is relatively slow).
+func GetCPU() uint32
+
+// MaxPossibleCPU returns the highest possible CPU number, which is guaranteed
+// not to change for the lifetime of the host kernel.
+func MaxPossibleCPU() (uint32, error) {
+ const path = "/sys/devices/system/cpu/possible"
+ data, err := ioutil.ReadFile(path)
+ if err != nil {
+ return 0, err
+ }
+ str := string(data)
+ // Linux: drivers/base/cpu.c:show_cpus_attr() =>
+ // include/linux/cpumask.h:cpumask_print_to_pagebuf() =>
+ // lib/bitmap.c:bitmap_print_to_pagebuf()
+ i, err := maxValueInLinuxBitmap(str)
+ if err != nil {
+ return 0, fmt.Errorf("invalid %s (%q): %v", path, str, err)
+ }
+ return uint32(i), nil
+}
+
+// maxValueInLinuxBitmap returns the maximum value specified in str, which is a
+// string emitted by Linux's lib/bitmap.c:bitmap_print_to_pagebuf(list=true).
+func maxValueInLinuxBitmap(str string) (uint64, error) {
+ str = strings.TrimSpace(str)
+ // Find the last decimal number in str.
+ idx := strings.LastIndexFunc(str, func(c rune) bool {
+ return !unicode.IsDigit(c)
+ })
+ if idx != -1 {
+ str = str[idx+1:]
+ }
+ i, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ return 0, err
+ }
+ return i, nil
+}
diff --git a/pkg/sentry/hostcpu/hostcpu_state_autogen.go b/pkg/sentry/hostcpu/hostcpu_state_autogen.go
new file mode 100755
index 000000000..f04a56ec0
--- /dev/null
+++ b/pkg/sentry/hostcpu/hostcpu_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package hostcpu
+
diff --git a/pkg/sentry/inet/context.go b/pkg/sentry/inet/context.go
new file mode 100644
index 000000000..8550c4793
--- /dev/null
+++ b/pkg/sentry/inet/context.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextID is the inet package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxStack is a Context.Value key for a network stack.
+ CtxStack contextID = iota
+)
+
+// StackFromContext returns the network stack associated with ctx.
+func StackFromContext(ctx context.Context) Stack {
+ if v := ctx.Value(CtxStack); v != nil {
+ return v.(Stack)
+ }
+ return nil
+}
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
new file mode 100644
index 000000000..7c104fd47
--- /dev/null
+++ b/pkg/sentry/inet/inet.go
@@ -0,0 +1,104 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package inet defines semantics for IP stacks.
+package inet
+
+// Stack represents a TCP/IP stack.
+type Stack interface {
+ // Interfaces returns all network interfaces as a mapping from interface
+ // indexes to interface properties. Interface indices are strictly positive
+ // integers.
+ Interfaces() map[int32]Interface
+
+ // InterfaceAddrs returns all network interface addresses as a mapping from
+ // interface indexes to a slice of associated interface address properties.
+ InterfaceAddrs() map[int32][]InterfaceAddr
+
+ // SupportsIPv6 returns true if the stack supports IPv6 connectivity.
+ SupportsIPv6() bool
+
+ // TCPReceiveBufferSize returns TCP receive buffer size settings.
+ TCPReceiveBufferSize() (TCPBufferSize, error)
+
+ // SetTCPReceiveBufferSize attempts to change TCP receive buffer size
+ // settings.
+ SetTCPReceiveBufferSize(size TCPBufferSize) error
+
+ // TCPSendBufferSize returns TCP send buffer size settings.
+ TCPSendBufferSize() (TCPBufferSize, error)
+
+ // SetTCPSendBufferSize attempts to change TCP send buffer size settings.
+ SetTCPSendBufferSize(size TCPBufferSize) error
+
+ // TCPSACKEnabled returns true if RFC 2018 TCP Selective Acknowledgements
+ // are enabled.
+ TCPSACKEnabled() (bool, error)
+
+ // SetTCPSACKEnabled attempts to change TCP selective acknowledgement
+ // settings.
+ SetTCPSACKEnabled(enabled bool) error
+}
+
+// Interface contains information about a network interface.
+type Interface struct {
+ // Keep these fields sorted in the order they appear in rtnetlink(7).
+
+ // DeviceType is the device type, a Linux ARPHRD_* constant.
+ DeviceType uint16
+
+ // Flags is the device flags; see netdevice(7), under "Ioctls",
+ // "SIOCGIFFLAGS, SIOCSIFFLAGS".
+ Flags uint32
+
+ // Name is the device name.
+ Name string
+
+ // Addr is the hardware device address.
+ Addr []byte
+
+ // MTU is the maximum transmission unit.
+ MTU uint32
+}
+
+// InterfaceAddr contains information about a network interface address.
+type InterfaceAddr struct {
+ // Keep these fields sorted in the order they appear in rtnetlink(7).
+
+ // Family is the address family, a Linux AF_* constant.
+ Family uint8
+
+ // PrefixLen is the address prefix length.
+ PrefixLen uint8
+
+ // Flags is the address flags.
+ Flags uint8
+
+ // Addr is the actual address.
+ Addr []byte
+}
+
+// TCPBufferSize contains settings controlling TCP buffer sizing.
+//
+// +stateify savable
+type TCPBufferSize struct {
+ // Min is the minimum size.
+ Min int
+
+ // Default is the default size.
+ Default int
+
+ // Max is the maximum size.
+ Max int
+}
diff --git a/pkg/sentry/inet/inet_state_autogen.go b/pkg/sentry/inet/inet_state_autogen.go
new file mode 100755
index 000000000..9f8460ec0
--- /dev/null
+++ b/pkg/sentry/inet/inet_state_autogen.go
@@ -0,0 +1,26 @@
+// automatically generated by stateify.
+
+package inet
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *TCPBufferSize) beforeSave() {}
+func (x *TCPBufferSize) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Min", &x.Min)
+ m.Save("Default", &x.Default)
+ m.Save("Max", &x.Max)
+}
+
+func (x *TCPBufferSize) afterLoad() {}
+func (x *TCPBufferSize) load(m state.Map) {
+ m.Load("Min", &x.Min)
+ m.Load("Default", &x.Default)
+ m.Load("Max", &x.Max)
+}
+
+func init() {
+ state.Register("inet.TCPBufferSize", (*TCPBufferSize)(nil), state.Fns{Save: (*TCPBufferSize).save, Load: (*TCPBufferSize).load})
+}
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
new file mode 100644
index 000000000..624371eb6
--- /dev/null
+++ b/pkg/sentry/inet/test_stack.go
@@ -0,0 +1,83 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+// TestStack is a dummy implementation of Stack for tests.
+type TestStack struct {
+ InterfacesMap map[int32]Interface
+ InterfaceAddrsMap map[int32][]InterfaceAddr
+ SupportsIPv6Flag bool
+ TCPRecvBufSize TCPBufferSize
+ TCPSendBufSize TCPBufferSize
+ TCPSACKFlag bool
+}
+
+// NewTestStack returns a TestStack with no network interfaces. The value of
+// all other options is unspecified; tests that rely on specific values must
+// set them explicitly.
+func NewTestStack() *TestStack {
+ return &TestStack{
+ InterfacesMap: make(map[int32]Interface),
+ InterfaceAddrsMap: make(map[int32][]InterfaceAddr),
+ }
+}
+
+// Interfaces implements Stack.Interfaces.
+func (s *TestStack) Interfaces() map[int32]Interface {
+ return s.InterfacesMap
+}
+
+// InterfaceAddrs implements Stack.InterfaceAddrs.
+func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr {
+ return s.InterfaceAddrsMap
+}
+
+// SupportsIPv6 implements Stack.SupportsIPv6.
+func (s *TestStack) SupportsIPv6() bool {
+ return s.SupportsIPv6Flag
+}
+
+// TCPReceiveBufferSize implements Stack.TCPReceiveBufferSize.
+func (s *TestStack) TCPReceiveBufferSize() (TCPBufferSize, error) {
+ return s.TCPRecvBufSize, nil
+}
+
+// SetTCPReceiveBufferSize implements Stack.SetTCPReceiveBufferSize.
+func (s *TestStack) SetTCPReceiveBufferSize(size TCPBufferSize) error {
+ s.TCPRecvBufSize = size
+ return nil
+}
+
+// TCPSendBufferSize implements Stack.TCPSendBufferSize.
+func (s *TestStack) TCPSendBufferSize() (TCPBufferSize, error) {
+ return s.TCPSendBufSize, nil
+}
+
+// SetTCPSendBufferSize implements Stack.SetTCPSendBufferSize.
+func (s *TestStack) SetTCPSendBufferSize(size TCPBufferSize) error {
+ s.TCPSendBufSize = size
+ return nil
+}
+
+// TCPSACKEnabled implements Stack.TCPSACKEnabled.
+func (s *TestStack) TCPSACKEnabled() (bool, error) {
+ return s.TCPSACKFlag, nil
+}
+
+// SetTCPSACKEnabled implements Stack.SetTCPSACKEnabled.
+func (s *TestStack) SetTCPSACKEnabled(enabled bool) error {
+ s.TCPSACKFlag = enabled
+ return nil
+}
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
new file mode 100644
index 000000000..5ce52e66c
--- /dev/null
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// +stateify savable
+type abstractEndpoint struct {
+ ep transport.BoundEndpoint
+ wr *refs.WeakRef
+ name string
+ ns *AbstractSocketNamespace
+}
+
+// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
+func (e *abstractEndpoint) WeakRefGone() {
+ e.ns.mu.Lock()
+ if e.ns.endpoints[e.name].ep == e.ep {
+ delete(e.ns.endpoints, e.name)
+ }
+ e.ns.mu.Unlock()
+}
+
+// AbstractSocketNamespace is used to implement the Linux abstract socket functionality.
+//
+// +stateify savable
+type AbstractSocketNamespace struct {
+ mu sync.Mutex `state:"nosave"`
+
+ // Keeps mapping from name to endpoint.
+ endpoints map[string]abstractEndpoint
+}
+
+// NewAbstractSocketNamespace returns a new AbstractSocketNamespace.
+func NewAbstractSocketNamespace() *AbstractSocketNamespace {
+ return &AbstractSocketNamespace{
+ endpoints: make(map[string]abstractEndpoint),
+ }
+}
+
+// A boundEndpoint wraps a transport.BoundEndpoint to maintain a reference on
+// its backing object.
+type boundEndpoint struct {
+ transport.BoundEndpoint
+ rc refs.RefCounter
+}
+
+// Release implements transport.BoundEndpoint.Release.
+func (e *boundEndpoint) Release() {
+ e.rc.DecRef()
+ e.BoundEndpoint.Release()
+}
+
+// BoundEndpoint retrieves the endpoint bound to the given name. The return
+// value is nil if no endpoint was bound.
+func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndpoint {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ ep, ok := a.endpoints[name]
+ if !ok {
+ return nil
+ }
+
+ rc := ep.wr.Get()
+ if rc == nil {
+ delete(a.endpoints, name)
+ return nil
+ }
+
+ return &boundEndpoint{ep.ep, rc}
+}
+
+// Bind binds the given socket.
+//
+// When the last reference managed by rc is dropped, ep may be removed from the
+// namespace.
+func (a *AbstractSocketNamespace) Bind(name string, ep transport.BoundEndpoint, rc refs.RefCounter) error {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if ep, ok := a.endpoints[name]; ok {
+ if rc := ep.wr.Get(); rc != nil {
+ rc.DecRef()
+ return syscall.EADDRINUSE
+ }
+ }
+
+ ae := abstractEndpoint{ep: ep, name: name, ns: a}
+ ae.wr = refs.NewWeakRef(rc, &ae)
+ a.endpoints[name] = ae
+ return nil
+}
diff --git a/pkg/sentry/kernel/auth/auth.go b/pkg/sentry/kernel/auth/auth.go
new file mode 100644
index 000000000..847d121aa
--- /dev/null
+++ b/pkg/sentry/kernel/auth/auth.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package auth implements an access control model that is a subset of Linux's.
+//
+// The auth package supports two kinds of access controls: user/group IDs and
+// capabilities. Each resource in the security model is associated with a user
+// namespace; "privileged" operations check that the operator's credentials
+// have the required user/group IDs or capabilities within the user namespace
+// of accessed resources.
+package auth
diff --git a/pkg/sentry/kernel/auth/auth_state_autogen.go b/pkg/sentry/kernel/auth/auth_state_autogen.go
new file mode 100755
index 000000000..6f80381c6
--- /dev/null
+++ b/pkg/sentry/kernel/auth/auth_state_autogen.go
@@ -0,0 +1,151 @@
+// automatically generated by stateify.
+
+package auth
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Credentials) beforeSave() {}
+func (x *Credentials) save(m state.Map) {
+ x.beforeSave()
+ m.Save("RealKUID", &x.RealKUID)
+ m.Save("EffectiveKUID", &x.EffectiveKUID)
+ m.Save("SavedKUID", &x.SavedKUID)
+ m.Save("RealKGID", &x.RealKGID)
+ m.Save("EffectiveKGID", &x.EffectiveKGID)
+ m.Save("SavedKGID", &x.SavedKGID)
+ m.Save("ExtraKGIDs", &x.ExtraKGIDs)
+ m.Save("PermittedCaps", &x.PermittedCaps)
+ m.Save("InheritableCaps", &x.InheritableCaps)
+ m.Save("EffectiveCaps", &x.EffectiveCaps)
+ m.Save("BoundingCaps", &x.BoundingCaps)
+ m.Save("KeepCaps", &x.KeepCaps)
+ m.Save("UserNamespace", &x.UserNamespace)
+}
+
+func (x *Credentials) afterLoad() {}
+func (x *Credentials) load(m state.Map) {
+ m.Load("RealKUID", &x.RealKUID)
+ m.Load("EffectiveKUID", &x.EffectiveKUID)
+ m.Load("SavedKUID", &x.SavedKUID)
+ m.Load("RealKGID", &x.RealKGID)
+ m.Load("EffectiveKGID", &x.EffectiveKGID)
+ m.Load("SavedKGID", &x.SavedKGID)
+ m.Load("ExtraKGIDs", &x.ExtraKGIDs)
+ m.Load("PermittedCaps", &x.PermittedCaps)
+ m.Load("InheritableCaps", &x.InheritableCaps)
+ m.Load("EffectiveCaps", &x.EffectiveCaps)
+ m.Load("BoundingCaps", &x.BoundingCaps)
+ m.Load("KeepCaps", &x.KeepCaps)
+ m.Load("UserNamespace", &x.UserNamespace)
+}
+
+func (x *IDMapEntry) beforeSave() {}
+func (x *IDMapEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("FirstID", &x.FirstID)
+ m.Save("FirstParentID", &x.FirstParentID)
+ m.Save("Length", &x.Length)
+}
+
+func (x *IDMapEntry) afterLoad() {}
+func (x *IDMapEntry) load(m state.Map) {
+ m.Load("FirstID", &x.FirstID)
+ m.Load("FirstParentID", &x.FirstParentID)
+ m.Load("Length", &x.Length)
+}
+
+func (x *idMapRange) beforeSave() {}
+func (x *idMapRange) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *idMapRange) afterLoad() {}
+func (x *idMapRange) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func (x *idMapSet) beforeSave() {}
+func (x *idMapSet) save(m state.Map) {
+ x.beforeSave()
+ var root *idMapSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *idMapSet) afterLoad() {}
+func (x *idMapSet) load(m state.Map) {
+ m.LoadValue("root", new(*idMapSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*idMapSegmentDataSlices)) })
+}
+
+func (x *idMapnode) beforeSave() {}
+func (x *idMapnode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *idMapnode) afterLoad() {}
+func (x *idMapnode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *idMapSegmentDataSlices) beforeSave() {}
+func (x *idMapSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *idMapSegmentDataSlices) afterLoad() {}
+func (x *idMapSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func (x *UserNamespace) beforeSave() {}
+func (x *UserNamespace) save(m state.Map) {
+ x.beforeSave()
+ m.Save("parent", &x.parent)
+ m.Save("owner", &x.owner)
+ m.Save("uidMapFromParent", &x.uidMapFromParent)
+ m.Save("uidMapToParent", &x.uidMapToParent)
+ m.Save("gidMapFromParent", &x.gidMapFromParent)
+ m.Save("gidMapToParent", &x.gidMapToParent)
+}
+
+func (x *UserNamespace) afterLoad() {}
+func (x *UserNamespace) load(m state.Map) {
+ m.Load("parent", &x.parent)
+ m.Load("owner", &x.owner)
+ m.Load("uidMapFromParent", &x.uidMapFromParent)
+ m.Load("uidMapToParent", &x.uidMapToParent)
+ m.Load("gidMapFromParent", &x.gidMapFromParent)
+ m.Load("gidMapToParent", &x.gidMapToParent)
+}
+
+func init() {
+ state.Register("auth.Credentials", (*Credentials)(nil), state.Fns{Save: (*Credentials).save, Load: (*Credentials).load})
+ state.Register("auth.IDMapEntry", (*IDMapEntry)(nil), state.Fns{Save: (*IDMapEntry).save, Load: (*IDMapEntry).load})
+ state.Register("auth.idMapRange", (*idMapRange)(nil), state.Fns{Save: (*idMapRange).save, Load: (*idMapRange).load})
+ state.Register("auth.idMapSet", (*idMapSet)(nil), state.Fns{Save: (*idMapSet).save, Load: (*idMapSet).load})
+ state.Register("auth.idMapnode", (*idMapnode)(nil), state.Fns{Save: (*idMapnode).save, Load: (*idMapnode).load})
+ state.Register("auth.idMapSegmentDataSlices", (*idMapSegmentDataSlices)(nil), state.Fns{Save: (*idMapSegmentDataSlices).save, Load: (*idMapSegmentDataSlices).load})
+ state.Register("auth.UserNamespace", (*UserNamespace)(nil), state.Fns{Save: (*UserNamespace).save, Load: (*UserNamespace).load})
+}
diff --git a/pkg/sentry/kernel/auth/capability_set.go b/pkg/sentry/kernel/auth/capability_set.go
new file mode 100644
index 000000000..7a0c967cd
--- /dev/null
+++ b/pkg/sentry/kernel/auth/capability_set.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/bits"
+)
+
+// A CapabilitySet is a set of capabilities implemented as a bitset. The zero
+// value of CapabilitySet is a set containing no capabilities.
+type CapabilitySet uint64
+
+// AllCapabilities is a CapabilitySet containing all valid capabilities.
+var AllCapabilities = CapabilitySetOf(linux.MaxCapability+1) - 1
+
+// CapabilitySetOf returns a CapabilitySet containing only the given
+// capability.
+func CapabilitySetOf(cp linux.Capability) CapabilitySet {
+ return CapabilitySet(bits.MaskOf64(int(cp)))
+}
+
+// CapabilitySetOfMany returns a CapabilitySet containing the given capabilities.
+func CapabilitySetOfMany(cps []linux.Capability) CapabilitySet {
+ var cs uint64
+ for _, cp := range cps {
+ cs |= bits.MaskOf64(int(cp))
+ }
+ return CapabilitySet(cs)
+}
+
+// TaskCapabilities represents all the capability sets for a task. Each of these
+// sets is explained in greater detail in capabilities(7).
+type TaskCapabilities struct {
+ // Permitted is a limiting superset for the effective capabilities that
+ // the thread may assume.
+ PermittedCaps CapabilitySet
+ // Inheritable is a set of capabilities preserved across an execve(2).
+ InheritableCaps CapabilitySet
+ // Effective is the set of capabilities used by the kernel to perform
+ // permission checks for the thread.
+ EffectiveCaps CapabilitySet
+ // Bounding is a limiting superset for the capabilities that a thread
+ // can add to its inheritable set using capset(2).
+ BoundingCaps CapabilitySet
+ // Ambient is a set of capabilities that are preserved across an
+ // execve(2) of a program that is not privileged.
+ AmbientCaps CapabilitySet
+}
diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go
new file mode 100644
index 000000000..16d110610
--- /dev/null
+++ b/pkg/sentry/kernel/auth/context.go
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextID is the auth package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxCredentials is a Context.Value key for Credentials.
+ CtxCredentials contextID = iota
+)
+
+// CredentialsFromContext returns a copy of the Credentials used by ctx, or a
+// set of Credentials with no capabilities if ctx does not have Credentials.
+func CredentialsFromContext(ctx context.Context) *Credentials {
+ if v := ctx.Value(CtxCredentials); v != nil {
+ return v.(*Credentials)
+ }
+ return NewAnonymousCredentials()
+}
diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go
new file mode 100644
index 000000000..1511a0324
--- /dev/null
+++ b/pkg/sentry/kernel/auth/credentials.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Credentials contains information required to authorize privileged operations
+// in a user namespace.
+//
+// +stateify savable
+type Credentials struct {
+ // Real/effective/saved user/group IDs in the root user namespace. None of
+ // these should ever be NoID.
+ RealKUID KUID
+ EffectiveKUID KUID
+ SavedKUID KUID
+ RealKGID KGID
+ EffectiveKGID KGID
+ SavedKGID KGID
+
+ // Filesystem user/group IDs are not implemented. "... setfsuid() is
+ // nowadays unneeded and should be avoided in new applications (likewise
+ // for setfsgid(2))." - setfsuid(2)
+
+ // Supplementary groups used by set/getgroups.
+ //
+ // ExtraKGIDs slices are immutable, allowing multiple Credentials with the
+ // same ExtraKGIDs to share the same slice.
+ ExtraKGIDs []KGID
+
+ // The capability sets applicable to this set of credentials.
+ PermittedCaps CapabilitySet
+ InheritableCaps CapabilitySet
+ EffectiveCaps CapabilitySet
+ BoundingCaps CapabilitySet
+ // Ambient capabilities are not introduced until Linux 4.3.
+
+ // KeepCaps is the flag for PR_SET_KEEPCAPS which allow capabilities to be
+ // maintained after a switch from root user to non-root user via setuid().
+ KeepCaps bool
+
+ // The user namespace associated with the owner of the credentials.
+ UserNamespace *UserNamespace
+}
+
+// NewAnonymousCredentials returns a set of credentials with no capabilities in
+// any user namespace.
+func NewAnonymousCredentials() *Credentials {
+ // Create a new root user namespace. Since the new namespace's owner is
+ // KUID 0 and the returned credentials have non-zero KUID/KGID, the
+ // returned credentials do not have any capabilities in the new namespace.
+ // Since the new namespace is not part of any existing user namespace
+ // hierarchy, the returned credentials do not have any capabilities in any
+ // other namespace.
+ return &Credentials{
+ RealKUID: NobodyKUID,
+ EffectiveKUID: NobodyKUID,
+ SavedKUID: NobodyKUID,
+ RealKGID: NobodyKGID,
+ EffectiveKGID: NobodyKGID,
+ SavedKGID: NobodyKGID,
+ UserNamespace: NewRootUserNamespace(),
+ }
+}
+
+// NewRootCredentials returns a set of credentials with KUID and KGID 0 (i.e.
+// global root) in user namespace ns.
+func NewRootCredentials(ns *UserNamespace) *Credentials {
+ // I can't find documentation for this anywhere, but it's correct for the
+ // inheritable capability set to be initially empty (the capabilities test
+ // checks for this property).
+ return &Credentials{
+ RealKUID: RootKUID,
+ EffectiveKUID: RootKUID,
+ SavedKUID: RootKUID,
+ RealKGID: RootKGID,
+ EffectiveKGID: RootKGID,
+ SavedKGID: RootKGID,
+ PermittedCaps: AllCapabilities,
+ EffectiveCaps: AllCapabilities,
+ BoundingCaps: AllCapabilities,
+ UserNamespace: ns,
+ }
+}
+
+// NewUserCredentials returns a set of credentials based on the given UID, GIDs,
+// and capabilities in a given namespace. If all arguments are their zero
+// values, this returns the same credentials as NewRootCredentials.
+func NewUserCredentials(kuid KUID, kgid KGID, extraKGIDs []KGID, capabilities *TaskCapabilities, ns *UserNamespace) *Credentials {
+ creds := NewRootCredentials(ns)
+
+ // Set the UID.
+ uid := kuid
+ creds.RealKUID = uid
+ creds.EffectiveKUID = uid
+ creds.SavedKUID = uid
+
+ // Set GID.
+ gid := kgid
+ creds.RealKGID = gid
+ creds.EffectiveKGID = gid
+ creds.SavedKGID = gid
+
+ // Set additional GIDs.
+ creds.ExtraKGIDs = append(creds.ExtraKGIDs, extraKGIDs...)
+
+ // Set capabilities.
+ if capabilities != nil {
+ creds.PermittedCaps = capabilities.PermittedCaps
+ creds.EffectiveCaps = capabilities.EffectiveCaps
+ creds.BoundingCaps = capabilities.BoundingCaps
+ creds.InheritableCaps = capabilities.InheritableCaps
+ // TODO(nlacasse): Support ambient capabilities.
+ } else {
+ // If no capabilities are specified, grant capabilities consistent with
+ // setresuid + setresgid from NewRootCredentials to the given uid and
+ // gid.
+ if kuid == RootKUID {
+ creds.PermittedCaps = AllCapabilities
+ creds.EffectiveCaps = AllCapabilities
+ } else {
+ creds.PermittedCaps = 0
+ creds.EffectiveCaps = 0
+ }
+ creds.BoundingCaps = AllCapabilities
+ }
+
+ return creds
+}
+
+// Fork generates an identical copy of a set of credentials.
+func (c *Credentials) Fork() *Credentials {
+ nc := new(Credentials)
+ *nc = *c // Copy-by-value; this is legal for all fields.
+ return nc
+}
+
+// InGroup returns true if c is in group kgid. Compare Linux's
+// kernel/groups.c:in_group_p().
+func (c *Credentials) InGroup(kgid KGID) bool {
+ if c.EffectiveKGID == kgid {
+ return true
+ }
+ for _, extraKGID := range c.ExtraKGIDs {
+ if extraKGID == kgid {
+ return true
+ }
+ }
+ return false
+}
+
+// HasCapabilityIn returns true if c has capability cp in ns.
+func (c *Credentials) HasCapabilityIn(cp linux.Capability, ns *UserNamespace) bool {
+ for {
+ // "1. A process has a capability inside a user namespace if it is a member
+ // of that namespace and it has the capability in its effective capability
+ // set." - user_namespaces(7)
+ if c.UserNamespace == ns {
+ return CapabilitySetOf(cp)&c.EffectiveCaps != 0
+ }
+ // "3. ... A process that resides in the parent of the user namespace and
+ // whose effective user ID matches the owner of the namespace has all
+ // capabilities in the namespace."
+ if c.UserNamespace == ns.parent && c.EffectiveKUID == ns.owner {
+ return true
+ }
+ // "2. If a process has a capability in a user namespace, then it has that
+ // capability in all child (and further removed descendant) namespaces as
+ // well."
+ if ns.parent == nil {
+ return false
+ }
+ ns = ns.parent
+ }
+}
+
+// HasCapability returns true if c has capability cp in its user namespace.
+func (c *Credentials) HasCapability(cp linux.Capability) bool {
+ return c.HasCapabilityIn(cp, c.UserNamespace)
+}
+
+// UseUID checks that c can use uid in its user namespace, then translates it
+// to the root user namespace.
+//
+// The checks UseUID does are common, but you should verify that it's doing
+// exactly what you want.
+func (c *Credentials) UseUID(uid UID) (KUID, error) {
+ // uid must be mapped.
+ kuid := c.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return NoID, syserror.EINVAL
+ }
+ // If c has CAP_SETUID, then it can use any UID in its user namespace.
+ if c.HasCapability(linux.CAP_SETUID) {
+ return kuid, nil
+ }
+ // Otherwise, c must already have the UID as its real, effective, or saved
+ // set-user-ID.
+ if kuid == c.RealKUID || kuid == c.EffectiveKUID || kuid == c.SavedKUID {
+ return kuid, nil
+ }
+ return NoID, syserror.EPERM
+}
+
+// UseGID checks that c can use gid in its user namespace, then translates it
+// to the root user namespace.
+func (c *Credentials) UseGID(gid GID) (KGID, error) {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return NoID, syserror.EINVAL
+ }
+ if c.HasCapability(linux.CAP_SETGID) {
+ return kgid, nil
+ }
+ if kgid == c.RealKGID || kgid == c.EffectiveKGID || kgid == c.SavedKGID {
+ return kgid, nil
+ }
+ return NoID, syserror.EPERM
+}
diff --git a/pkg/sentry/kernel/auth/id.go b/pkg/sentry/kernel/auth/id.go
new file mode 100644
index 000000000..0a58ba17c
--- /dev/null
+++ b/pkg/sentry/kernel/auth/id.go
@@ -0,0 +1,121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "math"
+)
+
+// UID is a user ID in an unspecified user namespace.
+type UID uint32
+
+// GID is a group ID in an unspecified user namespace.
+type GID uint32
+
+// In the root user namespace, user/group IDs have a 1-to-1 relationship with
+// the users/groups they represent. In other user namespaces, this is not the
+// case; for example, two different unmapped users may both "have" the overflow
+// UID. This means that it is generally only valid to compare user and group
+// IDs in the root user namespace. We assign distinct types, KUID/KGID, to such
+// IDs to emphasize this distinction. ("k" is for "key", as in "unique key".
+// Linux also uses the prefix "k", but I think they mean "kernel".)
+
+// KUID is a user ID in the root user namespace.
+type KUID uint32
+
+// KGID is a group ID in the root user namespace.
+type KGID uint32
+
+const (
+ // NoID is uint32(-1). -1 is consistently used as a special value, in Linux
+ // and by extension in the auth package, to mean "no ID":
+ //
+ // - ID mapping returns -1 if the ID is not mapped.
+ //
+ // - Most set*id() syscalls accept -1 to mean "do not change this ID".
+ NoID = math.MaxUint32
+
+ // OverflowUID is the default value of /proc/sys/kernel/overflowuid. The
+ // "overflow UID" is usually [1] used when translating a user ID between
+ // namespaces fails because the ID is not mapped. (We don't implement this
+ // file, so the overflow UID is constant.)
+ //
+ // [1] "There is one notable case where unmapped user and group IDs are not
+ // converted to the corresponding overflow ID value. When viewing a uid_map
+ // or gid_map file in which there is no mapping for the second field, that
+ // field is displayed as 4294967295 (-1 as an unsigned integer);" -
+ // user_namespaces(7)
+ OverflowUID = UID(65534)
+ OverflowGID = GID(65534)
+
+ // NobodyKUID is the user ID usually reserved for the least privileged user
+ // "nobody".
+ NobodyKUID = KUID(65534)
+ NobodyKGID = KGID(65534)
+
+ // RootKUID is the user ID usually used for the most privileged user "root".
+ RootKUID = KUID(0)
+ RootKGID = KGID(0)
+ RootUID = UID(0)
+ RootGID = GID(0)
+)
+
+// Ok returns true if uid is not -1.
+func (uid UID) Ok() bool {
+ return uid != NoID
+}
+
+// Ok returns true if gid is not -1.
+func (gid GID) Ok() bool {
+ return gid != NoID
+}
+
+// Ok returns true if kuid is not -1.
+func (kuid KUID) Ok() bool {
+ return kuid != NoID
+}
+
+// Ok returns true if kgid is not -1.
+func (kgid KGID) Ok() bool {
+ return kgid != NoID
+}
+
+// OrOverflow returns uid if it is valid and the overflow UID otherwise.
+func (uid UID) OrOverflow() UID {
+ if uid.Ok() {
+ return uid
+ }
+ return OverflowUID
+}
+
+// OrOverflow returns gid if it is valid and the overflow GID otherwise.
+func (gid GID) OrOverflow() GID {
+ if gid.Ok() {
+ return gid
+ }
+ return OverflowGID
+}
+
+// In translates kuid into user namespace ns. If kuid is not mapped in ns, In
+// returns NoID.
+func (kuid KUID) In(ns *UserNamespace) UID {
+ return ns.MapFromKUID(kuid)
+}
+
+// In translates kgid into user namespace ns. If kgid is not mapped in ns, In
+// returns NoID.
+func (kgid KGID) In(ns *UserNamespace) GID {
+ return ns.MapFromKGID(kgid)
+}
diff --git a/pkg/sentry/kernel/auth/id_map.go b/pkg/sentry/kernel/auth/id_map.go
new file mode 100644
index 000000000..e5d6028d6
--- /dev/null
+++ b/pkg/sentry/kernel/auth/id_map.go
@@ -0,0 +1,285 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// MapFromKUID translates kuid, a UID in the root namespace, to a UID in ns.
+func (ns *UserNamespace) MapFromKUID(kuid KUID) UID {
+ if ns.parent == nil {
+ return UID(kuid)
+ }
+ return UID(ns.mapID(&ns.uidMapFromParent, uint32(ns.parent.MapFromKUID(kuid))))
+}
+
+// MapFromKGID translates kgid, a GID in the root namespace, to a GID in ns.
+func (ns *UserNamespace) MapFromKGID(kgid KGID) GID {
+ if ns.parent == nil {
+ return GID(kgid)
+ }
+ return GID(ns.mapID(&ns.gidMapFromParent, uint32(ns.parent.MapFromKGID(kgid))))
+}
+
+// MapToKUID translates uid, a UID in ns, to a UID in the root namespace.
+func (ns *UserNamespace) MapToKUID(uid UID) KUID {
+ if ns.parent == nil {
+ return KUID(uid)
+ }
+ return ns.parent.MapToKUID(UID(ns.mapID(&ns.uidMapToParent, uint32(uid))))
+}
+
+// MapToKGID translates gid, a GID in ns, to a GID in the root namespace.
+func (ns *UserNamespace) MapToKGID(gid GID) KGID {
+ if ns.parent == nil {
+ return KGID(gid)
+ }
+ return ns.parent.MapToKGID(GID(ns.mapID(&ns.gidMapToParent, uint32(gid))))
+}
+
+func (ns *UserNamespace) mapID(m *idMapSet, id uint32) uint32 {
+ if id == NoID {
+ return NoID
+ }
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ if it := m.FindSegment(id); it.Ok() {
+ return it.Value() + (id - it.Start())
+ }
+ return NoID
+}
+
+// allIDsMapped returns true if all IDs in the range [start, end) are mapped in
+// m.
+//
+// Preconditions: end >= start.
+func (ns *UserNamespace) allIDsMapped(m *idMapSet, start, end uint32) bool {
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ return m.SpanRange(idMapRange{start, end}) == end-start
+}
+
+// An IDMapEntry represents a mapping from a range of contiguous IDs in a user
+// namespace to an equally-sized range of contiguous IDs in the namespace's
+// parent.
+//
+// +stateify savable
+type IDMapEntry struct {
+ // FirstID is the first ID in the range in the namespace.
+ FirstID uint32
+
+ // FirstParentID is the first ID in the range in the parent namespace.
+ FirstParentID uint32
+
+ // Length is the number of IDs in the range.
+ Length uint32
+}
+
+// SetUIDMap instructs ns to translate UIDs as specified by entries.
+//
+// Note: SetUIDMap does not place an upper bound on the number of entries, but
+// Linux does. This restriction is implemented in SetUIDMap's caller, the
+// implementation of /proc/[pid]/uid_map.
+func (ns *UserNamespace) SetUIDMap(ctx context.Context, entries []IDMapEntry) error {
+ c := CredentialsFromContext(ctx)
+
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ // "After the creation of a new user namespace, the uid_map file of *one*
+ // of the processes in the namespace may be written to *once* to define the
+ // mapping of user IDs in the new user namespace. An attempt to write more
+ // than once to a uid_map file in a user namespace fails with the error
+ // EPERM. Similar rules apply for gid_map files." - user_namespaces(7)
+ if !ns.uidMapFromParent.IsEmpty() {
+ return syserror.EPERM
+ }
+ // "At least one line must be written to the file."
+ if len(entries) == 0 {
+ return syserror.EINVAL
+ }
+ // """
+ // In order for a process to write to the /proc/[pid]/uid_map
+ // (/proc/[pid]/gid_map) file, all of the following requirements must be
+ // met:
+ //
+ // 1. The writing process must have the CAP_SETUID (CAP_SETGID) capability
+ // in the user namespace of the process pid.
+ // """
+ if !c.HasCapabilityIn(linux.CAP_SETUID, ns) {
+ return syserror.EPERM
+ }
+ // "2. The writing process must either be in the user namespace of the process
+ // pid or be in the parent user namespace of the process pid."
+ if c.UserNamespace != ns && c.UserNamespace != ns.parent {
+ return syserror.EPERM
+ }
+ // """
+ // 3. (see trySetUIDMap)
+ //
+ // 4. One of the following two cases applies:
+ //
+ // * Either the writing process has the CAP_SETUID (CAP_SETGID) capability
+ // in the parent user namespace.
+ // """
+ if !c.HasCapabilityIn(linux.CAP_SETUID, ns.parent) {
+ // """
+ // * Or otherwise all of the following restrictions apply:
+ //
+ // + The data written to uid_map (gid_map) must consist of a single line
+ // that maps the writing process' effective user ID (group ID) in the
+ // parent user namespace to a user ID (group ID) in the user namespace.
+ // """
+ if len(entries) != 1 || ns.parent.MapToKUID(UID(entries[0].FirstParentID)) != c.EffectiveKUID || entries[0].Length != 1 {
+ return syserror.EPERM
+ }
+ // """
+ // + The writing process must have the same effective user ID as the
+ // process that created the user namespace.
+ // """
+ if c.EffectiveKUID != ns.owner {
+ return syserror.EPERM
+ }
+ }
+ // trySetUIDMap leaves data in maps if it fails.
+ if err := ns.trySetUIDMap(entries); err != nil {
+ ns.uidMapFromParent.RemoveAll()
+ ns.uidMapToParent.RemoveAll()
+ return err
+ }
+ return nil
+}
+
+func (ns *UserNamespace) trySetUIDMap(entries []IDMapEntry) error {
+ for _, e := range entries {
+ // Determine upper bounds and check for overflow. This implicitly
+ // checks for NoID.
+ lastID := e.FirstID + e.Length
+ if lastID <= e.FirstID {
+ return syserror.EINVAL
+ }
+ lastParentID := e.FirstParentID + e.Length
+ if lastParentID <= e.FirstParentID {
+ return syserror.EINVAL
+ }
+ // "3. The mapped user IDs (group IDs) must in turn have a mapping in
+ // the parent user namespace."
+ // Only the root namespace has a nil parent, and root is assigned
+ // mappings when it's created, so SetUIDMap would have returned EPERM
+ // without reaching this point if ns is root.
+ if !ns.parent.allIDsMapped(&ns.parent.uidMapToParent, e.FirstParentID, lastParentID) {
+ return syserror.EPERM
+ }
+ // If either of these Adds fail, we have an overlapping range.
+ if !ns.uidMapFromParent.Add(idMapRange{e.FirstParentID, lastParentID}, e.FirstID) {
+ return syserror.EINVAL
+ }
+ if !ns.uidMapToParent.Add(idMapRange{e.FirstID, lastID}, e.FirstParentID) {
+ return syserror.EINVAL
+ }
+ }
+ return nil
+}
+
+// SetGIDMap instructs ns to translate GIDs as specified by entries.
+func (ns *UserNamespace) SetGIDMap(ctx context.Context, entries []IDMapEntry) error {
+ c := CredentialsFromContext(ctx)
+
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ if !ns.gidMapFromParent.IsEmpty() {
+ return syserror.EPERM
+ }
+ if len(entries) == 0 {
+ return syserror.EINVAL
+ }
+ if !c.HasCapabilityIn(linux.CAP_SETGID, ns) {
+ return syserror.EPERM
+ }
+ if c.UserNamespace != ns && c.UserNamespace != ns.parent {
+ return syserror.EPERM
+ }
+ if !c.HasCapabilityIn(linux.CAP_SETGID, ns.parent) {
+ if len(entries) != 1 || ns.parent.MapToKGID(GID(entries[0].FirstParentID)) != c.EffectiveKGID || entries[0].Length != 1 {
+ return syserror.EPERM
+ }
+ // It's correct for this to still be UID.
+ if c.EffectiveKUID != ns.owner {
+ return syserror.EPERM
+ }
+ // "In the case of gid_map, use of the setgroups(2) system call must
+ // first be denied by writing "deny" to the /proc/[pid]/setgroups file
+ // (see below) before writing to gid_map." (This file isn't implemented
+ // in the version of Linux we're emulating; see comment in
+ // UserNamespace.)
+ }
+ if err := ns.trySetGIDMap(entries); err != nil {
+ ns.gidMapFromParent.RemoveAll()
+ ns.gidMapToParent.RemoveAll()
+ return err
+ }
+ return nil
+}
+
+func (ns *UserNamespace) trySetGIDMap(entries []IDMapEntry) error {
+ for _, e := range entries {
+ lastID := e.FirstID + e.Length
+ if lastID <= e.FirstID {
+ return syserror.EINVAL
+ }
+ lastParentID := e.FirstParentID + e.Length
+ if lastParentID <= e.FirstParentID {
+ return syserror.EINVAL
+ }
+ if !ns.parent.allIDsMapped(&ns.parent.gidMapToParent, e.FirstParentID, lastParentID) {
+ return syserror.EPERM
+ }
+ if !ns.gidMapFromParent.Add(idMapRange{e.FirstParentID, lastParentID}, e.FirstID) {
+ return syserror.EINVAL
+ }
+ if !ns.gidMapToParent.Add(idMapRange{e.FirstID, lastID}, e.FirstParentID) {
+ return syserror.EINVAL
+ }
+ }
+ return nil
+}
+
+// UIDMap returns the user ID mappings configured for ns. If no mappings
+// have been configured, UIDMap returns nil.
+func (ns *UserNamespace) UIDMap() []IDMapEntry {
+ return ns.getIDMap(&ns.uidMapToParent)
+}
+
+// GIDMap returns the group ID mappings configured for ns. If no mappings
+// have been configured, GIDMap returns nil.
+func (ns *UserNamespace) GIDMap() []IDMapEntry {
+ return ns.getIDMap(&ns.gidMapToParent)
+}
+
+func (ns *UserNamespace) getIDMap(m *idMapSet) []IDMapEntry {
+ ns.mu.Lock()
+ defer ns.mu.Unlock()
+ var entries []IDMapEntry
+ for it := m.FirstSegment(); it.Ok(); it = it.NextSegment() {
+ entries = append(entries, IDMapEntry{
+ FirstID: it.Start(),
+ FirstParentID: it.Value(),
+ Length: it.Range().Length(),
+ })
+ }
+ return entries
+}
diff --git a/pkg/sentry/kernel/auth/id_map_functions.go b/pkg/sentry/kernel/auth/id_map_functions.go
new file mode 100644
index 000000000..432dbfb6d
--- /dev/null
+++ b/pkg/sentry/kernel/auth/id_map_functions.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+// idMapFunctions "implements" generic interface segment.Functions for
+// idMapSet. An idMapSet maps non-overlapping ranges of contiguous IDs in one
+// user namespace to non-overlapping ranges of contiguous IDs in another user
+// namespace. Each such ID mapping is implemented as a range-to-value mapping
+// in the set such that [range.Start(), range.End()) => [value, value +
+// range.Length()).
+type idMapFunctions struct{}
+
+func (idMapFunctions) MinKey() uint32 {
+ return 0
+}
+
+func (idMapFunctions) MaxKey() uint32 {
+ return NoID
+}
+
+func (idMapFunctions) ClearValue(*uint32) {}
+
+func (idMapFunctions) Merge(r1 idMapRange, val1 uint32, r2 idMapRange, val2 uint32) (uint32, bool) {
+ // Mapped ranges have to be contiguous.
+ if val1+r1.Length() != val2 {
+ return 0, false
+ }
+ return val1, true
+}
+
+func (idMapFunctions) Split(r idMapRange, val uint32, split uint32) (uint32, uint32) {
+ return val, val + (split - r.Start)
+}
diff --git a/pkg/sentry/kernel/auth/id_map_range.go b/pkg/sentry/kernel/auth/id_map_range.go
new file mode 100755
index 000000000..833fa3518
--- /dev/null
+++ b/pkg/sentry/kernel/auth/id_map_range.go
@@ -0,0 +1,62 @@
+package auth
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type idMapRange struct {
+ // Start is the inclusive start of the range.
+ Start uint32
+
+ // End is the exclusive end of the range.
+ End uint32
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r idMapRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r idMapRange) Length() uint32 {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r idMapRange) Contains(x uint32) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r idMapRange) Overlaps(r2 idMapRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r idMapRange) IsSupersetOf(r2 idMapRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r idMapRange) Intersect(r2 idMapRange) idMapRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r idMapRange) CanSplitAt(x uint32) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/sentry/kernel/auth/id_map_set.go b/pkg/sentry/kernel/auth/id_map_set.go
new file mode 100755
index 000000000..f72c839c7
--- /dev/null
+++ b/pkg/sentry/kernel/auth/id_map_set.go
@@ -0,0 +1,1270 @@
+package auth
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ idMapminDegree = 3
+
+ idMapmaxDegree = 2 * idMapminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type idMapSet struct {
+ root idMapnode `state:".(*idMapSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *idMapSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *idMapSet) IsEmptyRange(r idMapRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *idMapSet) Span() uint32 {
+ var sz uint32
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *idMapSet) SpanRange(r idMapRange) uint32 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint32
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *idMapSet) FirstSegment() idMapIterator {
+ if s.root.nrSegments == 0 {
+ return idMapIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *idMapSet) LastSegment() idMapIterator {
+ if s.root.nrSegments == 0 {
+ return idMapIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *idMapSet) FirstGap() idMapGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return idMapGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *idMapSet) LastGap() idMapGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return idMapGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *idMapSet) Find(key uint32) (idMapIterator, idMapGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return idMapIterator{n, i}, idMapGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return idMapIterator{}, idMapGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *idMapSet) FindSegment(key uint32) idMapIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *idMapSet) LowerBoundSegment(min uint32) idMapIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *idMapSet) UpperBoundSegment(max uint32) idMapIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *idMapSet) FindGap(key uint32) idMapGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *idMapSet) LowerBoundGap(min uint32) idMapGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *idMapSet) UpperBoundGap(max uint32) idMapGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *idMapSet) Add(r idMapRange, val uint32) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *idMapSet) AddWithoutMerging(r idMapRange, val uint32) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *idMapSet) Insert(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (idMapFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (idMapFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (idMapFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *idMapSet) InsertWithoutMerging(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *idMapSet) InsertWithoutMergingUnchecked(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return idMapIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *idMapSet) Remove(seg idMapIterator) idMapGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ idMapFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(idMapGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *idMapSet) RemoveAll() {
+ s.root = idMapnode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *idMapSet) RemoveRange(r idMapRange) idMapGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *idMapSet) Merge(first, second idMapIterator) idMapIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *idMapSet) MergeUnchecked(first, second idMapIterator) idMapIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (idMapFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return idMapIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *idMapSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *idMapSet) MergeRange(r idMapRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *idMapSet) MergeAdjacent(r idMapRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *idMapSet) Split(seg idMapIterator, split uint32) (idMapIterator, idMapIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *idMapSet) SplitUnchecked(seg idMapIterator, split uint32) (idMapIterator, idMapIterator) {
+ val1, val2 := (idMapFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), idMapRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *idMapSet) SplitAt(split uint32) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *idMapSet) Isolate(seg idMapIterator, r idMapRange) idMapIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *idMapSet) ApplyContiguous(r idMapRange, fn func(seg idMapIterator)) idMapGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return idMapGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return idMapGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type idMapnode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *idMapnode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [idMapmaxDegree - 1]idMapRange
+ values [idMapmaxDegree - 1]uint32
+ children [idMapmaxDegree]*idMapnode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *idMapnode) firstSegment() idMapIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return idMapIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *idMapnode) lastSegment() idMapIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return idMapIterator{n, n.nrSegments - 1}
+}
+
+func (n *idMapnode) prevSibling() *idMapnode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *idMapnode) nextSibling() *idMapnode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *idMapnode) rebalanceBeforeInsert(gap idMapGapIterator) idMapGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < idMapmaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &idMapnode{
+ nrSegments: idMapminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &idMapnode{
+ nrSegments: idMapminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:idMapminDegree-1], n.keys[:idMapminDegree-1])
+ copy(left.values[:idMapminDegree-1], n.values[:idMapminDegree-1])
+ copy(right.keys[:idMapminDegree-1], n.keys[idMapminDegree:])
+ copy(right.values[:idMapminDegree-1], n.values[idMapminDegree:])
+ n.keys[0], n.values[0] = n.keys[idMapminDegree-1], n.values[idMapminDegree-1]
+ idMapzeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:idMapminDegree], n.children[:idMapminDegree])
+ copy(right.children[:idMapminDegree], n.children[idMapminDegree:])
+ idMapzeroNodeSlice(n.children[2:])
+ for i := 0; i < idMapminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < idMapminDegree {
+ return idMapGapIterator{left, gap.index}
+ }
+ return idMapGapIterator{right, gap.index - idMapminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[idMapminDegree-1], n.values[idMapminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &idMapnode{
+ nrSegments: idMapminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:idMapminDegree-1], n.keys[idMapminDegree:])
+ copy(sibling.values[:idMapminDegree-1], n.values[idMapminDegree:])
+ idMapzeroValueSlice(n.values[idMapminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:idMapminDegree], n.children[idMapminDegree:])
+ idMapzeroNodeSlice(n.children[idMapminDegree:])
+ for i := 0; i < idMapminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = idMapminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < idMapminDegree {
+ return gap
+ }
+ return idMapGapIterator{sibling, gap.index - idMapminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *idMapnode) rebalanceAfterRemove(gap idMapGapIterator) idMapGapIterator {
+ for {
+ if n.nrSegments >= idMapminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= idMapminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ idMapFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return idMapGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return idMapGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= idMapminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ idMapFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return idMapGapIterator{n, n.nrSegments}
+ }
+ return idMapGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return idMapGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return idMapGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *idMapnode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = idMapGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ idMapFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type idMapIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *idMapnode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg idMapIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg idMapIterator) Range() idMapRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg idMapIterator) Start() uint32 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg idMapIterator) End() uint32 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg idMapIterator) SetRangeUnchecked(r idMapRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg idMapIterator) SetRange(r idMapRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg idMapIterator) SetStartUnchecked(start uint32) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg idMapIterator) SetStart(start uint32) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg idMapIterator) SetEndUnchecked(end uint32) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg idMapIterator) SetEnd(end uint32) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg idMapIterator) Value() uint32 {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg idMapIterator) ValuePtr() *uint32 {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg idMapIterator) SetValue(val uint32) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg idMapIterator) PrevSegment() idMapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return idMapIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return idMapIterator{}
+ }
+ return idMapsegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg idMapIterator) NextSegment() idMapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return idMapIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return idMapIterator{}
+ }
+ return idMapsegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg idMapIterator) PrevGap() idMapGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return idMapGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg idMapIterator) NextGap() idMapGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return idMapGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg idMapIterator) PrevNonEmpty() (idMapIterator, idMapGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return idMapIterator{}, gap
+ }
+ return gap.PrevSegment(), idMapGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg idMapIterator) NextNonEmpty() (idMapIterator, idMapGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return idMapIterator{}, gap
+ }
+ return gap.NextSegment(), idMapGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type idMapGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *idMapnode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap idMapGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap idMapGapIterator) Range() idMapRange {
+ return idMapRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap idMapGapIterator) Start() uint32 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return idMapFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap idMapGapIterator) End() uint32 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return idMapFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap idMapGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap idMapGapIterator) PrevSegment() idMapIterator {
+ return idMapsegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap idMapGapIterator) NextSegment() idMapIterator {
+ return idMapsegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap idMapGapIterator) PrevGap() idMapGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return idMapGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap idMapGapIterator) NextGap() idMapGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return idMapGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func idMapsegmentBeforePosition(n *idMapnode, i int) idMapIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return idMapIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return idMapIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func idMapsegmentAfterPosition(n *idMapnode, i int) idMapIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return idMapIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return idMapIterator{n, i}
+}
+
+func idMapzeroValueSlice(slice []uint32) {
+
+ for i := range slice {
+ idMapFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func idMapzeroNodeSlice(slice []*idMapnode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *idMapSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *idMapnode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *idMapnode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type idMapSegmentDataSlices struct {
+ Start []uint32
+ End []uint32
+ Values []uint32
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *idMapSet) ExportSortedSlices() *idMapSegmentDataSlices {
+ var sds idMapSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *idMapSet) ImportSortedSlices(sds *idMapSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := idMapRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *idMapSet) saveRoot() *idMapSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *idMapSet) loadRoot(sds *idMapSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go
new file mode 100644
index 000000000..a40dd668f
--- /dev/null
+++ b/pkg/sentry/kernel/auth/user_namespace.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import (
+ "math"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// A UserNamespace represents a user namespace. See user_namespaces(7) for
+// details.
+//
+// +stateify savable
+type UserNamespace struct {
+ // parent is this namespace's parent. If this is the root namespace, parent
+ // is nil. The parent pointer is immutable.
+ parent *UserNamespace
+
+ // owner is the effective UID of the namespace's creator in the root
+ // namespace. owner is immutable.
+ owner KUID
+
+ // mu protects the following fields.
+ //
+ // If mu will be locked in multiple UserNamespaces, it must be locked in
+ // descendant namespaces before ancestors.
+ mu sync.Mutex `state:"nosave"`
+
+ // Mappings of user/group IDs between this namespace and its parent.
+ //
+ // All ID maps, once set, cannot be changed. This means that successful
+ // UID/GID translations cannot be racy.
+ uidMapFromParent idMapSet
+ uidMapToParent idMapSet
+ gidMapFromParent idMapSet
+ gidMapToParent idMapSet
+
+ // TODO(b/27454212): Support disabling setgroups(2).
+}
+
+// NewRootUserNamespace returns a UserNamespace that is appropriate for a
+// system's root user namespace.
+func NewRootUserNamespace() *UserNamespace {
+ var ns UserNamespace
+ // """
+ // The initial user namespace has no parent namespace, but, for
+ // consistency, the kernel provides dummy user and group ID mapping files
+ // for this namespace. Looking at the uid_map file (gid_map is the same)
+ // from a shell in the initial namespace shows:
+ //
+ // $ cat /proc/$$/uid_map
+ // 0 0 4294967295
+ // """ - user_namespaces(7)
+ for _, m := range []*idMapSet{
+ &ns.uidMapFromParent,
+ &ns.uidMapToParent,
+ &ns.gidMapFromParent,
+ &ns.gidMapToParent,
+ } {
+ if !m.Add(idMapRange{0, math.MaxUint32}, 0) {
+ panic("Failed to insert into empty ID map")
+ }
+ }
+ return &ns
+}
+
+// Root returns the root of the user namespace tree containing ns.
+func (ns *UserNamespace) Root() *UserNamespace {
+ for ns.parent != nil {
+ ns = ns.parent
+ }
+ return ns
+}
+
+// "The kernel imposes (since version 3.11) a limit of 32 nested levels of user
+// namespaces." - user_namespaces(7)
+const maxUserNamespaceDepth = 32
+
+func (ns *UserNamespace) depth() int {
+ var i int
+ for ns != nil {
+ i++
+ ns = ns.parent
+ }
+ return i
+}
+
+// NewChildUserNamespace returns a new user namespace created by a caller with
+// credentials c.
+func (c *Credentials) NewChildUserNamespace() (*UserNamespace, error) {
+ if c.UserNamespace.depth() >= maxUserNamespaceDepth {
+ // "... Calls to unshare(2) or clone(2) that would cause this limit to
+ // be exceeded fail with the error EUSERS." - user_namespaces(7)
+ return nil, syserror.EUSERS
+ }
+ // "EPERM: CLONE_NEWUSER was specified in flags, but either the effective
+ // user ID or the effective group ID of the caller does not have a mapping
+ // in the parent namespace (see user_namespaces(7))." - clone(2)
+ // "CLONE_NEWUSER requires that the user ID and group ID of the calling
+ // process are mapped to user IDs and group IDs in the user namespace of
+ // the calling process at the time of the call." - unshare(2)
+ if !c.EffectiveKUID.In(c.UserNamespace).Ok() {
+ return nil, syserror.EPERM
+ }
+ if !c.EffectiveKGID.In(c.UserNamespace).Ok() {
+ return nil, syserror.EPERM
+ }
+ return &UserNamespace{
+ parent: c.UserNamespace,
+ owner: c.EffectiveKUID,
+ // "When a user namespace is created, it starts without a mapping of
+ // user IDs (group IDs) to the parent user namespace." -
+ // user_namespaces(7)
+ }, nil
+}
diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go
new file mode 100644
index 000000000..a1a084eab
--- /dev/null
+++ b/pkg/sentry/kernel/context.go
@@ -0,0 +1,135 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextID is the kernel package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxCanTrace is a Context.Value key for a function with the same
+ // signature and semantics as kernel.Task.CanTrace.
+ CtxCanTrace contextID = iota
+
+ // CtxKernel is a Context.Value key for a Kernel.
+ CtxKernel
+
+ // CtxPIDNamespace is a Context.Value key for a PIDNamespace.
+ CtxPIDNamespace
+
+ // CtxTask is a Context.Value key for a Task.
+ CtxTask
+
+ // CtxUTSNamespace is a Context.Value key for a UTSNamespace.
+ CtxUTSNamespace
+
+ // CtxIPCNamespace is a Context.Value key for a IPCNamespace.
+ CtxIPCNamespace
+)
+
+// ContextCanTrace returns true if ctx is permitted to trace t, in the same sense
+// as kernel.Task.CanTrace.
+func ContextCanTrace(ctx context.Context, t *Task, attach bool) bool {
+ if v := ctx.Value(CtxCanTrace); v != nil {
+ return v.(func(*Task, bool) bool)(t, attach)
+ }
+ return false
+}
+
+// KernelFromContext returns the Kernel in which ctx is executing, or nil if
+// there is no such Kernel.
+func KernelFromContext(ctx context.Context) *Kernel {
+ if v := ctx.Value(CtxKernel); v != nil {
+ return v.(*Kernel)
+ }
+ return nil
+}
+
+// PIDNamespaceFromContext returns the PID namespace in which ctx is executing,
+// or nil if there is no such PID namespace.
+func PIDNamespaceFromContext(ctx context.Context) *PIDNamespace {
+ if v := ctx.Value(CtxPIDNamespace); v != nil {
+ return v.(*PIDNamespace)
+ }
+ return nil
+}
+
+// UTSNamespaceFromContext returns the UTS namespace in which ctx is executing,
+// or nil if there is no such UTS namespace.
+func UTSNamespaceFromContext(ctx context.Context) *UTSNamespace {
+ if v := ctx.Value(CtxUTSNamespace); v != nil {
+ return v.(*UTSNamespace)
+ }
+ return nil
+}
+
+// IPCNamespaceFromContext returns the IPC namespace in which ctx is executing,
+// or nil if there is no such IPC namespace.
+func IPCNamespaceFromContext(ctx context.Context) *IPCNamespace {
+ if v := ctx.Value(CtxIPCNamespace); v != nil {
+ return v.(*IPCNamespace)
+ }
+ return nil
+}
+
+// TaskFromContext returns the Task associated with ctx, or nil if there is no
+// such Task.
+func TaskFromContext(ctx context.Context) *Task {
+ if v := ctx.Value(CtxTask); v != nil {
+ return v.(*Task)
+ }
+ return nil
+}
+
+// AsyncContext returns a context.Context that may be used by goroutines that
+// do work on behalf of t and therefore share its contextual values, but are
+// not t's task goroutine (e.g. asynchronous I/O).
+func (t *Task) AsyncContext() context.Context {
+ return taskAsyncContext{t: t}
+}
+
+type taskAsyncContext struct {
+ context.NoopSleeper
+ t *Task
+}
+
+// Debugf implements log.Logger.Debugf.
+func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) {
+ ctx.t.Debugf(format, v...)
+}
+
+// Infof implements log.Logger.Infof.
+func (ctx taskAsyncContext) Infof(format string, v ...interface{}) {
+ ctx.t.Infof(format, v...)
+}
+
+// Warningf implements log.Logger.Warningf.
+func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) {
+ ctx.t.Warningf(format, v...)
+}
+
+// IsLogging implements log.Logger.IsLogging.
+func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
+ return ctx.t.IsLogging(level)
+}
+
+// Value implements context.Context.Value.
+func (ctx taskAsyncContext) Value(key interface{}) interface{} {
+ return ctx.t.Value(key)
+}
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
new file mode 100644
index 000000000..bbacba1f4
--- /dev/null
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -0,0 +1,473 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package epoll provides an implementation of Linux's IO event notification
+// facility. See epoll(7) for more details.
+package epoll
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/anon"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Event describes the event mask that was observed and the user data to be
+// returned when one of the events occurs. It has this format to match the linux
+// format to avoid extra copying/allocation when writing events to userspace.
+type Event struct {
+ // Events is the event mask containing the set of events that have been
+ // observed on an entry.
+ Events uint32
+
+ // Data is an opaque 64-bit value provided by the caller when adding the
+ // entry, and returned to the caller when the entry reports an event.
+ Data [2]int32
+}
+
+// EntryFlags is a bitmask that holds an entry's flags.
+type EntryFlags int
+
+// Valid entry flags.
+const (
+ OneShot EntryFlags = 1 << iota
+ EdgeTriggered
+)
+
+// FileIdentifier identifies a file. We cannot use just the FD because it could
+// potentially be reassigned. We also cannot use just the file pointer because
+// it is possible to have multiple entries for the same file object as long as
+// they are created with different FDs (i.e., the FDs point to the same file).
+//
+// +stateify savable
+type FileIdentifier struct {
+ File *fs.File `state:"wait"`
+ Fd kdefs.FD
+}
+
+// pollEntry holds all the state associated with an event poll entry, that is,
+// a file being observed by an event poll object.
+//
+// +stateify savable
+type pollEntry struct {
+ pollEntryEntry
+ file *refs.WeakRef `state:"manual"`
+ id FileIdentifier `state:"wait"`
+ userData [2]int32
+ waiter waiter.Entry `state:"manual"`
+ mask waiter.EventMask
+ flags EntryFlags
+
+ epoll *EventPoll
+
+ // We cannot save the current list pointer as it points into EventPoll
+ // struct, while state framework currently does not support such
+ // in-struct pointers. Instead, EventPoll will properly set this field
+ // in its loading logic.
+ curList *pollEntryList `state:"nosave"`
+}
+
+// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
+// weakReferenceGone is called when the file in the weak reference is destroyed.
+// The poll entry is removed in response to this.
+func (p *pollEntry) WeakRefGone() {
+ p.epoll.RemoveEntry(p.id)
+}
+
+// EventPoll holds all the state associated with an event poll object, that is,
+// collection of files to observe and their current state.
+//
+// +stateify savable
+type EventPoll struct {
+ fsutil.FilePipeSeek `state:"zerovalue"`
+ fsutil.FileNotDirReaddir `state:"zerovalue"`
+ fsutil.FileNoFsync `state:"zerovalue"`
+ fsutil.FileNoopFlush `state:"zerovalue"`
+ fsutil.FileNoIoctl `state:"zerovalue"`
+ fsutil.FileNoMMap `state:"zerovalue"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // Wait queue is used to notify interested parties when the event poll
+ // object itself becomes readable or writable.
+ waiter.Queue `state:"zerovalue"`
+
+ // files is the map of all the files currently being observed, it is
+ // protected by mu.
+ mu sync.Mutex `state:"nosave"`
+ files map[FileIdentifier]*pollEntry
+
+ // listsMu protects manipulation of the lists below. It needs to be a
+ // different lock to avoid circular lock acquisition order involving
+ // the wait queue mutexes and mu. The full order is mu, observed file
+ // wait queue mutex, then listsMu; this allows listsMu to be acquired
+ // when readyCallback is called.
+ //
+ // An entry is always in one of the following lists:
+ // readyList -- when there's a chance that it's ready to have
+ // events delivered to epoll waiters. Given that being
+ // ready is a transient state, the Readiness() and
+ // readEvents() functions always call the entry's file
+ // Readiness() function to confirm it's ready.
+ // waitingList -- when there's no chance that the entry is ready,
+ // so it's waiting for the readyCallback to be called
+ // on it before it gets moved to the readyList.
+ // disabledList -- when the entry is disabled. This happens when
+ // a one-shot entry gets delivered via readEvents().
+ listsMu sync.Mutex `state:"nosave"`
+ readyList pollEntryList
+ waitingList pollEntryList
+ disabledList pollEntryList
+}
+
+// cycleMu is used to serialize all the cycle checks. This is only used when
+// an event poll file is added as an entry to another event poll. Such checks
+// are serialized to avoid lock acquisition order inversion: if a thread is
+// adding A to B, and another thread is adding B to A, each would acquire A's
+// and B's mutexes in reverse order, and could cause deadlocks. Having this
+// lock prevents this by allowing only one check at a time to happen.
+//
+// We do the cycle check to prevent callers from introducing potentially
+// infinite recursions. If a caller were to add A to B and then B to A, for
+// event poll A to know if it's readable, it would need to check event poll B,
+// which in turn would need event poll A and so on indefinitely.
+var cycleMu sync.Mutex
+
+// NewEventPoll allocates and initializes a new event poll object.
+func NewEventPoll(ctx context.Context) *fs.File {
+ // name matches fs/eventpoll.c:epoll_create1.
+ dirent := fs.NewDirent(anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]"))
+ return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{
+ files: make(map[FileIdentifier]*pollEntry),
+ })
+}
+
+// Release implements fs.FileOperations.Release.
+func (e *EventPoll) Release() {
+ // We need to take the lock now because files may be attempting to
+ // remove entries in parallel if they get destroyed.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Go through all entries and clean up.
+ for _, entry := range e.files {
+ entry.id.File.EventUnregister(&entry.waiter)
+ entry.file.Drop()
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (*EventPoll) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syscall.ENOSYS
+}
+
+// Write implements fs.FileOperations.Write.
+func (*EventPoll) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
+ return 0, syscall.ENOSYS
+}
+
+// eventsAvailable determines if 'e' has events available for delivery.
+func (e *EventPoll) eventsAvailable() bool {
+ e.listsMu.Lock()
+
+ for it := e.readyList.Front(); it != nil; {
+ entry := it
+ it = it.Next()
+
+ // If the entry is ready, we know 'e' has at least one entry
+ // ready for delivery.
+ ready := entry.id.File.Readiness(entry.mask)
+ if ready != 0 {
+ e.listsMu.Unlock()
+ return true
+ }
+
+ // Entry is not ready, so move it to waiting list.
+ e.readyList.Remove(entry)
+ e.waitingList.PushBack(entry)
+ entry.curList = &e.waitingList
+ }
+
+ e.listsMu.Unlock()
+
+ return false
+}
+
+// Readiness determines if the event poll object is currently readable (i.e.,
+// if there are pending events for delivery).
+func (e *EventPoll) Readiness(mask waiter.EventMask) waiter.EventMask {
+ ready := waiter.EventMask(0)
+
+ if (mask&waiter.EventIn) != 0 && e.eventsAvailable() {
+ ready |= waiter.EventIn
+ }
+
+ return ready
+}
+
+// ReadEvents returns up to max available events.
+func (e *EventPoll) ReadEvents(max int) []Event {
+ var local pollEntryList
+ var ret []Event
+
+ e.listsMu.Lock()
+
+ // Go through all entries we believe may be ready.
+ for it := e.readyList.Front(); it != nil && len(ret) < max; {
+ entry := it
+ it = it.Next()
+
+ // Check the entry's readiness. It it's not really ready, we
+ // just put it back in the waiting list and move on to the next
+ // entry.
+ ready := entry.id.File.Readiness(entry.mask) & entry.mask
+ if ready == 0 {
+ e.readyList.Remove(entry)
+ e.waitingList.PushBack(entry)
+ entry.curList = &e.waitingList
+
+ continue
+ }
+
+ // Add event to the array that will be returned to caller.
+ ret = append(ret, Event{
+ Events: uint32(ready),
+ Data: entry.userData,
+ })
+
+ // The entry is consumed, so we must move it to the disabled
+ // list in case it's one-shot, or back to the wait list if it's
+ // edge-triggered. If it's neither, we leave it in the ready
+ // list so that its readiness can be checked the next time
+ // around; however, we must move it to the end of the list so
+ // that other events can be delivered as well.
+ e.readyList.Remove(entry)
+ if entry.flags&OneShot != 0 {
+ e.disabledList.PushBack(entry)
+ entry.curList = &e.disabledList
+ } else if entry.flags&EdgeTriggered != 0 {
+ e.waitingList.PushBack(entry)
+ entry.curList = &e.waitingList
+ } else {
+ local.PushBack(entry)
+ }
+ }
+
+ e.readyList.PushBackList(&local)
+
+ e.listsMu.Unlock()
+
+ return ret
+}
+
+// readyCallback is called when one of the files we're polling becomes ready. It
+// moves said file to the readyList if it's currently in the waiting list.
+type readyCallback struct{}
+
+// Callback implements waiter.EntryCallback.Callback.
+func (*readyCallback) Callback(w *waiter.Entry) {
+ entry := w.Context.(*pollEntry)
+ e := entry.epoll
+
+ e.listsMu.Lock()
+
+ if entry.curList == &e.waitingList {
+ e.waitingList.Remove(entry)
+ e.readyList.PushBack(entry)
+ entry.curList = &e.readyList
+
+ e.Notify(waiter.EventIn)
+ }
+
+ e.listsMu.Unlock()
+}
+
+// initEntryReadiness initializes the entry's state with regards to its
+// readiness by placing it in the appropriate list and registering for
+// notifications.
+func (e *EventPoll) initEntryReadiness(entry *pollEntry) {
+ // A new entry starts off in the waiting list.
+ e.listsMu.Lock()
+ e.waitingList.PushBack(entry)
+ entry.curList = &e.waitingList
+ e.listsMu.Unlock()
+
+ // Register for event notifications.
+ f := entry.id.File
+ f.EventRegister(&entry.waiter, entry.mask)
+
+ // Check if the file happens to already be in a ready state.
+ ready := f.Readiness(entry.mask) & entry.mask
+ if ready != 0 {
+ (*readyCallback).Callback(nil, &entry.waiter)
+ }
+}
+
+// observes checks if event poll object e is directly or indirectly observing
+// event poll object ep. It uses a bounded recursive depth-first search.
+func (e *EventPoll) observes(ep *EventPoll, depthLeft int) bool {
+ // If we reached the maximum depth, we'll consider that we found it
+ // because we don't want to allow chains that are too long.
+ if depthLeft <= 0 {
+ return true
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Go through each observed file and check if it is or observes ep.
+ for id := range e.files {
+ f, ok := id.File.FileOperations.(*EventPoll)
+ if !ok {
+ continue
+ }
+
+ if f == ep || f.observes(ep, depthLeft-1) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// AddEntry adds a new file to the collection of files observed by e.
+func (e *EventPoll) AddEntry(id FileIdentifier, flags EntryFlags, mask waiter.EventMask, data [2]int32) error {
+ // Acquire cycle check lock if another event poll is being added.
+ ep, ok := id.File.FileOperations.(*EventPoll)
+ if ok {
+ cycleMu.Lock()
+ defer cycleMu.Unlock()
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Fail if the file already has an entry.
+ if _, ok := e.files[id]; ok {
+ return syscall.EEXIST
+ }
+
+ // Check if a cycle would be created. We use 4 as the limit because
+ // that's the value used by linux and we want to emulate it.
+ if ep != nil {
+ if e == ep {
+ return syscall.EINVAL
+ }
+
+ if ep.observes(e, 4) {
+ return syscall.ELOOP
+ }
+ }
+
+ // Create new entry and add it to map.
+ //
+ // N.B. Even though we are creating a weak reference here, we know it
+ // won't trigger a callback because we hold a reference to the file
+ // throughout the execution of this function.
+ entry := &pollEntry{
+ id: id,
+ userData: data,
+ epoll: e,
+ flags: flags,
+ waiter: waiter.Entry{Callback: &readyCallback{}},
+ mask: mask,
+ }
+ entry.waiter.Context = entry
+ e.files[id] = entry
+ entry.file = refs.NewWeakRef(id.File, entry)
+
+ // Initialize the readiness state of the new entry.
+ e.initEntryReadiness(entry)
+
+ return nil
+}
+
+// UpdateEntry updates the flags, mask and user data associated with a file that
+// is already part of the collection of observed files.
+func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter.EventMask, data [2]int32) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Fail if the file doesn't have an entry.
+ entry, ok := e.files[id]
+ if !ok {
+ return syscall.ENOENT
+ }
+
+ // Unregister the old mask and remove entry from the list it's in, so
+ // readyCallback is guaranteed to not be called on this entry anymore.
+ entry.id.File.EventUnregister(&entry.waiter)
+
+ // Remove entry from whatever list it's in. This ensure that no other
+ // threads have access to this entry as the only way left to find it
+ // is via e.files, but we hold e.mu, which prevents that.
+ e.listsMu.Lock()
+ entry.curList.Remove(entry)
+ e.listsMu.Unlock()
+
+ // Initialize new readiness state.
+ entry.flags = flags
+ entry.mask = mask
+ entry.userData = data
+ e.initEntryReadiness(entry)
+
+ return nil
+}
+
+// RemoveEntry a files from the collection of observed files.
+func (e *EventPoll) RemoveEntry(id FileIdentifier) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Fail if the file doesn't have an entry.
+ entry, ok := e.files[id]
+ if !ok {
+ return syscall.ENOENT
+ }
+
+ // Unregister from file first so that no concurrent attempts will be
+ // made to manipulate the file.
+ entry.id.File.EventUnregister(&entry.waiter)
+
+ // Remove from the current list.
+ e.listsMu.Lock()
+ entry.curList.Remove(entry)
+ entry.curList = nil
+ e.listsMu.Unlock()
+
+ // Remove file from map, and drop weak reference.
+ delete(e.files, id)
+ entry.file.Drop()
+
+ return nil
+}
+
+// UnregisterEpollWaiters removes the epoll waiter objects from the waiting
+// queues. This is different from Release() as the file is not dereferenced.
+func (e *EventPoll) UnregisterEpollWaiters() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for _, entry := range e.files {
+ entry.id.File.EventUnregister(&entry.waiter)
+ }
+}
diff --git a/pkg/sentry/kernel/epoll/epoll_list.go b/pkg/sentry/kernel/epoll/epoll_list.go
new file mode 100755
index 000000000..94d5c9e57
--- /dev/null
+++ b/pkg/sentry/kernel/epoll/epoll_list.go
@@ -0,0 +1,173 @@
+package epoll
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type pollEntryElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (pollEntryElementMapper) linkerFor(elem *pollEntry) *pollEntry { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type pollEntryList struct {
+ head *pollEntry
+ tail *pollEntry
+}
+
+// Reset resets list l to the empty state.
+func (l *pollEntryList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *pollEntryList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *pollEntryList) Front() *pollEntry {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *pollEntryList) Back() *pollEntry {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *pollEntryList) PushFront(e *pollEntry) {
+ pollEntryElementMapper{}.linkerFor(e).SetNext(l.head)
+ pollEntryElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ pollEntryElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *pollEntryList) PushBack(e *pollEntry) {
+ pollEntryElementMapper{}.linkerFor(e).SetNext(nil)
+ pollEntryElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ pollEntryElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *pollEntryList) PushBackList(m *pollEntryList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ pollEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ pollEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *pollEntryList) InsertAfter(b, e *pollEntry) {
+ a := pollEntryElementMapper{}.linkerFor(b).Next()
+ pollEntryElementMapper{}.linkerFor(e).SetNext(a)
+ pollEntryElementMapper{}.linkerFor(e).SetPrev(b)
+ pollEntryElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ pollEntryElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *pollEntryList) InsertBefore(a, e *pollEntry) {
+ b := pollEntryElementMapper{}.linkerFor(a).Prev()
+ pollEntryElementMapper{}.linkerFor(e).SetNext(a)
+ pollEntryElementMapper{}.linkerFor(e).SetPrev(b)
+ pollEntryElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ pollEntryElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *pollEntryList) Remove(e *pollEntry) {
+ prev := pollEntryElementMapper{}.linkerFor(e).Prev()
+ next := pollEntryElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ pollEntryElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ pollEntryElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type pollEntryEntry struct {
+ next *pollEntry
+ prev *pollEntry
+}
+
+// Next returns the entry that follows e in the list.
+func (e *pollEntryEntry) Next() *pollEntry {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *pollEntryEntry) Prev() *pollEntry {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *pollEntryEntry) SetNext(elem *pollEntry) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *pollEntryEntry) SetPrev(elem *pollEntry) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go
new file mode 100644
index 000000000..4c3c38f9e
--- /dev/null
+++ b/pkg/sentry/kernel/epoll/epoll_state.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epoll
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// afterLoad is invoked by stateify.
+func (p *pollEntry) afterLoad() {
+ p.waiter = waiter.Entry{Callback: &readyCallback{}}
+ p.waiter.Context = p
+ p.file = refs.NewWeakRef(p.id.File, p)
+ p.id.File.EventRegister(&p.waiter, p.mask)
+}
+
+// afterLoad is invoked by stateify.
+func (e *EventPoll) afterLoad() {
+ e.listsMu.Lock()
+ defer e.listsMu.Unlock()
+
+ for _, ls := range []*pollEntryList{&e.waitingList, &e.readyList, &e.disabledList} {
+ for it := ls.Front(); it != nil; it = it.Next() {
+ it.curList = ls
+ }
+ }
+
+ for it := e.waitingList.Front(); it != nil; it = it.Next() {
+ if it.id.File.Readiness(it.mask) != 0 {
+ e.waitingList.Remove(it)
+ e.readyList.PushBack(it)
+ it.curList = &e.readyList
+ e.Notify(waiter.EventIn)
+ }
+ }
+}
diff --git a/pkg/sentry/kernel/epoll/epoll_state_autogen.go b/pkg/sentry/kernel/epoll/epoll_state_autogen.go
new file mode 100755
index 000000000..a361ff37b
--- /dev/null
+++ b/pkg/sentry/kernel/epoll/epoll_state_autogen.go
@@ -0,0 +1,99 @@
+// automatically generated by stateify.
+
+package epoll
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *FileIdentifier) beforeSave() {}
+func (x *FileIdentifier) save(m state.Map) {
+ x.beforeSave()
+ m.Save("File", &x.File)
+ m.Save("Fd", &x.Fd)
+}
+
+func (x *FileIdentifier) afterLoad() {}
+func (x *FileIdentifier) load(m state.Map) {
+ m.LoadWait("File", &x.File)
+ m.Load("Fd", &x.Fd)
+}
+
+func (x *pollEntry) beforeSave() {}
+func (x *pollEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("pollEntryEntry", &x.pollEntryEntry)
+ m.Save("id", &x.id)
+ m.Save("userData", &x.userData)
+ m.Save("mask", &x.mask)
+ m.Save("flags", &x.flags)
+ m.Save("epoll", &x.epoll)
+}
+
+func (x *pollEntry) load(m state.Map) {
+ m.Load("pollEntryEntry", &x.pollEntryEntry)
+ m.LoadWait("id", &x.id)
+ m.Load("userData", &x.userData)
+ m.Load("mask", &x.mask)
+ m.Load("flags", &x.flags)
+ m.Load("epoll", &x.epoll)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *EventPoll) beforeSave() {}
+func (x *EventPoll) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.FilePipeSeek) { m.Failf("FilePipeSeek is %v, expected zero", x.FilePipeSeek) }
+ if !state.IsZeroValue(x.FileNotDirReaddir) { m.Failf("FileNotDirReaddir is %v, expected zero", x.FileNotDirReaddir) }
+ if !state.IsZeroValue(x.FileNoFsync) { m.Failf("FileNoFsync is %v, expected zero", x.FileNoFsync) }
+ if !state.IsZeroValue(x.FileNoopFlush) { m.Failf("FileNoopFlush is %v, expected zero", x.FileNoopFlush) }
+ if !state.IsZeroValue(x.FileNoIoctl) { m.Failf("FileNoIoctl is %v, expected zero", x.FileNoIoctl) }
+ if !state.IsZeroValue(x.FileNoMMap) { m.Failf("FileNoMMap is %v, expected zero", x.FileNoMMap) }
+ if !state.IsZeroValue(x.Queue) { m.Failf("Queue is %v, expected zero", x.Queue) }
+ m.Save("files", &x.files)
+ m.Save("readyList", &x.readyList)
+ m.Save("waitingList", &x.waitingList)
+ m.Save("disabledList", &x.disabledList)
+}
+
+func (x *EventPoll) load(m state.Map) {
+ m.Load("files", &x.files)
+ m.Load("readyList", &x.readyList)
+ m.Load("waitingList", &x.waitingList)
+ m.Load("disabledList", &x.disabledList)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *pollEntryList) beforeSave() {}
+func (x *pollEntryList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *pollEntryList) afterLoad() {}
+func (x *pollEntryList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *pollEntryEntry) beforeSave() {}
+func (x *pollEntryEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *pollEntryEntry) afterLoad() {}
+func (x *pollEntryEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("epoll.FileIdentifier", (*FileIdentifier)(nil), state.Fns{Save: (*FileIdentifier).save, Load: (*FileIdentifier).load})
+ state.Register("epoll.pollEntry", (*pollEntry)(nil), state.Fns{Save: (*pollEntry).save, Load: (*pollEntry).load})
+ state.Register("epoll.EventPoll", (*EventPoll)(nil), state.Fns{Save: (*EventPoll).save, Load: (*EventPoll).load})
+ state.Register("epoll.pollEntryList", (*pollEntryList)(nil), state.Fns{Save: (*pollEntryList).save, Load: (*pollEntryList).load})
+ state.Register("epoll.pollEntryEntry", (*pollEntryEntry)(nil), state.Fns{Save: (*pollEntryEntry).save, Load: (*pollEntryEntry).load})
+}
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
new file mode 100644
index 000000000..2f900be38
--- /dev/null
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -0,0 +1,283 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventfd provides an implementation of Linux's file-based event
+// notification.
+package eventfd
+
+import (
+ "math"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/fdnotifier"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/anon"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// EventOperations represents an event with the semantics of Linux's file-based event
+// notification (eventfd). Eventfds are usually internal to the Sentry but in certain
+// situations they may be converted into a host-backed eventfd.
+//
+// +stateify savable
+type EventOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // Mutex that protects accesses to the fields of this event.
+ mu sync.Mutex `state:"nosave"`
+
+ // Queue is used to notify interested parties when the event object
+ // becomes readable or writable.
+ wq waiter.Queue `state:"zerovalue"`
+
+ // val is the current value of the event counter.
+ val uint64
+
+ // semMode specifies whether the event is in "semaphore" mode.
+ semMode bool
+
+ // hostfd indicates whether this eventfd is passed through to the host.
+ hostfd int
+}
+
+// New creates a new event object with the supplied initial value and mode.
+func New(ctx context.Context, initVal uint64, semMode bool) *fs.File {
+ // name matches fs/eventfd.c:eventfd_file_create.
+ dirent := fs.NewDirent(anon.NewInode(ctx), "anon_inode:[eventfd]")
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{
+ val: initVal,
+ semMode: semMode,
+ hostfd: -1,
+ })
+}
+
+// HostFD returns the host eventfd associated with this event.
+func (e *EventOperations) HostFD() (int, error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.hostfd >= 0 {
+ return e.hostfd, nil
+ }
+
+ flags := linux.EFD_NONBLOCK
+ if e.semMode {
+ flags |= linux.EFD_SEMAPHORE
+ }
+
+ fd, _, err := syscall.Syscall(syscall.SYS_EVENTFD2, uintptr(e.val), uintptr(flags), 0)
+ if err != 0 {
+ return -1, err
+ }
+
+ if err := fdnotifier.AddFD(int32(fd), &e.wq); err != nil {
+ syscall.Close(int(fd))
+ return -1, err
+ }
+
+ e.hostfd = int(fd)
+ return e.hostfd, nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (e *EventOperations) Release() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.hostfd >= 0 {
+ fdnotifier.RemoveFD(int32(e.hostfd))
+ syscall.Close(e.hostfd)
+ e.hostfd = -1
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (e *EventOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() < 8 {
+ return 0, syscall.EINVAL
+ }
+ if err := e.read(ctx, dst); err != nil {
+ return 0, err
+ }
+ return 8, nil
+}
+
+// Write implements fs.FileOperations.Write.
+func (e *EventOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ if src.NumBytes() < 8 {
+ return 0, syscall.EINVAL
+ }
+ if err := e.write(ctx, src); err != nil {
+ return 0, err
+ }
+ return 8, nil
+}
+
+// Must be called with e.mu locked.
+func (e *EventOperations) hostRead(ctx context.Context, dst usermem.IOSequence) error {
+ var buf [8]byte
+
+ if _, err := syscall.Read(e.hostfd, buf[:]); err != nil {
+ if err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+ }
+
+ _, err := dst.CopyOut(ctx, buf[:])
+ return err
+}
+
+func (e *EventOperations) read(ctx context.Context, dst usermem.IOSequence) error {
+ e.mu.Lock()
+
+ if e.hostfd >= 0 {
+ defer e.mu.Unlock()
+ return e.hostRead(ctx, dst)
+ }
+
+ // We can't complete the read if the value is currently zero.
+ if e.val == 0 {
+ e.mu.Unlock()
+ return syserror.ErrWouldBlock
+ }
+
+ // Update the value based on the mode the event is operating in.
+ var val uint64
+ if e.semMode {
+ val = 1
+ // Consistent with Linux, this is done even if writing to memory fails.
+ e.val--
+ } else {
+ val = e.val
+ e.val = 0
+ }
+
+ e.mu.Unlock()
+
+ // Notify writers. We do this even if we were already writable because
+ // it is possible that a writer is waiting to write the maximum value
+ // to the event.
+ e.wq.Notify(waiter.EventOut)
+
+ var buf [8]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ _, err := dst.CopyOut(ctx, buf[:])
+ return err
+}
+
+// Must be called with e.mu locked.
+func (e *EventOperations) hostWrite(val uint64) error {
+ var buf [8]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ _, err := syscall.Write(e.hostfd, buf[:])
+ if err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+func (e *EventOperations) write(ctx context.Context, src usermem.IOSequence) error {
+ var buf [8]byte
+ if _, err := src.CopyIn(ctx, buf[:]); err != nil {
+ return err
+ }
+ val := usermem.ByteOrder.Uint64(buf[:])
+
+ return e.Signal(val)
+}
+
+// Signal is an internal function to signal the event fd.
+func (e *EventOperations) Signal(val uint64) error {
+ if val == math.MaxUint64 {
+ return syscall.EINVAL
+ }
+
+ e.mu.Lock()
+
+ if e.hostfd >= 0 {
+ defer e.mu.Unlock()
+ return e.hostWrite(val)
+ }
+
+ // We only allow writes that won't cause the value to go over the max
+ // uint64 minus 1.
+ if val > math.MaxUint64-1-e.val {
+ e.mu.Unlock()
+ return syserror.ErrWouldBlock
+ }
+
+ e.val += val
+ e.mu.Unlock()
+
+ // Always trigger a notification.
+ e.wq.Notify(waiter.EventIn)
+
+ return nil
+}
+
+// Readiness returns the ready events for the event fd.
+func (e *EventOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.mu.Lock()
+ if e.hostfd >= 0 {
+ defer e.mu.Unlock()
+ return fdnotifier.NonBlockingPoll(int32(e.hostfd), mask)
+ }
+
+ ready := waiter.EventMask(0)
+ if e.val > 0 {
+ ready |= waiter.EventIn
+ }
+
+ if e.val < math.MaxUint64-1 {
+ ready |= waiter.EventOut
+ }
+ e.mu.Unlock()
+
+ return mask & ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (e *EventOperations) EventRegister(entry *waiter.Entry, mask waiter.EventMask) {
+ e.wq.EventRegister(entry, mask)
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.hostfd >= 0 {
+ fdnotifier.UpdateFD(int32(e.hostfd))
+ }
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (e *EventOperations) EventUnregister(entry *waiter.Entry) {
+ e.wq.EventUnregister(entry)
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.hostfd >= 0 {
+ fdnotifier.UpdateFD(int32(e.hostfd))
+ }
+}
diff --git a/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go b/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go
new file mode 100755
index 000000000..922ff1b73
--- /dev/null
+++ b/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go
@@ -0,0 +1,27 @@
+// automatically generated by stateify.
+
+package eventfd
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *EventOperations) beforeSave() {}
+func (x *EventOperations) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.wq) { m.Failf("wq is %v, expected zero", x.wq) }
+ m.Save("val", &x.val)
+ m.Save("semMode", &x.semMode)
+ m.Save("hostfd", &x.hostfd)
+}
+
+func (x *EventOperations) afterLoad() {}
+func (x *EventOperations) load(m state.Map) {
+ m.Load("val", &x.val)
+ m.Load("semMode", &x.semMode)
+ m.Load("hostfd", &x.hostfd)
+}
+
+func init() {
+ state.Register("eventfd.EventOperations", (*EventOperations)(nil), state.Fns{Save: (*EventOperations).save, Load: (*EventOperations).load})
+}
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
new file mode 100644
index 000000000..84cd08501
--- /dev/null
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -0,0 +1,148 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fasync provides FIOASYNC related functionality.
+package fasync
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// New creates a new FileAsync.
+func New() fs.FileAsync {
+ return &FileAsync{}
+}
+
+// FileAsync sends signals when the registered file is ready for IO.
+//
+// +stateify savable
+type FileAsync struct {
+ mu sync.Mutex `state:"nosave"`
+ e waiter.Entry
+ requester *auth.Credentials
+
+ // Only one of the following is allowed to be non-nil.
+ recipientPG *kernel.ProcessGroup
+ recipientTG *kernel.ThreadGroup
+ recipientT *kernel.Task
+}
+
+// Callback sends a signal.
+func (a *FileAsync) Callback(e *waiter.Entry) {
+ a.mu.Lock()
+ if a.e.Callback == nil {
+ a.mu.Unlock()
+ return
+ }
+ t := a.recipientT
+ tg := a.recipientTG
+ if a.recipientPG != nil {
+ tg = a.recipientPG.Originator()
+ }
+ if tg != nil {
+ t = tg.Leader()
+ }
+ if t == nil {
+ // No recipient has been registered.
+ a.mu.Unlock()
+ return
+ }
+ c := t.Credentials()
+ // Logic from sigio_perm in fs/fcntl.c.
+ if a.requester.EffectiveKUID == 0 ||
+ a.requester.EffectiveKUID == c.SavedKUID ||
+ a.requester.EffectiveKUID == c.RealKUID ||
+ a.requester.RealKUID == c.SavedKUID ||
+ a.requester.RealKUID == c.RealKUID {
+ t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO))
+ }
+ a.mu.Unlock()
+}
+
+// Register sets the file which will be monitored for IO events.
+//
+// The file must not be currently registered.
+func (a *FileAsync) Register(w waiter.Waitable) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if a.e.Callback != nil {
+ panic("registering already registered file")
+ }
+
+ a.e.Callback = a
+ w.EventRegister(&a.e, waiter.EventIn|waiter.EventOut|waiter.EventErr|waiter.EventHUp)
+}
+
+// Unregister stops monitoring a file.
+//
+// The file must be currently registered.
+func (a *FileAsync) Unregister(w waiter.Waitable) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if a.e.Callback == nil {
+ panic("unregistering unregistered file")
+ }
+
+ w.EventUnregister(&a.e)
+ a.e.Callback = nil
+}
+
+// Owner returns who is currently getting signals. All return values will be
+// nil if no one is set to receive signals.
+func (a *FileAsync) Owner() (*kernel.Task, *kernel.ThreadGroup, *kernel.ProcessGroup) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.recipientT, a.recipientTG, a.recipientPG
+}
+
+// SetOwnerTask sets the owner (who will receive signals) to a specified task.
+// Only this owner will receive signals.
+func (a *FileAsync) SetOwnerTask(requester *kernel.Task, recipient *kernel.Task) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = requester.Credentials()
+ a.recipientT = recipient
+ a.recipientTG = nil
+ a.recipientPG = nil
+}
+
+// SetOwnerThreadGroup sets the owner (who will receive signals) to a specified
+// thread group. Only this owner will receive signals.
+func (a *FileAsync) SetOwnerThreadGroup(requester *kernel.Task, recipient *kernel.ThreadGroup) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = requester.Credentials()
+ a.recipientT = nil
+ a.recipientTG = recipient
+ a.recipientPG = nil
+}
+
+// SetOwnerProcessGroup sets the owner (who will receive signals) to a
+// specified process group. Only this owner will receive signals.
+func (a *FileAsync) SetOwnerProcessGroup(requester *kernel.Task, recipient *kernel.ProcessGroup) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = requester.Credentials()
+ a.recipientT = nil
+ a.recipientTG = nil
+ a.recipientPG = recipient
+}
diff --git a/pkg/sentry/kernel/fasync/fasync_state_autogen.go b/pkg/sentry/kernel/fasync/fasync_state_autogen.go
new file mode 100755
index 000000000..e162e0033
--- /dev/null
+++ b/pkg/sentry/kernel/fasync/fasync_state_autogen.go
@@ -0,0 +1,30 @@
+// automatically generated by stateify.
+
+package fasync
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *FileAsync) beforeSave() {}
+func (x *FileAsync) save(m state.Map) {
+ x.beforeSave()
+ m.Save("e", &x.e)
+ m.Save("requester", &x.requester)
+ m.Save("recipientPG", &x.recipientPG)
+ m.Save("recipientTG", &x.recipientTG)
+ m.Save("recipientT", &x.recipientT)
+}
+
+func (x *FileAsync) afterLoad() {}
+func (x *FileAsync) load(m state.Map) {
+ m.Load("e", &x.e)
+ m.Load("requester", &x.requester)
+ m.Load("recipientPG", &x.recipientPG)
+ m.Load("recipientTG", &x.recipientTG)
+ m.Load("recipientT", &x.recipientT)
+}
+
+func init() {
+ state.Register("fasync.FileAsync", (*FileAsync)(nil), state.Fns{Save: (*FileAsync).save, Load: (*FileAsync).load})
+}
diff --git a/pkg/sentry/kernel/fd_map.go b/pkg/sentry/kernel/fd_map.go
new file mode 100644
index 000000000..c5636d233
--- /dev/null
+++ b/pkg/sentry/kernel/fd_map.go
@@ -0,0 +1,364 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/lock"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+)
+
+// FDs is an ordering of FD's that can be made stable.
+type FDs []kdefs.FD
+
+func (f FDs) Len() int {
+ return len(f)
+}
+
+func (f FDs) Swap(i, j int) {
+ f[i], f[j] = f[j], f[i]
+}
+
+func (f FDs) Less(i, j int) bool {
+ return f[i] < f[j]
+}
+
+// FDFlags define flags for an individual descriptor.
+//
+// +stateify savable
+type FDFlags struct {
+ // CloseOnExec indicates the descriptor should be closed on exec.
+ CloseOnExec bool
+}
+
+// ToLinuxFileFlags converts a kernel.FDFlags object to a Linux file flags
+// representation.
+func (f FDFlags) ToLinuxFileFlags() (mask uint) {
+ if f.CloseOnExec {
+ mask |= linux.O_CLOEXEC
+ }
+ return
+}
+
+// ToLinuxFDFlags converts a kernel.FDFlags object to a Linux descriptor flags
+// representation.
+func (f FDFlags) ToLinuxFDFlags() (mask uint) {
+ if f.CloseOnExec {
+ mask |= linux.FD_CLOEXEC
+ }
+ return
+}
+
+// descriptor holds the details about a file descriptor, namely a pointer the
+// file itself and the descriptor flags.
+//
+// +stateify savable
+type descriptor struct {
+ file *fs.File
+ flags FDFlags
+}
+
+// FDMap is used to manage File references and flags.
+//
+// +stateify savable
+type FDMap struct {
+ refs.AtomicRefCount
+ k *Kernel
+ files map[kdefs.FD]descriptor
+ mu sync.RWMutex `state:"nosave"`
+ uid uint64
+}
+
+// ID returns a unique identifier for this FDMap.
+func (f *FDMap) ID() uint64 {
+ return f.uid
+}
+
+// NewFDMap allocates a new FDMap that may be used by tasks in k.
+func (k *Kernel) NewFDMap() *FDMap {
+ return &FDMap{
+ k: k,
+ files: make(map[kdefs.FD]descriptor),
+ uid: atomic.AddUint64(&k.fdMapUids, 1),
+ }
+}
+
+// destroy removes all of the file descriptors from the map.
+func (f *FDMap) destroy() {
+ f.RemoveIf(func(*fs.File, FDFlags) bool {
+ return true
+ })
+}
+
+// DecRef implements RefCounter.DecRef with destructor f.destroy.
+func (f *FDMap) DecRef() {
+ f.DecRefWithDestructor(f.destroy)
+}
+
+// Size returns the number of file descriptor slots currently allocated.
+func (f *FDMap) Size() int {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ return len(f.files)
+}
+
+// String is a stringer for FDMap.
+func (f *FDMap) String() string {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ var b bytes.Buffer
+ for k, v := range f.files {
+ n, _ := v.file.Dirent.FullName(nil /* root */)
+ b.WriteString(fmt.Sprintf("\tfd:%d => name %s\n", k, n))
+ }
+ return b.String()
+}
+
+// NewFDFrom allocates a new FD guaranteed to be the lowest number available
+// greater than or equal to from. This property is important as Unix programs
+// tend to count on this allocation order.
+func (f *FDMap) NewFDFrom(fd kdefs.FD, file *fs.File, flags FDFlags, limitSet *limits.LimitSet) (kdefs.FD, error) {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return 0, syscall.EINVAL
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // Finds the lowest fd not in the handles map.
+ lim := limitSet.Get(limits.NumberOfFiles)
+ for i := fd; lim.Cur == limits.Infinity || i < kdefs.FD(lim.Cur); i++ {
+ if _, ok := f.files[i]; !ok {
+ file.IncRef()
+ f.files[i] = descriptor{file, flags}
+ return i, nil
+ }
+ }
+
+ return -1, syscall.EMFILE
+}
+
+// NewFDAt sets the file reference for the given FD. If there is an
+// active reference for that FD, the ref count for that existing reference
+// is decremented.
+func (f *FDMap) NewFDAt(fd kdefs.FD, file *fs.File, flags FDFlags, limitSet *limits.LimitSet) error {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return syscall.EBADF
+ }
+
+ // In this one case we do not do a defer of the Unlock. The
+ // reason is that we must have done all the work needed for
+ // discarding any old open file before we return to the
+ // caller. In other words, the DecRef(), below, must have
+ // completed by the time we return to the caller to ensure
+ // side effects are, in fact, effected. A classic example is
+ // dup2(fd1, fd2); if fd2 was already open, it must be closed,
+ // and we don't want to resume the caller until it is; we have
+ // to block on the DecRef(). Hence we can not just do a 'go
+ // oldfile.DecRef()', since there would be no guarantee that
+ // it would be done before we the caller resumed. Since we
+ // must wait for the DecRef() to finish, and that could take
+ // time, it's best to first call f.muUnlock beore so we are
+ // not blocking other uses of this FDMap on the DecRef() call.
+ f.mu.Lock()
+ oldDesc, oldExists := f.files[fd]
+ lim := limitSet.Get(limits.NumberOfFiles).Cur
+ // if we're closing one then the effective limit is one
+ // more than the actual limit.
+ if oldExists && lim != limits.Infinity {
+ lim++
+ }
+ if lim != limits.Infinity && fd >= kdefs.FD(lim) {
+ f.mu.Unlock()
+ return syscall.EMFILE
+ }
+
+ file.IncRef()
+ f.files[fd] = descriptor{file, flags}
+ f.mu.Unlock()
+
+ if oldExists {
+ oldDesc.file.DecRef()
+ }
+ return nil
+}
+
+// SetFlags sets the flags for the given file descriptor, if it is valid.
+func (f *FDMap) SetFlags(fd kdefs.FD, flags FDFlags) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ desc, ok := f.files[fd]
+ if !ok {
+ return
+ }
+
+ f.files[fd] = descriptor{desc.file, flags}
+}
+
+// GetDescriptor returns a reference to the file and the flags for the FD. It
+// bumps its reference count as well. It returns nil if there is no File
+// for the FD, i.e. if the FD is invalid. The caller must use DecRef
+// when they are done.
+func (f *FDMap) GetDescriptor(fd kdefs.FD) (*fs.File, FDFlags) {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ if desc, ok := f.files[fd]; ok {
+ desc.file.IncRef()
+ return desc.file, desc.flags
+ }
+ return nil, FDFlags{}
+}
+
+// GetFile returns a reference to the File for the FD and bumps
+// its reference count as well. It returns nil if there is no File
+// for the FD, i.e. if the FD is invalid. The caller must use DecRef
+// when they are done.
+func (f *FDMap) GetFile(fd kdefs.FD) *fs.File {
+ f.mu.RLock()
+ if desc, ok := f.files[fd]; ok {
+ desc.file.IncRef()
+ f.mu.RUnlock()
+ return desc.file
+ }
+ f.mu.RUnlock()
+ return nil
+}
+
+// fds returns an ordering of FDs.
+func (f *FDMap) fds() FDs {
+ fds := make(FDs, 0, len(f.files))
+ for fd := range f.files {
+ fds = append(fds, fd)
+ }
+ sort.Sort(fds)
+ return fds
+}
+
+// GetFDs returns a list of valid fds.
+func (f *FDMap) GetFDs() FDs {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.fds()
+}
+
+// GetRefs returns a stable slice of references to all files and bumps the
+// reference count on each. The caller must use DecRef on each reference when
+// they're done using the slice.
+func (f *FDMap) GetRefs() []*fs.File {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ fds := f.fds()
+ fs := make([]*fs.File, 0, len(fds))
+ for _, fd := range fds {
+ desc := f.files[fd]
+ desc.file.IncRef()
+ fs = append(fs, desc.file)
+ }
+ return fs
+}
+
+// Fork returns an independent FDMap pointing to the same descriptors.
+func (f *FDMap) Fork() *FDMap {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ clone := f.k.NewFDMap()
+
+ // Grab a extra reference for every file.
+ for fd, desc := range f.files {
+ desc.file.IncRef()
+ clone.files[fd] = desc
+ }
+
+ // That's it!
+ return clone
+}
+
+// unlock releases all file locks held by this FDMap's uid. Must only be
+// called on a non-nil *fs.File.
+func (f *FDMap) unlock(file *fs.File) {
+ id := lock.UniqueID(f.ID())
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(id, lock.LockRange{0, lock.LockEOF})
+}
+
+// inotifyFileClose generates the appropriate inotify events for f being closed.
+func inotifyFileClose(f *fs.File) {
+ var ev uint32
+ d := f.Dirent
+
+ if fs.IsDir(d.Inode.StableAttr) {
+ ev |= linux.IN_ISDIR
+ }
+
+ if f.Flags().Write {
+ ev |= linux.IN_CLOSE_WRITE
+ } else {
+ ev |= linux.IN_CLOSE_NOWRITE
+ }
+
+ d.InotifyEvent(ev, 0)
+}
+
+// Remove removes an FD from the FDMap, and returns (File, true) if a File
+// one was found. Callers are expected to decrement the reference count on
+// the File. Otherwise returns (nil, false).
+func (f *FDMap) Remove(fd kdefs.FD) (*fs.File, bool) {
+ f.mu.Lock()
+ desc := f.files[fd]
+ delete(f.files, fd)
+ f.mu.Unlock()
+ if desc.file != nil {
+ f.unlock(desc.file)
+ inotifyFileClose(desc.file)
+ return desc.file, true
+ }
+ return nil, false
+}
+
+// RemoveIf removes all FDs where cond is true.
+func (f *FDMap) RemoveIf(cond func(*fs.File, FDFlags) bool) {
+ var removed []*fs.File
+ f.mu.Lock()
+ for fd, desc := range f.files {
+ if desc.file != nil && cond(desc.file, desc.flags) {
+ delete(f.files, fd)
+ removed = append(removed, desc.file)
+ }
+ }
+ f.mu.Unlock()
+
+ for _, file := range removed {
+ f.unlock(file)
+ inotifyFileClose(file)
+ file.DecRef()
+ }
+}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
new file mode 100644
index 000000000..d8115f59a
--- /dev/null
+++ b/pkg/sentry/kernel/fs_context.go
@@ -0,0 +1,187 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// FSContext contains filesystem context.
+//
+// This includes umask and working directory.
+//
+// +stateify savable
+type FSContext struct {
+ refs.AtomicRefCount
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // root is the filesystem root. Will be nil iff the FSContext has been
+ // destroyed.
+ root *fs.Dirent
+
+ // cwd is the current working directory. Will be nil iff the FSContext
+ // has been destroyed.
+ cwd *fs.Dirent
+
+ // umask is the current file mode creation mask. When a thread using this
+ // context invokes a syscall that creates a file, bits set in umask are
+ // removed from the permissions that the file is created with.
+ umask uint
+}
+
+// newFSContext returns a new filesystem context.
+func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext {
+ root.IncRef()
+ cwd.IncRef()
+ return &FSContext{
+ root: root,
+ cwd: cwd,
+ umask: umask,
+ }
+}
+
+// destroy is the destructor for an FSContext.
+//
+// This will call DecRef on both root and cwd Dirents. If either call to
+// DecRef returns an error, then it will be propigated. If both calls to
+// DecRef return an error, then the one from root.DecRef will be propigated.
+//
+// Note that there may still be calls to WorkingDirectory() or RootDirectory()
+// (that return nil). This is because valid references may still be held via
+// proc files or other mechanisms.
+func (f *FSContext) destroy() {
+ // Hold f.mu so that we don't race with RootDirectory() and
+ // WorkingDirectory().
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.root.DecRef()
+ f.root = nil
+
+ f.cwd.DecRef()
+ f.cwd = nil
+}
+
+// DecRef implements RefCounter.DecRef with destructor f.destroy.
+func (f *FSContext) DecRef() {
+ f.DecRefWithDestructor(f.destroy)
+}
+
+// Fork forks this FSContext.
+//
+// This is not a valid call after destroy.
+func (f *FSContext) Fork() *FSContext {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.cwd.IncRef()
+ f.root.IncRef()
+ return &FSContext{
+ cwd: f.cwd,
+ root: f.root,
+ umask: f.umask,
+ }
+}
+
+// WorkingDirectory returns the current working directory.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) WorkingDirectory() *fs.Dirent {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ if f.cwd != nil {
+ f.cwd.IncRef()
+ }
+ return f.cwd
+}
+
+// SetWorkingDirectory sets the current working directory.
+// This will take an extra reference on the Dirent.
+//
+// This is not a valid call after destroy.
+func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) {
+ if d == nil {
+ panic("FSContext.SetWorkingDirectory called with nil dirent")
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if f.cwd == nil {
+ panic(fmt.Sprintf("FSContext.SetWorkingDirectory(%v)) called after destroy", d))
+ }
+
+ old := f.cwd
+ f.cwd = d
+ d.IncRef()
+ old.DecRef()
+}
+
+// RootDirectory returns the current filesystem root.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) RootDirectory() *fs.Dirent {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ if f.root != nil {
+ f.root.IncRef()
+ }
+ return f.root
+}
+
+// SetRootDirectory sets the root directory.
+// This will take an extra reference on the Dirent.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
+ if d == nil {
+ panic("FSContext.SetRootDirectory called with nil dirent")
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if f.root == nil {
+ panic(fmt.Sprintf("FSContext.SetRootDirectory(%v)) called after destroy", d))
+ }
+
+ old := f.root
+ f.root = d
+ d.IncRef()
+ old.DecRef()
+}
+
+// Umask returns the current umask.
+func (f *FSContext) Umask() uint {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ return f.umask
+}
+
+// SwapUmask atomically sets the current umask and returns the old umask.
+func (f *FSContext) SwapUmask(mask uint) uint {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ old := f.umask
+ f.umask = mask
+ return old
+}
diff --git a/pkg/sentry/kernel/futex/atomicptr_bucket.go b/pkg/sentry/kernel/futex/atomicptr_bucket.go
new file mode 100755
index 000000000..2251a6e72
--- /dev/null
+++ b/pkg/sentry/kernel/futex/atomicptr_bucket.go
@@ -0,0 +1,27 @@
+package futex
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// An AtomicPtr is a pointer to a value of type Value that can be atomically
+// loaded and stored. The zero value of an AtomicPtr represents nil.
+//
+// Note that copying AtomicPtr by value performs a non-atomic read of the
+// stored pointer, which is unsafe if Store() can be called concurrently; in
+// this case, do `dst.Store(src.Load())` instead.
+type AtomicPtrBucket struct {
+ ptr unsafe.Pointer
+}
+
+// Load returns the value set by the most recent Store. It returns nil if there
+// has been no previous call to Store.
+func (p *AtomicPtrBucket) Load() *bucket {
+ return (*bucket)(atomic.LoadPointer(&p.ptr))
+}
+
+// Store sets the value returned by Load to x.
+func (p *AtomicPtrBucket) Store(x *bucket) {
+ atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x))
+}
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
new file mode 100644
index 000000000..bb38eb81e
--- /dev/null
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -0,0 +1,783 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package futex provides an implementation of the futex interface as found in
+// the Linux kernel. It allows one to easily transform Wait() calls into waits
+// on a channel, which is useful in a Go-based kernel, for example.
+package futex
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// KeyKind indicates the type of a Key.
+type KeyKind int
+
+const (
+ // KindPrivate indicates a private futex (a futex syscall with the
+ // FUTEX_PRIVATE_FLAG set).
+ KindPrivate KeyKind = iota
+
+ // KindSharedPrivate indicates a shared futex on a private memory mapping.
+ // Although KindPrivate and KindSharedPrivate futexes both use memory
+ // addresses to identify futexes, they do not interoperate (in Linux, the
+ // two are distinguished by the FUT_OFF_MMSHARED flag, which is used in key
+ // comparison).
+ KindSharedPrivate
+
+ // KindSharedMappable indicates a shared futex on a memory mapping other
+ // than a private anonymous memory mapping.
+ KindSharedMappable
+)
+
+// Key represents something that a futex waiter may wait on.
+type Key struct {
+ // Kind is the type of the Key.
+ Kind KeyKind
+
+ // Mappable is the memory-mapped object that is represented by the Key.
+ // Mappable is always nil if Kind is not KindSharedMappable, and may be nil
+ // even if it is.
+ Mappable memmap.Mappable
+
+ // MappingIdentity is the MappingIdentity associated with Mappable.
+ // MappingIdentity is always nil is Mappable is nil, and may be nil even if
+ // it isn't.
+ MappingIdentity memmap.MappingIdentity
+
+ // If Kind is KindPrivate or KindSharedPrivate, Offset is the represented
+ // memory address. Otherwise, Offset is the represented offset into
+ // Mappable.
+ Offset uint64
+}
+
+func (k *Key) release() {
+ if k.MappingIdentity != nil {
+ k.MappingIdentity.DecRef()
+ }
+ k.Mappable = nil
+ k.MappingIdentity = nil
+}
+
+func (k *Key) clone() Key {
+ if k.MappingIdentity != nil {
+ k.MappingIdentity.IncRef()
+ }
+ return *k
+}
+
+// Preconditions: k.Kind == KindPrivate or KindSharedPrivate.
+func (k *Key) addr() usermem.Addr {
+ return usermem.Addr(k.Offset)
+}
+
+// matches returns true if a wakeup on k2 should wake a waiter waiting on k.
+func (k *Key) matches(k2 *Key) bool {
+ // k.MappingIdentity is ignored; it's only used for reference counting.
+ return k.Kind == k2.Kind && k.Mappable == k2.Mappable && k.Offset == k2.Offset
+}
+
+// Target abstracts memory accesses and keys.
+type Target interface {
+ // SwapUint32 gives access to usermem.IO.SwapUint32.
+ SwapUint32(addr usermem.Addr, new uint32) (uint32, error)
+
+ // CompareAndSwap gives access to usermem.IO.CompareAndSwapUint32.
+ CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error)
+
+ // LoadUint32 gives access to usermem.IO.LoadUint32.
+ LoadUint32(addr usermem.Addr) (uint32, error)
+
+ // GetSharedKey returns a Key with kind KindSharedPrivate or
+ // KindSharedMappable corresponding to the memory mapped at address addr.
+ //
+ // If GetSharedKey returns a Key with a non-nil MappingIdentity, a
+ // reference is held on the MappingIdentity, which must be dropped by the
+ // caller when the Key is no longer in use.
+ GetSharedKey(addr usermem.Addr) (Key, error)
+}
+
+// check performs a basic equality check on the given address.
+func check(t Target, addr usermem.Addr, val uint32) error {
+ cur, err := t.LoadUint32(addr)
+ if err != nil {
+ return err
+ }
+ if cur != val {
+ return syserror.EAGAIN
+ }
+ return nil
+}
+
+// atomicOp performs a complex operation on the given address.
+func atomicOp(t Target, addr usermem.Addr, opIn uint32) (bool, error) {
+ opType := (opIn >> 28) & 0xf
+ cmp := (opIn >> 24) & 0xf
+ opArg := (opIn >> 12) & 0xfff
+ cmpArg := opIn & 0xfff
+
+ if opType&linux.FUTEX_OP_OPARG_SHIFT != 0 {
+ opArg = 1 << opArg
+ opType &^= linux.FUTEX_OP_OPARG_SHIFT // Clear flag.
+ }
+
+ var (
+ oldVal uint32
+ err error
+ )
+ if opType == linux.FUTEX_OP_SET {
+ oldVal, err = t.SwapUint32(addr, opArg)
+ if err != nil {
+ return false, err
+ }
+ } else {
+ for {
+ oldVal, err = t.LoadUint32(addr)
+ if err != nil {
+ return false, err
+ }
+ var newVal uint32
+ switch opType {
+ case linux.FUTEX_OP_ADD:
+ newVal = oldVal + opArg
+ case linux.FUTEX_OP_OR:
+ newVal = oldVal | opArg
+ case linux.FUTEX_OP_ANDN:
+ newVal = oldVal &^ opArg
+ case linux.FUTEX_OP_XOR:
+ newVal = oldVal ^ opArg
+ default:
+ return false, syserror.ENOSYS
+ }
+ prev, err := t.CompareAndSwapUint32(addr, oldVal, newVal)
+ if err != nil {
+ return false, err
+ }
+ if prev == oldVal {
+ break // Success.
+ }
+ }
+ }
+
+ switch cmp {
+ case linux.FUTEX_OP_CMP_EQ:
+ return oldVal == cmpArg, nil
+ case linux.FUTEX_OP_CMP_NE:
+ return oldVal != cmpArg, nil
+ case linux.FUTEX_OP_CMP_LT:
+ return oldVal < cmpArg, nil
+ case linux.FUTEX_OP_CMP_LE:
+ return oldVal <= cmpArg, nil
+ case linux.FUTEX_OP_CMP_GT:
+ return oldVal > cmpArg, nil
+ case linux.FUTEX_OP_CMP_GE:
+ return oldVal >= cmpArg, nil
+ default:
+ return false, syserror.ENOSYS
+ }
+}
+
+// Waiter is the struct which gets enqueued into buckets for wake up routines
+// and requeue routines to scan and notify. Once a Waiter has been enqueued by
+// WaitPrepare(), callers may listen on C for wake up events.
+type Waiter struct {
+ // Synchronization:
+ //
+ // - A Waiter that is not enqueued in a bucket is exclusively owned (no
+ // synchronization applies).
+ //
+ // - A Waiter is enqueued in a bucket by calling WaitPrepare(). After this,
+ // waiterEntry, bucket, and key are protected by the bucket.mu ("bucket
+ // lock") of the containing bucket, and bitmask is immutable. Note that
+ // since bucket is mutated using atomic memory operations, bucket.Load()
+ // may be called without holding the bucket lock, although it may change
+ // racily. See WaitComplete().
+ //
+ // - A Waiter is only guaranteed to be no longer queued after calling
+ // WaitComplete().
+
+ // waiterEntry links Waiter into bucket.waiters.
+ waiterEntry
+
+ // bucket is the bucket this waiter is queued in. If bucket is nil, the
+ // waiter is not waiting and is not in any bucket.
+ bucket AtomicPtrBucket
+
+ // C is sent to when the Waiter is woken.
+ C chan struct{}
+
+ // key is what this waiter is waiting on.
+ key Key
+
+ // The bitmask we're waiting on.
+ // This is used the case of a FUTEX_WAKE_BITSET.
+ bitmask uint32
+
+ // tid is the thread ID for the waiter in case this is a PI mutex.
+ tid uint32
+}
+
+// NewWaiter returns a new unqueued Waiter.
+func NewWaiter() *Waiter {
+ return &Waiter{
+ C: make(chan struct{}, 1),
+ }
+}
+
+// woken returns true if w has been woken since the last call to WaitPrepare.
+func (w *Waiter) woken() bool {
+ return len(w.C) != 0
+}
+
+// bucket holds a list of waiters for a given address hash.
+//
+// +stateify savable
+type bucket struct {
+ // mu protects waiters and contained Waiter state. See comment in Waiter.
+ mu sync.Mutex `state:"nosave"`
+
+ waiters waiterList `state:"zerovalue"`
+}
+
+// wakeLocked wakes up to n waiters matching the bitmask at the addr for this
+// bucket and returns the number of waiters woken.
+//
+// Preconditions: b.mu must be locked.
+func (b *bucket) wakeLocked(key *Key, bitmask uint32, n int) int {
+ done := 0
+ for w := b.waiters.Front(); done < n && w != nil; {
+ if !w.key.matches(key) || w.bitmask&bitmask == 0 {
+ // Not matching.
+ w = w.Next()
+ continue
+ }
+
+ // Remove from the bucket and wake the waiter.
+ woke := w
+ w = w.Next() // Next iteration.
+ b.wakeWaiterLocked(woke)
+ done++
+ }
+ return done
+}
+
+func (b *bucket) wakeWaiterLocked(w *Waiter) {
+ // Remove from the bucket and wake the waiter.
+ b.waiters.Remove(w)
+ w.C <- struct{}{}
+
+ // NOTE: The above channel write establishes a write barrier according
+ // to the memory model, so nothing may be ordered around it. Since
+ // we've dequeued w and will never touch it again, we can safely
+ // store nil to w.bucket here and allow the WaitComplete() to
+ // short-circuit grabbing the bucket lock. If they somehow miss the
+ // store, we are still holding the lock, so we can know that they won't
+ // dequeue w, assume it's free and have the below operation
+ // afterwards.
+ w.bucket.Store(nil)
+}
+
+// requeueLocked takes n waiters from the bucket and moves them to naddr on the
+// bucket "to".
+//
+// Preconditions: b and to must be locked.
+func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int {
+ done := 0
+ for w := b.waiters.Front(); done < n && w != nil; {
+ if !w.key.matches(key) {
+ // Not matching.
+ w = w.Next()
+ continue
+ }
+
+ requeued := w
+ w = w.Next() // Next iteration.
+ b.waiters.Remove(requeued)
+ requeued.key.release()
+ requeued.key = nkey.clone()
+ to.waiters.PushBack(requeued)
+ requeued.bucket.Store(to)
+ done++
+ }
+ return done
+}
+
+const (
+ // bucketCount is the number of buckets per Manager. By having many of
+ // these we reduce contention when concurrent yet unrelated calls are made.
+ bucketCount = 1 << bucketCountBits
+ bucketCountBits = 10
+)
+
+// getKey returns a Key representing address addr in c.
+func getKey(t Target, addr usermem.Addr, private bool) (Key, error) {
+ // Ensure the address is aligned.
+ // It must be a DWORD boundary.
+ if addr&0x3 != 0 {
+ return Key{}, syserror.EINVAL
+ }
+ if private {
+ return Key{Kind: KindPrivate, Offset: uint64(addr)}, nil
+ }
+ return t.GetSharedKey(addr)
+}
+
+// bucketIndexForAddr returns the index into Manager.buckets for addr.
+func bucketIndexForAddr(addr usermem.Addr) uintptr {
+ // - The bottom 2 bits of addr must be 0, per getKey.
+ //
+ // - On amd64, the top 16 bits of addr (bits 48-63) must be equal to bit 47
+ // for a canonical address, and (on all existing platforms) bit 47 must be
+ // 0 for an application address.
+ //
+ // Thus 19 bits of addr are "useless" for hashing, leaving only 45 "useful"
+ // bits. We choose one of the simplest possible hash functions that at
+ // least uses all 45 useful bits in the output, given that bucketCountBits
+ // == 10. This hash function also has the property that it will usually map
+ // adjacent addresses to adjacent buckets, slightly improving memory
+ // locality when an application synchronization structure uses multiple
+ // nearby futexes.
+ //
+ // Note that despite the large number of arithmetic operations in the
+ // function, many components can be computed in parallel, such that the
+ // critical path is 1 bit shift + 3 additions (2 in h1, then h1 + h2). This
+ // is also why h1 and h2 are grouped separately; for "(addr >> 2) + ... +
+ // (addr >> 42)" without any additional grouping, the compiler puts all 4
+ // additions in the critical path.
+ h1 := uintptr(addr>>2) + uintptr(addr>>12) + uintptr(addr>>22)
+ h2 := uintptr(addr>>32) + uintptr(addr>>42)
+ return (h1 + h2) % bucketCount
+}
+
+// Manager holds futex state for a single virtual address space.
+//
+// +stateify savable
+type Manager struct {
+ // privateBuckets holds buckets for KindPrivate and KindSharedPrivate
+ // futexes.
+ privateBuckets [bucketCount]bucket `state:"zerovalue"`
+
+ // sharedBucket is the bucket for KindSharedMappable futexes. sharedBucket
+ // may be shared by multiple Managers. The sharedBucket pointer is
+ // immutable.
+ sharedBucket *bucket
+}
+
+// NewManager returns an initialized futex manager.
+func NewManager() *Manager {
+ return &Manager{
+ sharedBucket: &bucket{},
+ }
+}
+
+// Fork returns a new Manager. Shared futex clients using the returned Manager
+// may interoperate with those using m.
+func (m *Manager) Fork() *Manager {
+ return &Manager{
+ sharedBucket: m.sharedBucket,
+ }
+}
+
+// lockBucket returns a locked bucket for the given key.
+func (m *Manager) lockBucket(k *Key) *bucket {
+ var b *bucket
+ if k.Kind == KindSharedMappable {
+ b = m.sharedBucket
+ } else {
+ b = &m.privateBuckets[bucketIndexForAddr(k.addr())]
+ }
+ b.mu.Lock()
+ return b
+}
+
+// lockBuckets returns locked buckets for the given keys.
+func (m *Manager) lockBuckets(k1, k2 *Key) (*bucket, *bucket) {
+ // Buckets must be consistently ordered to avoid circular lock
+ // dependencies. We order buckets in m.privateBuckets by index (lowest
+ // index first), and all buckets in m.privateBuckets precede
+ // m.sharedBucket.
+
+ // Handle the common case first:
+ if k1.Kind != KindSharedMappable && k2.Kind != KindSharedMappable {
+ i1 := bucketIndexForAddr(k1.addr())
+ i2 := bucketIndexForAddr(k2.addr())
+ b1 := &m.privateBuckets[i1]
+ b2 := &m.privateBuckets[i2]
+ switch {
+ case i1 < i2:
+ b1.mu.Lock()
+ b2.mu.Lock()
+ case i2 < i1:
+ b2.mu.Lock()
+ b1.mu.Lock()
+ default:
+ b1.mu.Lock()
+ }
+ return b1, b2
+ }
+
+ // At least one of b1 or b2 should be m.sharedBucket.
+ b1 := m.sharedBucket
+ b2 := m.sharedBucket
+ if k1.Kind != KindSharedMappable {
+ b1 = m.lockBucket(k1)
+ } else if k2.Kind != KindSharedMappable {
+ b2 = m.lockBucket(k2)
+ }
+ m.sharedBucket.mu.Lock()
+ return b1, b2
+}
+
+// Wake wakes up to n waiters matching the bitmask on the given addr.
+// The number of waiters woken is returned.
+func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32, n int) (int, error) {
+ // This function is very hot; avoid defer.
+ k, err := getKey(t, addr, private)
+ if err != nil {
+ return 0, err
+ }
+
+ b := m.lockBucket(&k)
+ r := b.wakeLocked(&k, bitmask, n)
+
+ b.mu.Unlock()
+ k.release()
+ return r, nil
+}
+
+func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, checkval bool, val uint32, nwake int, nreq int) (int, error) {
+ k1, err := getKey(t, addr, private)
+ if err != nil {
+ return 0, err
+ }
+ defer k1.release()
+ k2, err := getKey(t, naddr, private)
+ if err != nil {
+ return 0, err
+ }
+ defer k2.release()
+
+ b1, b2 := m.lockBuckets(&k1, &k2)
+ defer b1.mu.Unlock()
+ if b2 != b1 {
+ defer b2.mu.Unlock()
+ }
+
+ if checkval {
+ if err := check(t, addr, val); err != nil {
+ return 0, err
+ }
+ }
+
+ // Wake the number required.
+ done := b1.wakeLocked(&k1, ^uint32(0), nwake)
+
+ // Requeue the number required.
+ b1.requeueLocked(b2, &k1, &k2, nreq)
+
+ return done, nil
+}
+
+// Requeue wakes up to nwake waiters on the given addr, and unconditionally
+// requeues up to nreq waiters on naddr.
+func (m *Manager) Requeue(t Target, addr, naddr usermem.Addr, private bool, nwake int, nreq int) (int, error) {
+ return m.doRequeue(t, addr, naddr, private, false, 0, nwake, nreq)
+}
+
+// RequeueCmp atomically checks that the addr contains val (via the Target),
+// wakes up to nwake waiters on addr and then unconditionally requeues nreq
+// waiters on naddr.
+func (m *Manager) RequeueCmp(t Target, addr, naddr usermem.Addr, private bool, val uint32, nwake int, nreq int) (int, error) {
+ return m.doRequeue(t, addr, naddr, private, true, val, nwake, nreq)
+}
+
+// WakeOp atomically applies op to the memory address addr2, wakes up to nwake1
+// waiters unconditionally from addr1, and, based on the original value at addr2
+// and a comparison encoded in op, wakes up to nwake2 waiters from addr2.
+// It returns the total number of waiters woken.
+func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwake1 int, nwake2 int, op uint32) (int, error) {
+ k1, err := getKey(t, addr1, private)
+ if err != nil {
+ return 0, err
+ }
+ defer k1.release()
+ k2, err := getKey(t, addr2, private)
+ if err != nil {
+ return 0, err
+ }
+ defer k2.release()
+
+ b1, b2 := m.lockBuckets(&k1, &k2)
+ defer b1.mu.Unlock()
+ if b2 != b1 {
+ defer b2.mu.Unlock()
+ }
+
+ done := 0
+ cond, err := atomicOp(t, addr2, op)
+ if err != nil {
+ return 0, err
+ }
+
+ // Wake up up to nwake1 entries from the first bucket.
+ done = b1.wakeLocked(&k1, ^uint32(0), nwake1)
+
+ // Wake up up to nwake2 entries from the second bucket if the
+ // operation yielded true.
+ if cond {
+ done += b2.wakeLocked(&k2, ^uint32(0), nwake2)
+ }
+
+ return done, nil
+}
+
+// WaitPrepare atomically checks that addr contains val (via the Checker), then
+// enqueues w to be woken by a send to w.C. If WaitPrepare returns nil, the
+// Waiter must be subsequently removed by calling WaitComplete, whether or not
+// a wakeup is received on w.C.
+func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) error {
+ k, err := getKey(t, addr, private)
+ if err != nil {
+ return err
+ }
+ // Ownership of k is transferred to w below.
+
+ // Prepare the Waiter before taking the bucket lock.
+ select {
+ case <-w.C:
+ default:
+ }
+ w.key = k
+ w.bitmask = bitmask
+
+ b := m.lockBucket(&k)
+ // This function is very hot; avoid defer.
+
+ // Perform our atomic check.
+ if err := check(t, addr, val); err != nil {
+ b.mu.Unlock()
+ w.key.release()
+ return err
+ }
+
+ // Add the waiter to the bucket.
+ b.waiters.PushBack(w)
+ w.bucket.Store(b)
+
+ b.mu.Unlock()
+ return nil
+}
+
+// WaitComplete must be called when a Waiter previously added by WaitPrepare is
+// no longer eligible to be woken.
+func (m *Manager) WaitComplete(w *Waiter) {
+ // Remove w from the bucket it's in.
+ for {
+ b := w.bucket.Load()
+
+ // If b is nil, the waiter isn't in any bucket anymore. This can't be
+ // racy because the waiter can't be concurrently re-queued in another
+ // bucket.
+ if b == nil {
+ break
+ }
+
+ // Take the bucket lock. Note that without holding the bucket lock, the
+ // waiter is not guaranteed to stay in that bucket, so after we take
+ // the bucket lock, we must ensure that the bucket hasn't changed: if
+ // it happens to have changed, we release the old bucket lock and try
+ // again with the new bucket; if it hasn't changed, we know it won't
+ // change now because we hold the lock.
+ b.mu.Lock()
+ if b != w.bucket.Load() {
+ b.mu.Unlock()
+ continue
+ }
+
+ // Remove waiter from bucket.
+ b.waiters.Remove(w)
+ w.bucket.Store(nil)
+ b.mu.Unlock()
+ break
+ }
+
+ // Release references held by the waiter.
+ w.key.release()
+}
+
+// LockPI attempts to lock the futex following the Priority-inheritance futex
+// rules. The lock is acquired only when 'addr' points to 0. The TID of the
+// calling task is set to 'addr' to indicate the futex is owned. It returns true
+// if the futex was successfully acquired.
+//
+// FUTEX_OWNER_DIED is only set by the Linux when robust lists are in use (see
+// exit_robust_list()). Given we don't support robust lists, although handled
+// below, it's never set.
+func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, private, try bool) (bool, error) {
+ k, err := getKey(t, addr, private)
+ if err != nil {
+ return false, err
+ }
+ // Ownership of k is transferred to w below.
+
+ // Prepare the Waiter before taking the bucket lock.
+ select {
+ case <-w.C:
+ default:
+ }
+ w.key = k
+ w.tid = tid
+
+ b := m.lockBucket(&k)
+ // Hot function: avoid defers.
+
+ success, err := m.lockPILocked(w, t, addr, tid, b, try)
+ if err != nil {
+ w.key.release()
+ b.mu.Unlock()
+ return false, err
+ }
+ if success || try {
+ // Release waiter if it's not going to be a wait.
+ w.key.release()
+ }
+ b.mu.Unlock()
+ return success, nil
+}
+
+func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint32, b *bucket, try bool) (bool, error) {
+ for {
+ cur, err := t.LoadUint32(addr)
+ if err != nil {
+ return false, err
+ }
+ if (cur & linux.FUTEX_TID_MASK) == tid {
+ return false, syserror.EDEADLK
+ }
+
+ if (cur & linux.FUTEX_TID_MASK) == 0 {
+ // No owner and no waiters, try to acquire the futex.
+
+ // Set TID and preserve owner died status.
+ val := tid
+ val |= cur & linux.FUTEX_OWNER_DIED
+ prev, err := t.CompareAndSwapUint32(addr, cur, val)
+ if err != nil {
+ return false, err
+ }
+ if prev != cur {
+ // CAS failed, retry...
+ // Linux reacquires the bucket lock on retries, which will re-lookup the
+ // mapping at the futex address. However, retrying while holding the
+ // lock is more efficient and reduces the chance of another conflict.
+ continue
+ }
+ // Futex acquired.
+ return true, nil
+ }
+
+ // Futex is already owned, prepare to wait.
+
+ if try {
+ // Caller doesn't want to wait.
+ return false, nil
+ }
+
+ // Set waiters bit if not set yet.
+ if cur&linux.FUTEX_WAITERS == 0 {
+ prev, err := t.CompareAndSwapUint32(addr, cur, cur|linux.FUTEX_WAITERS)
+ if err != nil {
+ return false, err
+ }
+ if prev != cur {
+ // CAS failed, retry...
+ continue
+ }
+ }
+
+ // Add the waiter to the bucket.
+ b.waiters.PushBack(w)
+ w.bucket.Store(b)
+ return false, nil
+ }
+}
+
+// UnlockPI unlock the futex following the Priority-inheritance futex
+// rules. The address provided must contain the caller's TID. If there are
+// waiters, TID of the next waiter (FIFO) is set to the given address, and the
+// waiter woken up. If there are no waiters, 0 is set to the address.
+func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error {
+ k, err := getKey(t, addr, private)
+ if err != nil {
+ return err
+ }
+ b := m.lockBucket(&k)
+
+ err = m.unlockPILocked(t, addr, tid, b)
+
+ k.release()
+ b.mu.Unlock()
+ return err
+}
+
+func (m *Manager) unlockPILocked(t Target, addr usermem.Addr, tid uint32, b *bucket) error {
+ cur, err := t.LoadUint32(addr)
+ if err != nil {
+ return err
+ }
+
+ if (cur & linux.FUTEX_TID_MASK) != tid {
+ return syserror.EPERM
+ }
+
+ if b.waiters.Empty() {
+ // It's safe to set 0 because there are no waiters, no new owner, and the
+ // executing task is the current owner (no owner died bit).
+ prev, err := t.CompareAndSwapUint32(addr, cur, 0)
+ if err != nil {
+ return err
+ }
+ if prev != cur {
+ // Let user mode handle CAS races. This is different than lock, which
+ // retries when CAS fails.
+ return syserror.EAGAIN
+ }
+ return nil
+ }
+
+ next := b.waiters.Front()
+
+ // Set next owner's TID, waiters if there are any. Resets owner died bit, if
+ // set, because the executing task takes over as the owner.
+ val := next.tid
+ if next.Next() != nil {
+ val |= linux.FUTEX_WAITERS
+ }
+
+ prev, err := t.CompareAndSwapUint32(addr, cur, val)
+ if err != nil {
+ return err
+ }
+ if prev != cur {
+ return syserror.EINVAL
+ }
+
+ b.wakeWaiterLocked(next)
+ return nil
+}
diff --git a/pkg/sentry/kernel/futex/futex_state_autogen.go b/pkg/sentry/kernel/futex/futex_state_autogen.go
new file mode 100755
index 000000000..b58e22b78
--- /dev/null
+++ b/pkg/sentry/kernel/futex/futex_state_autogen.go
@@ -0,0 +1,62 @@
+// automatically generated by stateify.
+
+package futex
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *bucket) beforeSave() {}
+func (x *bucket) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.waiters) { m.Failf("waiters is %v, expected zero", x.waiters) }
+}
+
+func (x *bucket) afterLoad() {}
+func (x *bucket) load(m state.Map) {
+}
+
+func (x *Manager) beforeSave() {}
+func (x *Manager) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.privateBuckets) { m.Failf("privateBuckets is %v, expected zero", x.privateBuckets) }
+ m.Save("sharedBucket", &x.sharedBucket)
+}
+
+func (x *Manager) afterLoad() {}
+func (x *Manager) load(m state.Map) {
+ m.Load("sharedBucket", &x.sharedBucket)
+}
+
+func (x *waiterList) beforeSave() {}
+func (x *waiterList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *waiterList) afterLoad() {}
+func (x *waiterList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *waiterEntry) beforeSave() {}
+func (x *waiterEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *waiterEntry) afterLoad() {}
+func (x *waiterEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("futex.bucket", (*bucket)(nil), state.Fns{Save: (*bucket).save, Load: (*bucket).load})
+ state.Register("futex.Manager", (*Manager)(nil), state.Fns{Save: (*Manager).save, Load: (*Manager).load})
+ state.Register("futex.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load})
+ state.Register("futex.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load})
+}
diff --git a/pkg/sentry/kernel/futex/waiter_list.go b/pkg/sentry/kernel/futex/waiter_list.go
new file mode 100755
index 000000000..cca5c4721
--- /dev/null
+++ b/pkg/sentry/kernel/futex/waiter_list.go
@@ -0,0 +1,173 @@
+package futex
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type waiterElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (waiterElementMapper) linkerFor(elem *Waiter) *Waiter { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type waiterList struct {
+ head *Waiter
+ tail *Waiter
+}
+
+// Reset resets list l to the empty state.
+func (l *waiterList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *waiterList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *waiterList) Front() *Waiter {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *waiterList) Back() *Waiter {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *waiterList) PushFront(e *Waiter) {
+ waiterElementMapper{}.linkerFor(e).SetNext(l.head)
+ waiterElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ waiterElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *waiterList) PushBack(e *Waiter) {
+ waiterElementMapper{}.linkerFor(e).SetNext(nil)
+ waiterElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ waiterElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *waiterList) PushBackList(m *waiterList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *waiterList) InsertAfter(b, e *Waiter) {
+ a := waiterElementMapper{}.linkerFor(b).Next()
+ waiterElementMapper{}.linkerFor(e).SetNext(a)
+ waiterElementMapper{}.linkerFor(e).SetPrev(b)
+ waiterElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ waiterElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *waiterList) InsertBefore(a, e *Waiter) {
+ b := waiterElementMapper{}.linkerFor(a).Prev()
+ waiterElementMapper{}.linkerFor(e).SetNext(a)
+ waiterElementMapper{}.linkerFor(e).SetPrev(b)
+ waiterElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ waiterElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *waiterList) Remove(e *Waiter) {
+ prev := waiterElementMapper{}.linkerFor(e).Prev()
+ next := waiterElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ waiterElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ waiterElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type waiterEntry struct {
+ next *Waiter
+ prev *Waiter
+}
+
+// Next returns the entry that follows e in the list.
+func (e *waiterEntry) Next() *Waiter {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *waiterEntry) Prev() *Waiter {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *waiterEntry) SetNext(elem *Waiter) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *waiterEntry) SetPrev(elem *Waiter) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go
new file mode 100644
index 000000000..ebe12812c
--- /dev/null
+++ b/pkg/sentry/kernel/ipc_namespace.go
@@ -0,0 +1,58 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/semaphore"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/shm"
+)
+
+// IPCNamespace represents an IPC namespace.
+//
+// +stateify savable
+type IPCNamespace struct {
+ // User namespace which owns this IPC namespace. Immutable.
+ userNS *auth.UserNamespace
+
+ semaphores *semaphore.Registry
+ shms *shm.Registry
+}
+
+// NewIPCNamespace creates a new IPC namespace.
+func NewIPCNamespace(userNS *auth.UserNamespace) *IPCNamespace {
+ return &IPCNamespace{
+ userNS: userNS,
+ semaphores: semaphore.NewRegistry(userNS),
+ shms: shm.NewRegistry(userNS),
+ }
+}
+
+// SemaphoreRegistry returns the semanphore set registry for this namespace.
+func (i *IPCNamespace) SemaphoreRegistry() *semaphore.Registry {
+ return i.semaphores
+}
+
+// ShmRegistry returns the shm segment registry for this namespace.
+func (i *IPCNamespace) ShmRegistry() *shm.Registry {
+ return i.shms
+}
+
+// IPCNamespace returns the task's IPC namespace.
+func (t *Task) IPCNamespace() *IPCNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.ipcns
+}
diff --git a/pkg/sentry/kernel/kdefs/kdefs.go b/pkg/sentry/kernel/kdefs/kdefs.go
new file mode 100644
index 000000000..304da2032
--- /dev/null
+++ b/pkg/sentry/kernel/kdefs/kdefs.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package kdefs defines common kernel definitions.
+//
+package kdefs
+
+// FD is a File Descriptor.
+type FD int32
diff --git a/pkg/sentry/kernel/kdefs/kdefs_state_autogen.go b/pkg/sentry/kernel/kdefs/kdefs_state_autogen.go
new file mode 100755
index 000000000..cef77125b
--- /dev/null
+++ b/pkg/sentry/kernel/kdefs/kdefs_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package kdefs
+
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
new file mode 100644
index 000000000..85d73ace2
--- /dev/null
+++ b/pkg/sentry/kernel/kernel.go
@@ -0,0 +1,1241 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package kernel provides an emulation of the Linux kernel.
+//
+// See README.md for a detailed overview.
+//
+// Lock order (outermost locks must be taken first):
+//
+// Kernel.extMu
+// ThreadGroup.timerMu
+// ktime.Timer.mu (for kernelCPUClockTicker and IntervalTimer)
+// TaskSet.mu
+// SignalHandlers.mu
+// Task.mu
+//
+// Locking SignalHandlers.mu in multiple SignalHandlers requires locking
+// TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same
+// time requires locking all of their signal mutexes first.
+package kernel
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "path/filepath"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/eventchannel"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/timerfd"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/hostcpu"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/epoll"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/sched"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/loader"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink/port"
+ sentrytime "gvisor.googlesource.com/gvisor/pkg/sentry/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
+ uspb "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid"
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// Kernel represents an emulated Linux kernel. It must be initialized by calling
+// Init() or LoadFrom().
+//
+// +stateify savable
+type Kernel struct {
+ // extMu serializes external changes to the Kernel with calls to
+ // Kernel.SaveTo. (Kernel.SaveTo requires that the state of the Kernel
+ // remains frozen for the duration of the call; it requires that the Kernel
+ // is paused as a precondition, which ensures that none of the tasks
+ // running within the Kernel can affect its state, but extMu is required to
+ // ensure that concurrent users of the Kernel *outside* the Kernel's
+ // control cannot affect its state by calling e.g.
+ // Kernel.SendExternalSignal.)
+ extMu sync.Mutex `state:"nosave"`
+
+ // started is true if Start has been called. Unless otherwise specified,
+ // all Kernel fields become immutable once started becomes true.
+ started bool `state:"nosave"`
+
+ // All of the following fields are immutable unless otherwise specified.
+
+ // Platform is the platform that is used to execute tasks in the created
+ // Kernel. See comment on pgalloc.MemoryFileProvider for why Platform is
+ // embedded anonymously (the same issue applies).
+ platform.Platform `state:"nosave"`
+
+ // mf provides application memory.
+ mf *pgalloc.MemoryFile `state:"nosave"`
+
+ // See InitKernelArgs for the meaning of these fields.
+ featureSet *cpuid.FeatureSet
+ timekeeper *Timekeeper
+ tasks *TaskSet
+ rootUserNamespace *auth.UserNamespace
+ networkStack inet.Stack `state:"nosave"`
+ applicationCores uint
+ useHostCores bool
+ extraAuxv []arch.AuxEntry
+ vdso *loader.VDSO
+ rootUTSNamespace *UTSNamespace
+ rootIPCNamespace *IPCNamespace
+ rootAbstractSocketNamespace *AbstractSocketNamespace
+
+ // mounts holds the state of the virtual filesystem. mounts is initially
+ // nil, and must be set by calling Kernel.SetRootMountNamespace before
+ // Kernel.CreateProcess can succeed.
+ mounts *fs.MountNamespace
+
+ // futexes is the "root" futex.Manager, from which all others are forked.
+ // This is necessary to ensure that shared futexes are coherent across all
+ // tasks, including those created by CreateProcess.
+ futexes *futex.Manager
+
+ // globalInit is the thread group whose leader has ID 1 in the root PID
+ // namespace. globalInit is stored separately so that it is accessible even
+ // after all tasks in the thread group have exited, such that ID 1 is no
+ // longer mapped.
+ //
+ // globalInit is mutable until it is assigned by the first successful call
+ // to CreateProcess, and is protected by extMu.
+ globalInit *ThreadGroup
+
+ // realtimeClock is a ktime.Clock based on timekeeper's Realtime.
+ realtimeClock *timekeeperClock
+
+ // monotonicClock is a ktime.Clock based on timekeeper's Monotonic.
+ monotonicClock *timekeeperClock
+
+ // syslog is the kernel log.
+ syslog syslog
+
+ // cpuClock is incremented every linux.ClockTick. cpuClock is used to
+ // measure task CPU usage, since sampling monotonicClock twice on every
+ // syscall turns out to be unreasonably expensive. This is similar to how
+ // Linux does task CPU accounting on x86 (CONFIG_IRQ_TIME_ACCOUNTING),
+ // although Linux also uses scheduler timing information to improve
+ // resolution (kernel/sched/cputime.c:cputime_adjust()), which we can't do
+ // since "preeemptive" scheduling is managed by the Go runtime, which
+ // doesn't provide this information.
+ //
+ // cpuClock is mutable, and is accessed using atomic memory operations.
+ cpuClock uint64
+
+ // cpuClockTicker increments cpuClock.
+ cpuClockTicker *ktime.Timer `state:"nosave"`
+
+ // fdMapUids is an ever-increasing counter for generating FDMap uids.
+ //
+ // fdMapUids is mutable, and is accessed using atomic memory operations.
+ fdMapUids uint64
+
+ // uniqueID is used to generate unique identifiers.
+ //
+ // uniqueID is mutable, and is accessed using atomic memory operations.
+ uniqueID uint64
+
+ // nextInotifyCookie is a monotonically increasing counter used for
+ // generating unique inotify event cookies.
+ //
+ // nextInotifyCookie is mutable, and is accessed using atomic memory
+ // operations.
+ nextInotifyCookie uint32
+
+ // netlinkPorts manages allocation of netlink socket port IDs.
+ netlinkPorts *port.Manager
+
+ // saveErr is the error causing the sandbox to exit during save, if
+ // any. It is protected by extMu.
+ saveErr error `state:"nosave"`
+
+ // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
+ danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
+
+ // socketTable is used to track all sockets on the system. Protected by
+ // extMu.
+ socketTable map[int]map[*refs.WeakRef]struct{}
+
+ // deviceRegistry is used to save/restore device.SimpleDevices.
+ deviceRegistry struct{} `state:".(*device.Registry)"`
+
+ // DirentCacheLimiter controls the number of total dirent entries can be in
+ // caches. Not all caches use it, only the caches that use host resources use
+ // the limiter. It may be nil if disabled.
+ DirentCacheLimiter *fs.DirentCacheLimiter
+}
+
+// InitKernelArgs holds arguments to Init.
+type InitKernelArgs struct {
+ // FeatureSet is the emulated CPU feature set.
+ FeatureSet *cpuid.FeatureSet
+
+ // Timekeeper manages time for all tasks in the system.
+ Timekeeper *Timekeeper
+
+ // RootUserNamespace is the root user namespace.
+ RootUserNamespace *auth.UserNamespace
+
+ // NetworkStack is the TCP/IP network stack. NetworkStack may be nil.
+ NetworkStack inet.Stack
+
+ // ApplicationCores is the number of logical CPUs visible to sandboxed
+ // applications. The set of logical CPU IDs is [0, ApplicationCores); thus
+ // ApplicationCores is analogous to Linux's nr_cpu_ids, the index of the
+ // most significant bit in cpu_possible_mask + 1.
+ ApplicationCores uint
+
+ // If UseHostCores is true, Task.CPU() returns the task goroutine's CPU
+ // instead of a virtualized CPU number, and Task.CopyToCPUMask() is a
+ // no-op. If ApplicationCores is less than hostcpu.MaxPossibleCPU(), it
+ // will be overridden.
+ UseHostCores bool
+
+ // ExtraAuxv contains additional auxiliary vector entries that are added to
+ // each process by the ELF loader.
+ ExtraAuxv []arch.AuxEntry
+
+ // Vdso holds the VDSO and its parameter page.
+ Vdso *loader.VDSO
+
+ // RootUTSNamespace is the root UTS namespace.
+ RootUTSNamespace *UTSNamespace
+
+ // RootIPCNamespace is the root IPC namespace.
+ RootIPCNamespace *IPCNamespace
+
+ // RootAbstractSocketNamespace is the root Abstract Socket namespace.
+ RootAbstractSocketNamespace *AbstractSocketNamespace
+}
+
+// Init initialize the Kernel with no tasks.
+//
+// Callers must manually set Kernel.Platform and call Kernel.SetMemoryFile
+// before calling Init.
+func (k *Kernel) Init(args InitKernelArgs) error {
+ if args.FeatureSet == nil {
+ return fmt.Errorf("FeatureSet is nil")
+ }
+ if args.Timekeeper == nil {
+ return fmt.Errorf("Timekeeper is nil")
+ }
+ if args.RootUserNamespace == nil {
+ return fmt.Errorf("RootUserNamespace is nil")
+ }
+ if args.ApplicationCores == 0 {
+ return fmt.Errorf("ApplicationCores is 0")
+ }
+
+ k.featureSet = args.FeatureSet
+ k.timekeeper = args.Timekeeper
+ k.tasks = newTaskSet()
+ k.rootUserNamespace = args.RootUserNamespace
+ k.rootUTSNamespace = args.RootUTSNamespace
+ k.rootIPCNamespace = args.RootIPCNamespace
+ k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace
+ k.networkStack = args.NetworkStack
+ k.applicationCores = args.ApplicationCores
+ if args.UseHostCores {
+ k.useHostCores = true
+ maxCPU, err := hostcpu.MaxPossibleCPU()
+ if err != nil {
+ return fmt.Errorf("Failed to get maximum CPU number: %v", err)
+ }
+ minAppCores := uint(maxCPU) + 1
+ if k.applicationCores < minAppCores {
+ log.Infof("UseHostCores enabled: increasing ApplicationCores from %d to %d", k.applicationCores, minAppCores)
+ k.applicationCores = minAppCores
+ }
+ }
+ k.extraAuxv = args.ExtraAuxv
+ k.vdso = args.Vdso
+ k.realtimeClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Realtime}
+ k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
+ k.futexes = futex.NewManager()
+ k.netlinkPorts = port.New()
+ k.socketTable = make(map[int]map[*refs.WeakRef]struct{})
+
+ return nil
+}
+
+// SaveTo saves the state of k to w.
+//
+// Preconditions: The kernel must be paused throughout the call to SaveTo.
+func (k *Kernel) SaveTo(w io.Writer) error {
+ saveStart := time.Now()
+ ctx := k.SupervisorContext()
+
+ // Do not allow other Kernel methods to affect it while it's being saved.
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+
+ // Stop time.
+ k.pauseTimeLocked()
+ defer k.resumeTimeLocked()
+
+ // Evict all evictable MemoryFile allocations.
+ k.mf.StartEvictions()
+ k.mf.WaitForEvictions()
+
+ // Flush write operations on open files so data reaches backing storage.
+ // This must come after MemoryFile eviction since eviction may cause file
+ // writes.
+ if err := k.tasks.flushWritesToFiles(ctx); err != nil {
+ return err
+ }
+
+ // Remove all epoll waiter objects from underlying wait queues.
+ // NOTE: for programs to resume execution in future snapshot scenarios,
+ // we will need to re-establish these waiter objects after saving.
+ k.tasks.unregisterEpollWaiters()
+
+ // Clear the dirent cache before saving because Dirents must be Loaded in a
+ // particular order (parents before children), and Loading dirents from a cache
+ // breaks that order.
+ if err := k.flushMountSourceRefs(); err != nil {
+ return err
+ }
+
+ // Ensure that all pending asynchronous work is complete:
+ // - inode and mount release
+ // - asynchronuous IO
+ fs.AsyncBarrier()
+
+ // Once all fs work has completed (flushed references have all been released),
+ // reset mount mappings. This allows individual mounts to save how inodes map
+ // to filesystem resources. Without this, fs.Inodes cannot be restored.
+ fs.SaveInodeMappings()
+
+ // Discard unsavable mappings, such as those for host file descriptors.
+ // This must be done after waiting for "asynchronous fs work", which
+ // includes async I/O that may touch application memory.
+ if err := k.invalidateUnsavableMappings(ctx); err != nil {
+ return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ }
+
+ // Save the CPUID FeatureSet before the rest of the kernel so we can
+ // verify its compatibility on restore before attempting to restore the
+ // entire kernel, which may fail on an incompatible machine.
+ //
+ // N.B. This will also be saved along with the full kernel save below.
+ cpuidStart := time.Now()
+ if err := state.Save(w, k.FeatureSet(), nil); err != nil {
+ return err
+ }
+ log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
+
+ // Save the kernel state.
+ kernelStart := time.Now()
+ var stats state.Stats
+ if err := state.Save(w, k, &stats); err != nil {
+ return err
+ }
+ log.Infof("Kernel save stats: %s", &stats)
+ log.Infof("Kernel save took [%s].", time.Since(kernelStart))
+
+ // Save the memory file's state.
+ memoryStart := time.Now()
+ if err := k.mf.SaveTo(w); err != nil {
+ return err
+ }
+ log.Infof("Memory save took [%s].", time.Since(memoryStart))
+
+ log.Infof("Overall save took [%s].", time.Since(saveStart))
+
+ return nil
+}
+
+// flushMountSourceRefs flushes the MountSources for all mounted filesystems
+// and open FDs.
+func (k *Kernel) flushMountSourceRefs() error {
+ // Flush all mount sources for currently mounted filesystems.
+ k.mounts.FlushMountSourceRefs()
+
+ // There may be some open FDs whose filesystems have been unmounted. We
+ // must flush those as well.
+ return k.tasks.forEachFDPaused(func(desc descriptor) error {
+ desc.file.Dirent.Inode.MountSource.FlushDirentRefs()
+ return nil
+ })
+}
+
+// forEachFDPaused applies the given function to each open file descriptor in each
+// task.
+//
+// Precondition: Must be called with the kernel paused.
+func (ts *TaskSet) forEachFDPaused(f func(descriptor) error) error {
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+ for t := range ts.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if t.fds == nil {
+ continue
+ }
+ for _, desc := range t.fds.files {
+ if err := f(desc); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
+ return ts.forEachFDPaused(func(desc descriptor) error {
+ if flags := desc.file.Flags(); !flags.Write {
+ return nil
+ }
+ if sattr := desc.file.Dirent.Inode.StableAttr; !fs.IsFile(sattr) && !fs.IsDir(sattr) {
+ return nil
+ }
+ // Here we need all metadata synced.
+ syncErr := desc.file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncAll)
+ if err := fs.SaveFileFsyncError(syncErr); err != nil {
+ name, _ := desc.file.Dirent.FullName(nil /* root */)
+ // Wrap this error in ErrSaveRejection
+ // so that it will trigger a save
+ // error, rather than a panic. This
+ // also allows us to distinguish Fsync
+ // errors from state file errors in
+ // state.Save.
+ return fs.ErrSaveRejection{
+ Err: fmt.Errorf("%q was not sufficiently synced: %v", name, err),
+ }
+ }
+ return nil
+ })
+}
+
+// Preconditions: The kernel must be paused.
+func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
+ invalidated := make(map[*mm.MemoryManager]struct{})
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+ for t := range k.tasks.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if mm := t.tc.MemoryManager; mm != nil {
+ if _, ok := invalidated[mm]; !ok {
+ if err := mm.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ invalidated[mm] = struct{}{}
+ }
+ }
+ // I really wish we just had a sync.Map of all MMs...
+ if r, ok := t.runState.(*runSyscallAfterExecStop); ok {
+ if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (ts *TaskSet) unregisterEpollWaiters() {
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+ for t := range ts.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if fdmap := t.fds; fdmap != nil {
+ for _, desc := range fdmap.files {
+ if desc.file != nil {
+ if e, ok := desc.file.FileOperations.(*epoll.EventPoll); ok {
+ e.UnregisterEpollWaiters()
+ }
+ }
+ }
+ }
+ }
+}
+
+// LoadFrom returns a new Kernel loaded from args.
+func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error {
+ loadStart := time.Now()
+
+ k.networkStack = net
+
+ initAppCores := k.applicationCores
+
+ // Load the pre-saved CPUID FeatureSet.
+ //
+ // N.B. This was also saved along with the full kernel below, so we
+ // don't need to explicitly install it in the Kernel.
+ cpuidStart := time.Now()
+ var features cpuid.FeatureSet
+ if err := state.Load(r, &features, nil); err != nil {
+ return err
+ }
+ log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
+
+ // Verify that the FeatureSet is usable on this host. We do this before
+ // Kernel load so that the explicit CPUID mismatch error has priority
+ // over floating point state restore errors that may occur on load on
+ // an incompatible machine.
+ if err := features.CheckHostCompatible(); err != nil {
+ return err
+ }
+
+ // Load the kernel state.
+ kernelStart := time.Now()
+ var stats state.Stats
+ if err := state.Load(r, k, &stats); err != nil {
+ return err
+ }
+ log.Infof("Kernel load stats: %s", &stats)
+ log.Infof("Kernel load took [%s].", time.Since(kernelStart))
+
+ // Load the memory file's state.
+ memoryStart := time.Now()
+ if err := k.mf.LoadFrom(r); err != nil {
+ return err
+ }
+ log.Infof("Memory load took [%s].", time.Since(memoryStart))
+
+ // Ensure that all pending asynchronous work is complete:
+ // - namedpipe opening
+ // - inode file opening
+ if err := fs.AsyncErrorBarrier(); err != nil {
+ return err
+ }
+
+ tcpip.AsyncLoading.Wait()
+
+ log.Infof("Overall load took [%s]", time.Since(loadStart))
+
+ // Applications may size per-cpu structures based on k.applicationCores, so
+ // it can't change across save/restore. When we are virtualizing CPU
+ // numbers, this isn't a problem. However, when we are exposing host CPU
+ // assignments, we can't tolerate an increase in the number of host CPUs,
+ // which could result in getcpu(2) returning CPUs that applications expect
+ // not to exist.
+ if k.useHostCores && initAppCores > k.applicationCores {
+ return fmt.Errorf("UseHostCores enabled: can't increase ApplicationCores from %d to %d after restore", k.applicationCores, initAppCores)
+ }
+
+ return nil
+}
+
+// Destroy releases resources owned by k.
+//
+// Preconditions: There must be no task goroutines running in k.
+func (k *Kernel) Destroy() {
+ if k.mounts != nil {
+ k.mounts.DecRef()
+ k.mounts = nil
+ }
+}
+
+// UniqueID returns a unique identifier.
+func (k *Kernel) UniqueID() uint64 {
+ id := atomic.AddUint64(&k.uniqueID, 1)
+ if id == 0 {
+ panic("unique identifier generator wrapped around")
+ }
+ return id
+}
+
+// CreateProcessArgs holds arguments to kernel.CreateProcess.
+type CreateProcessArgs struct {
+ // Filename is the filename to load.
+ //
+ // If this is provided as "", then the file will be guessed via Argv[0].
+ Filename string
+
+ // Argvv is a list of arguments.
+ Argv []string
+
+ // Envv is a list of environment variables.
+ Envv []string
+
+ // WorkingDirectory is the initial working directory.
+ //
+ // This defaults to the root if empty.
+ WorkingDirectory string
+
+ // Credentials is the initial credentials.
+ Credentials *auth.Credentials
+
+ // FDMap is the initial set of file descriptors. If CreateProcess succeeds,
+ // it takes a reference on FDMap.
+ FDMap *FDMap
+
+ // Umask is the initial umask.
+ Umask uint
+
+ // Limits is the initial resource limits.
+ Limits *limits.LimitSet
+
+ // MaxSymlinkTraversals is the maximum number of symlinks to follow
+ // during resolution.
+ MaxSymlinkTraversals uint
+
+ // UTSNamespace is the initial UTS namespace.
+ UTSNamespace *UTSNamespace
+
+ // IPCNamespace is the initial IPC namespace.
+ IPCNamespace *IPCNamespace
+
+ // AbstractSocketNamespace is the initial Abstract Socket namespace.
+ AbstractSocketNamespace *AbstractSocketNamespace
+
+ // Root optionally contains the dirent that serves as the root for the
+ // process. If nil, the mount namespace's root is used as the process'
+ // root.
+ //
+ // Anyone setting Root must donate a reference (i.e. increment it) to
+ // keep it alive until it is decremented by CreateProcess.
+ Root *fs.Dirent
+
+ // ContainerID is the container that the process belongs to.
+ ContainerID string
+}
+
+// NewContext returns a context.Context that represents the task that will be
+// created by args.NewContext(k).
+func (args *CreateProcessArgs) NewContext(k *Kernel) *createProcessContext {
+ return &createProcessContext{
+ Logger: log.Log(),
+ k: k,
+ args: args,
+ }
+}
+
+// createProcessContext is a context.Context that represents the context
+// associated with a task that is being created.
+type createProcessContext struct {
+ context.NoopSleeper
+ log.Logger
+ k *Kernel
+ args *CreateProcessArgs
+}
+
+// Value implements context.Context.Value.
+func (ctx *createProcessContext) Value(key interface{}) interface{} {
+ switch key {
+ case CtxKernel:
+ return ctx.k
+ case CtxPIDNamespace:
+ // "The new task ... is in the root PID namespace." -
+ // Kernel.CreateProcess
+ return ctx.k.tasks.Root
+ case CtxUTSNamespace:
+ return ctx.args.UTSNamespace
+ case CtxIPCNamespace:
+ return ctx.args.IPCNamespace
+ case auth.CtxCredentials:
+ return ctx.args.Credentials
+ case fs.CtxRoot:
+ if ctx.args.Root != nil {
+ // Take a refernce on the root dirent that will be
+ // given to the caller.
+ ctx.args.Root.IncRef()
+ return ctx.args.Root
+ }
+ if ctx.k.mounts != nil {
+ // MountNamespace.Root() will take a reference on the
+ // root dirent for us.
+ return ctx.k.mounts.Root()
+ }
+ return nil
+ case fs.CtxDirentCacheLimiter:
+ return ctx.k.DirentCacheLimiter
+ case ktime.CtxRealtimeClock:
+ return ctx.k.RealtimeClock()
+ case limits.CtxLimits:
+ return ctx.args.Limits
+ case pgalloc.CtxMemoryFile:
+ return ctx.k.mf
+ case pgalloc.CtxMemoryFileProvider:
+ return ctx.k
+ case platform.CtxPlatform:
+ return ctx.k
+ case uniqueid.CtxGlobalUniqueID:
+ return ctx.k.UniqueID()
+ case uniqueid.CtxGlobalUniqueIDProvider:
+ return ctx.k
+ case uniqueid.CtxInotifyCookie:
+ return ctx.k.GenerateInotifyCookie()
+ case unimpl.CtxEvents:
+ return ctx.k
+ default:
+ return nil
+ }
+}
+
+// CreateProcess creates a new task in a new thread group with the given
+// options. The new task has no parent and is in the root PID namespace.
+//
+// If k.Start() has already been called, then the created process must be
+// started by calling kernel.StartProcess(tg).
+//
+// If k.Start() has not yet been called, then the created task will begin
+// running when k.Start() is called.
+//
+// CreateProcess has no analogue in Linux; it is used to create the initial
+// application task, as well as processes started by the control server.
+func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, error) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ log.Infof("EXEC: %v", args.Argv)
+
+ if k.mounts == nil {
+ return nil, 0, fmt.Errorf("no kernel MountNamespace")
+ }
+
+ tg := k.newThreadGroup(k.tasks.Root, NewSignalHandlers(), linux.SIGCHLD, args.Limits, k.monotonicClock)
+ ctx := args.NewContext(k)
+
+ // Grab the root directory.
+ root := args.Root
+ if root == nil {
+ root = fs.RootFromContext(ctx)
+ // Is the root STILL nil?
+ if root == nil {
+ return nil, 0, fmt.Errorf("CreateProcessArgs.Root was not provided, and failed to get root from context")
+ }
+ }
+ defer root.DecRef()
+ args.Root = nil
+
+ // Grab the working directory.
+ remainingTraversals := uint(args.MaxSymlinkTraversals)
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ var err error
+ wd, err = k.mounts.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef()
+ }
+
+ if args.Filename == "" {
+ // Was anything provided?
+ if len(args.Argv) == 0 {
+ return nil, 0, fmt.Errorf("no filename or command provided")
+ }
+ if !filepath.IsAbs(args.Argv[0]) {
+ return nil, 0, fmt.Errorf("'%s' is not an absolute path", args.Argv[0])
+ }
+ args.Filename = args.Argv[0]
+ }
+
+ // Create a fresh task context.
+ remainingTraversals = uint(args.MaxSymlinkTraversals)
+ tc, se := k.LoadTaskImage(ctx, k.mounts, root, wd, &remainingTraversals, args.Filename, args.Argv, args.Envv, k.featureSet)
+ if se != nil {
+ return nil, 0, errors.New(se.String())
+ }
+
+ // Take a reference on the FDMap, which will be transferred to
+ // TaskSet.NewTask().
+ args.FDMap.IncRef()
+
+ // Create the task.
+ config := &TaskConfig{
+ Kernel: k,
+ ThreadGroup: tg,
+ TaskContext: tc,
+ FSContext: newFSContext(root, wd, args.Umask),
+ FDMap: args.FDMap,
+ Credentials: args.Credentials,
+ AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores),
+ UTSNamespace: args.UTSNamespace,
+ IPCNamespace: args.IPCNamespace,
+ AbstractSocketNamespace: args.AbstractSocketNamespace,
+ ContainerID: args.ContainerID,
+ }
+ if _, err := k.tasks.NewTask(config); err != nil {
+ return nil, 0, err
+ }
+
+ // Success.
+ tgid := k.tasks.Root.IDOfThreadGroup(tg)
+ if k.globalInit == nil {
+ k.globalInit = tg
+ }
+ return tg, tgid, nil
+}
+
+// StartProcess starts running a process that was created with CreateProcess.
+func (k *Kernel) StartProcess(tg *ThreadGroup) {
+ t := tg.Leader()
+ tid := k.tasks.Root.IDOfTask(t)
+ t.Start(tid)
+}
+
+// Start starts execution of all tasks in k.
+//
+// Preconditions: Start may be called exactly once.
+func (k *Kernel) Start() error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+
+ if k.globalInit == nil {
+ return fmt.Errorf("kernel contains no tasks")
+ }
+ if k.started {
+ return fmt.Errorf("kernel already started")
+ }
+
+ k.started = true
+ k.cpuClockTicker = ktime.NewTimer(k.monotonicClock, newKernelCPUClockTicker(k))
+ k.cpuClockTicker.Swap(ktime.Setting{
+ Enabled: true,
+ Period: linux.ClockTick,
+ })
+ // If k was created by LoadKernelFrom, timers were stopped during
+ // Kernel.SaveTo and need to be resumed. If k was created by NewKernel,
+ // this is a no-op.
+ k.resumeTimeLocked()
+ // Start task goroutines.
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+ for t, tid := range k.tasks.Root.tids {
+ t.Start(tid)
+ }
+ return nil
+}
+
+// pauseTimeLocked pauses all Timers and Timekeeper updates.
+//
+// Preconditions: Any task goroutines running in k must be stopped. k.extMu
+// must be locked.
+func (k *Kernel) pauseTimeLocked() {
+ // k.cpuClockTicker may be nil since Kernel.SaveTo() may be called before
+ // Kernel.Start().
+ if k.cpuClockTicker != nil {
+ k.cpuClockTicker.Pause()
+ }
+
+ // By precondition, nothing else can be interacting with PIDNamespace.tids
+ // or FDMap.files, so we can iterate them without synchronization. (We
+ // can't hold the TaskSet mutex when pausing thread group timers because
+ // thread group timers call ThreadGroup.SendSignal, which takes the TaskSet
+ // mutex, while holding the Timer mutex.)
+ for t := range k.tasks.Root.tids {
+ if t == t.tg.leader {
+ t.tg.itimerRealTimer.Pause()
+ for _, it := range t.tg.timers {
+ it.PauseTimer()
+ }
+ }
+ // This means we'll iterate FDMaps shared by multiple tasks repeatedly,
+ // but ktime.Timer.Pause is idempotent so this is harmless.
+ if fdm := t.fds; fdm != nil {
+ for _, desc := range fdm.files {
+ if tfd, ok := desc.file.FileOperations.(*timerfd.TimerOperations); ok {
+ tfd.PauseTimer()
+ }
+ }
+ }
+ }
+ k.timekeeper.PauseUpdates()
+}
+
+// resumeTimeLocked resumes all Timers and Timekeeper updates. If
+// pauseTimeLocked has not been previously called, resumeTimeLocked has no
+// effect.
+//
+// Preconditions: Any task goroutines running in k must be stopped. k.extMu
+// must be locked.
+func (k *Kernel) resumeTimeLocked() {
+ if k.cpuClockTicker != nil {
+ k.cpuClockTicker.Resume()
+ }
+
+ k.timekeeper.ResumeUpdates()
+ for t := range k.tasks.Root.tids {
+ if t == t.tg.leader {
+ t.tg.itimerRealTimer.Resume()
+ for _, it := range t.tg.timers {
+ it.ResumeTimer()
+ }
+ }
+ if fdm := t.fds; fdm != nil {
+ for _, desc := range fdm.files {
+ if tfd, ok := desc.file.FileOperations.(*timerfd.TimerOperations); ok {
+ tfd.ResumeTimer()
+ }
+ }
+ }
+ }
+}
+
+// WaitExited blocks until all tasks in k have exited.
+func (k *Kernel) WaitExited() {
+ k.tasks.liveGoroutines.Wait()
+}
+
+// Kill requests that all tasks in k immediately exit as if group exiting with
+// status es. Kill does not wait for tasks to exit.
+func (k *Kernel) Kill(es ExitStatus) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.Kill(es)
+}
+
+// Pause requests that all tasks in k temporarily stop executing, and blocks
+// until all tasks in k have stopped. Multiple calls to Pause nest and require
+// an equal number of calls to Unpause to resume execution.
+func (k *Kernel) Pause() {
+ k.extMu.Lock()
+ k.tasks.BeginExternalStop()
+ k.extMu.Unlock()
+ k.tasks.runningGoroutines.Wait()
+}
+
+// Unpause ends the effect of a previous call to Pause. If Unpause is called
+// without a matching preceding call to Pause, Unpause may panic.
+func (k *Kernel) Unpause() {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.EndExternalStop()
+}
+
+// SendExternalSignal injects a signal into the kernel.
+//
+// context is used only for debugging to describe how the signal was received.
+//
+// Preconditions: Kernel must have an init process.
+func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.sendExternalSignal(info, context)
+}
+
+// SendContainerSignal sends the given signal to all processes inside the
+// namespace that match the given container ID.
+func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+
+ var lastErr error
+ for tg := range k.tasks.Root.tgids {
+ if tg.leader.ContainerID() == cid {
+ tg.signalHandlers.mu.Lock()
+ infoCopy := *info
+ if err := tg.leader.sendSignalLocked(&infoCopy, true /*group*/); err != nil {
+ lastErr = err
+ }
+ tg.signalHandlers.mu.Unlock()
+ }
+ }
+ return lastErr
+}
+
+// FeatureSet returns the FeatureSet.
+func (k *Kernel) FeatureSet() *cpuid.FeatureSet {
+ return k.featureSet
+}
+
+// Timekeeper returns the Timekeeper.
+func (k *Kernel) Timekeeper() *Timekeeper {
+ return k.timekeeper
+}
+
+// TaskSet returns the TaskSet.
+func (k *Kernel) TaskSet() *TaskSet {
+ return k.tasks
+}
+
+// RootUserNamespace returns the root UserNamespace.
+func (k *Kernel) RootUserNamespace() *auth.UserNamespace {
+ return k.rootUserNamespace
+}
+
+// RootUTSNamespace returns the root UTSNamespace.
+func (k *Kernel) RootUTSNamespace() *UTSNamespace {
+ return k.rootUTSNamespace
+}
+
+// RootIPCNamespace returns the root IPCNamespace.
+func (k *Kernel) RootIPCNamespace() *IPCNamespace {
+ return k.rootIPCNamespace
+}
+
+// RootAbstractSocketNamespace returns the root AbstractSocketNamespace.
+func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
+ return k.rootAbstractSocketNamespace
+}
+
+// RootMountNamespace returns the MountNamespace.
+func (k *Kernel) RootMountNamespace() *fs.MountNamespace {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return k.mounts
+}
+
+// SetRootMountNamespace sets the MountNamespace.
+func (k *Kernel) SetRootMountNamespace(mounts *fs.MountNamespace) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.mounts = mounts
+}
+
+// NetworkStack returns the network stack. NetworkStack may return nil if no
+// network stack is available.
+func (k *Kernel) NetworkStack() inet.Stack {
+ return k.networkStack
+}
+
+// GlobalInit returns the thread group with ID 1 in the root PID namespace, or
+// nil if no such thread group exists. GlobalInit may return a thread group
+// containing no tasks if the thread group has already exited.
+func (k *Kernel) GlobalInit() *ThreadGroup {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return k.globalInit
+}
+
+// ApplicationCores returns the number of CPUs visible to sandboxed
+// applications.
+func (k *Kernel) ApplicationCores() uint {
+ return k.applicationCores
+}
+
+// RealtimeClock returns the application CLOCK_REALTIME clock.
+func (k *Kernel) RealtimeClock() ktime.Clock {
+ return k.realtimeClock
+}
+
+// MonotonicClock returns the application CLOCK_MONOTONIC clock.
+func (k *Kernel) MonotonicClock() ktime.Clock {
+ return k.monotonicClock
+}
+
+// CPUClockNow returns the current value of k.cpuClock.
+func (k *Kernel) CPUClockNow() uint64 {
+ return atomic.LoadUint64(&k.cpuClock)
+}
+
+// Syslog returns the syslog.
+func (k *Kernel) Syslog() *syslog {
+ return &k.syslog
+}
+
+// GenerateInotifyCookie generates a unique inotify event cookie.
+//
+// Returned values may overlap with previously returned values if the value
+// space is exhausted. 0 is not a valid cookie value, all other values
+// representable in a uint32 are allowed.
+func (k *Kernel) GenerateInotifyCookie() uint32 {
+ id := atomic.AddUint32(&k.nextInotifyCookie, 1)
+ // Wrap-around is explicitly allowed for inotify event cookies.
+ if id == 0 {
+ id = atomic.AddUint32(&k.nextInotifyCookie, 1)
+ }
+ return id
+}
+
+// NetlinkPorts returns the netlink port manager.
+func (k *Kernel) NetlinkPorts() *port.Manager {
+ return k.netlinkPorts
+}
+
+// SaveError returns the sandbox error that caused the kernel to exit during
+// save.
+func (k *Kernel) SaveError() error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return k.saveErr
+}
+
+// SetSaveError sets the sandbox error that caused the kernel to exit during
+// save, if one is not already set.
+func (k *Kernel) SetSaveError(err error) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ if k.saveErr == nil {
+ k.saveErr = err
+ }
+}
+
+var _ tcpip.Clock = (*Kernel)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (k *Kernel) NowNanoseconds() int64 {
+ now, err := k.timekeeper.GetTime(sentrytime.Realtime)
+ if err != nil {
+ panic("Kernel.NowNanoseconds: " + err.Error())
+ }
+ return now
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (k *Kernel) NowMonotonic() int64 {
+ now, err := k.timekeeper.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ panic("Kernel.NowMonotonic: " + err.Error())
+ }
+ return now
+}
+
+// SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or
+// LoadFrom.
+func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) {
+ k.mf = mf
+}
+
+// MemoryFile implements pgalloc.MemoryFileProvider.MemoryFile.
+func (k *Kernel) MemoryFile() *pgalloc.MemoryFile {
+ return k.mf
+}
+
+// SupervisorContext returns a Context with maximum privileges in k. It should
+// only be used by goroutines outside the control of the emulated kernel
+// defined by e.
+//
+// Callers are responsible for ensuring that the returned Context is not used
+// concurrently with changes to the Kernel.
+func (k *Kernel) SupervisorContext() context.Context {
+ return supervisorContext{
+ Logger: log.Log(),
+ k: k,
+ }
+}
+
+// EmitUnimplementedEvent emits an UnimplementedSyscall event via the event
+// channel.
+func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
+ t := TaskFromContext(ctx)
+ eventchannel.Emit(&uspb.UnimplementedSyscall{
+ Tid: int32(t.ThreadID()),
+ Registers: t.Arch().StateData().Proto(),
+ })
+}
+
+// socketEntry represents a socket recorded in Kernel.socketTable. It implements
+// refs.WeakRefUser for sockets stored in the socket table.
+//
+// +stateify savable
+type socketEntry struct {
+ k *Kernel
+ sock *refs.WeakRef
+ family int
+}
+
+// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
+func (s *socketEntry) WeakRefGone() {
+ s.k.extMu.Lock()
+ // k.socketTable is guaranteed to point to a valid socket table for s.family
+ // at this point, since we made sure of the fact when we created this
+ // socketEntry, and we never delete socket tables.
+ delete(s.k.socketTable[s.family], s.sock)
+ s.k.extMu.Unlock()
+}
+
+// RecordSocket adds a socket to the system-wide socket table for tracking.
+//
+// Precondition: Caller must hold a reference to sock.
+func (k *Kernel) RecordSocket(sock *fs.File, family int) {
+ k.extMu.Lock()
+ table, ok := k.socketTable[family]
+ if !ok {
+ table = make(map[*refs.WeakRef]struct{})
+ k.socketTable[family] = table
+ }
+ se := socketEntry{k: k, family: family}
+ se.sock = refs.NewWeakRef(sock, &se)
+ table[se.sock] = struct{}{}
+ k.extMu.Unlock()
+}
+
+// ListSockets returns a snapshot of all sockets of a given family.
+func (k *Kernel) ListSockets(family int) []*refs.WeakRef {
+ k.extMu.Lock()
+ socks := []*refs.WeakRef{}
+ if table, ok := k.socketTable[family]; ok {
+ socks = make([]*refs.WeakRef, 0, len(table))
+ for s := range table {
+ socks = append(socks, s)
+ }
+ }
+ k.extMu.Unlock()
+ return socks
+}
+
+type supervisorContext struct {
+ context.NoopSleeper
+ log.Logger
+ k *Kernel
+}
+
+// Value implements context.Context.
+func (ctx supervisorContext) Value(key interface{}) interface{} {
+ switch key {
+ case CtxCanTrace:
+ // The supervisor context can trace anything. (None of
+ // supervisorContext's users are expected to invoke ptrace, but ptrace
+ // permissions are required for certain file accesses.)
+ return func(*Task, bool) bool { return true }
+ case CtxKernel:
+ return ctx.k
+ case CtxPIDNamespace:
+ return ctx.k.tasks.Root
+ case CtxUTSNamespace:
+ return ctx.k.rootUTSNamespace
+ case CtxIPCNamespace:
+ return ctx.k.rootIPCNamespace
+ case auth.CtxCredentials:
+ // The supervisor context is global root.
+ return auth.NewRootCredentials(ctx.k.rootUserNamespace)
+ case fs.CtxRoot:
+ return ctx.k.mounts.Root()
+ case fs.CtxDirentCacheLimiter:
+ return ctx.k.DirentCacheLimiter
+ case ktime.CtxRealtimeClock:
+ return ctx.k.RealtimeClock()
+ case limits.CtxLimits:
+ // No limits apply.
+ return limits.NewLimitSet()
+ case pgalloc.CtxMemoryFile:
+ return ctx.k.mf
+ case pgalloc.CtxMemoryFileProvider:
+ return ctx.k
+ case platform.CtxPlatform:
+ return ctx.k
+ case uniqueid.CtxGlobalUniqueID:
+ return ctx.k.UniqueID()
+ case uniqueid.CtxGlobalUniqueIDProvider:
+ return ctx.k
+ case uniqueid.CtxInotifyCookie:
+ return ctx.k.GenerateInotifyCookie()
+ case unimpl.CtxEvents:
+ return ctx.k
+ default:
+ return nil
+ }
+}
diff --git a/pkg/sentry/kernel/kernel_state.go b/pkg/sentry/kernel/kernel_state.go
new file mode 100644
index 000000000..48c3ff5a9
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_state.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// saveDanglingEndpoints is invoked by stateify.
+func (k *Kernel) saveDanglingEndpoints() []tcpip.Endpoint {
+ return tcpip.GetDanglingEndpoints()
+}
+
+// loadDanglingEndpoints is invoked by stateify.
+func (k *Kernel) loadDanglingEndpoints(es []tcpip.Endpoint) {
+ for _, e := range es {
+ tcpip.AddDanglingEndpoint(e)
+ }
+}
+
+// saveDeviceRegistry is invoked by stateify.
+func (k *Kernel) saveDeviceRegistry() *device.Registry {
+ return device.SimpleDevices
+}
+
+// loadDeviceRegistry is invoked by stateify.
+func (k *Kernel) loadDeviceRegistry(r *device.Registry) {
+ device.SimpleDevices.LoadFrom(r)
+}
diff --git a/pkg/sentry/kernel/kernel_state_autogen.go b/pkg/sentry/kernel/kernel_state_autogen.go
new file mode 100755
index 000000000..82fd0abfd
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_state_autogen.go
@@ -0,0 +1,1147 @@
+// automatically generated by stateify.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/bpf"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+func (x *abstractEndpoint) beforeSave() {}
+func (x *abstractEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ep", &x.ep)
+ m.Save("wr", &x.wr)
+ m.Save("name", &x.name)
+ m.Save("ns", &x.ns)
+}
+
+func (x *abstractEndpoint) afterLoad() {}
+func (x *abstractEndpoint) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("wr", &x.wr)
+ m.Load("name", &x.name)
+ m.Load("ns", &x.ns)
+}
+
+func (x *AbstractSocketNamespace) beforeSave() {}
+func (x *AbstractSocketNamespace) save(m state.Map) {
+ x.beforeSave()
+ m.Save("endpoints", &x.endpoints)
+}
+
+func (x *AbstractSocketNamespace) afterLoad() {}
+func (x *AbstractSocketNamespace) load(m state.Map) {
+ m.Load("endpoints", &x.endpoints)
+}
+
+func (x *FDFlags) beforeSave() {}
+func (x *FDFlags) save(m state.Map) {
+ x.beforeSave()
+ m.Save("CloseOnExec", &x.CloseOnExec)
+}
+
+func (x *FDFlags) afterLoad() {}
+func (x *FDFlags) load(m state.Map) {
+ m.Load("CloseOnExec", &x.CloseOnExec)
+}
+
+func (x *descriptor) beforeSave() {}
+func (x *descriptor) save(m state.Map) {
+ x.beforeSave()
+ m.Save("file", &x.file)
+ m.Save("flags", &x.flags)
+}
+
+func (x *descriptor) afterLoad() {}
+func (x *descriptor) load(m state.Map) {
+ m.Load("file", &x.file)
+ m.Load("flags", &x.flags)
+}
+
+func (x *FDMap) beforeSave() {}
+func (x *FDMap) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("k", &x.k)
+ m.Save("files", &x.files)
+ m.Save("uid", &x.uid)
+}
+
+func (x *FDMap) afterLoad() {}
+func (x *FDMap) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("k", &x.k)
+ m.Load("files", &x.files)
+ m.Load("uid", &x.uid)
+}
+
+func (x *FSContext) beforeSave() {}
+func (x *FSContext) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("root", &x.root)
+ m.Save("cwd", &x.cwd)
+ m.Save("umask", &x.umask)
+}
+
+func (x *FSContext) afterLoad() {}
+func (x *FSContext) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("root", &x.root)
+ m.Load("cwd", &x.cwd)
+ m.Load("umask", &x.umask)
+}
+
+func (x *IPCNamespace) beforeSave() {}
+func (x *IPCNamespace) save(m state.Map) {
+ x.beforeSave()
+ m.Save("userNS", &x.userNS)
+ m.Save("semaphores", &x.semaphores)
+ m.Save("shms", &x.shms)
+}
+
+func (x *IPCNamespace) afterLoad() {}
+func (x *IPCNamespace) load(m state.Map) {
+ m.Load("userNS", &x.userNS)
+ m.Load("semaphores", &x.semaphores)
+ m.Load("shms", &x.shms)
+}
+
+func (x *Kernel) beforeSave() {}
+func (x *Kernel) save(m state.Map) {
+ x.beforeSave()
+ var danglingEndpoints []tcpip.Endpoint = x.saveDanglingEndpoints()
+ m.SaveValue("danglingEndpoints", danglingEndpoints)
+ var deviceRegistry *device.Registry = x.saveDeviceRegistry()
+ m.SaveValue("deviceRegistry", deviceRegistry)
+ m.Save("featureSet", &x.featureSet)
+ m.Save("timekeeper", &x.timekeeper)
+ m.Save("tasks", &x.tasks)
+ m.Save("rootUserNamespace", &x.rootUserNamespace)
+ m.Save("applicationCores", &x.applicationCores)
+ m.Save("useHostCores", &x.useHostCores)
+ m.Save("extraAuxv", &x.extraAuxv)
+ m.Save("vdso", &x.vdso)
+ m.Save("rootUTSNamespace", &x.rootUTSNamespace)
+ m.Save("rootIPCNamespace", &x.rootIPCNamespace)
+ m.Save("rootAbstractSocketNamespace", &x.rootAbstractSocketNamespace)
+ m.Save("mounts", &x.mounts)
+ m.Save("futexes", &x.futexes)
+ m.Save("globalInit", &x.globalInit)
+ m.Save("realtimeClock", &x.realtimeClock)
+ m.Save("monotonicClock", &x.monotonicClock)
+ m.Save("syslog", &x.syslog)
+ m.Save("cpuClock", &x.cpuClock)
+ m.Save("fdMapUids", &x.fdMapUids)
+ m.Save("uniqueID", &x.uniqueID)
+ m.Save("nextInotifyCookie", &x.nextInotifyCookie)
+ m.Save("netlinkPorts", &x.netlinkPorts)
+ m.Save("socketTable", &x.socketTable)
+ m.Save("DirentCacheLimiter", &x.DirentCacheLimiter)
+}
+
+func (x *Kernel) afterLoad() {}
+func (x *Kernel) load(m state.Map) {
+ m.Load("featureSet", &x.featureSet)
+ m.Load("timekeeper", &x.timekeeper)
+ m.Load("tasks", &x.tasks)
+ m.Load("rootUserNamespace", &x.rootUserNamespace)
+ m.Load("applicationCores", &x.applicationCores)
+ m.Load("useHostCores", &x.useHostCores)
+ m.Load("extraAuxv", &x.extraAuxv)
+ m.Load("vdso", &x.vdso)
+ m.Load("rootUTSNamespace", &x.rootUTSNamespace)
+ m.Load("rootIPCNamespace", &x.rootIPCNamespace)
+ m.Load("rootAbstractSocketNamespace", &x.rootAbstractSocketNamespace)
+ m.Load("mounts", &x.mounts)
+ m.Load("futexes", &x.futexes)
+ m.Load("globalInit", &x.globalInit)
+ m.Load("realtimeClock", &x.realtimeClock)
+ m.Load("monotonicClock", &x.monotonicClock)
+ m.Load("syslog", &x.syslog)
+ m.Load("cpuClock", &x.cpuClock)
+ m.Load("fdMapUids", &x.fdMapUids)
+ m.Load("uniqueID", &x.uniqueID)
+ m.Load("nextInotifyCookie", &x.nextInotifyCookie)
+ m.Load("netlinkPorts", &x.netlinkPorts)
+ m.Load("socketTable", &x.socketTable)
+ m.Load("DirentCacheLimiter", &x.DirentCacheLimiter)
+ m.LoadValue("danglingEndpoints", new([]tcpip.Endpoint), func(y interface{}) { x.loadDanglingEndpoints(y.([]tcpip.Endpoint)) })
+ m.LoadValue("deviceRegistry", new(*device.Registry), func(y interface{}) { x.loadDeviceRegistry(y.(*device.Registry)) })
+}
+
+func (x *socketEntry) beforeSave() {}
+func (x *socketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("k", &x.k)
+ m.Save("sock", &x.sock)
+ m.Save("family", &x.family)
+}
+
+func (x *socketEntry) afterLoad() {}
+func (x *socketEntry) load(m state.Map) {
+ m.Load("k", &x.k)
+ m.Load("sock", &x.sock)
+ m.Load("family", &x.family)
+}
+
+func (x *pendingSignals) beforeSave() {}
+func (x *pendingSignals) save(m state.Map) {
+ x.beforeSave()
+ var signals []savedPendingSignal = x.saveSignals()
+ m.SaveValue("signals", signals)
+}
+
+func (x *pendingSignals) afterLoad() {}
+func (x *pendingSignals) load(m state.Map) {
+ m.LoadValue("signals", new([]savedPendingSignal), func(y interface{}) { x.loadSignals(y.([]savedPendingSignal)) })
+}
+
+func (x *pendingSignalQueue) beforeSave() {}
+func (x *pendingSignalQueue) save(m state.Map) {
+ x.beforeSave()
+ m.Save("pendingSignalList", &x.pendingSignalList)
+ m.Save("length", &x.length)
+}
+
+func (x *pendingSignalQueue) afterLoad() {}
+func (x *pendingSignalQueue) load(m state.Map) {
+ m.Load("pendingSignalList", &x.pendingSignalList)
+ m.Load("length", &x.length)
+}
+
+func (x *pendingSignal) beforeSave() {}
+func (x *pendingSignal) save(m state.Map) {
+ x.beforeSave()
+ m.Save("pendingSignalEntry", &x.pendingSignalEntry)
+ m.Save("SignalInfo", &x.SignalInfo)
+ m.Save("timer", &x.timer)
+}
+
+func (x *pendingSignal) afterLoad() {}
+func (x *pendingSignal) load(m state.Map) {
+ m.Load("pendingSignalEntry", &x.pendingSignalEntry)
+ m.Load("SignalInfo", &x.SignalInfo)
+ m.Load("timer", &x.timer)
+}
+
+func (x *pendingSignalList) beforeSave() {}
+func (x *pendingSignalList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *pendingSignalList) afterLoad() {}
+func (x *pendingSignalList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *pendingSignalEntry) beforeSave() {}
+func (x *pendingSignalEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *pendingSignalEntry) afterLoad() {}
+func (x *pendingSignalEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *savedPendingSignal) beforeSave() {}
+func (x *savedPendingSignal) save(m state.Map) {
+ x.beforeSave()
+ m.Save("si", &x.si)
+ m.Save("timer", &x.timer)
+}
+
+func (x *savedPendingSignal) afterLoad() {}
+func (x *savedPendingSignal) load(m state.Map) {
+ m.Load("si", &x.si)
+ m.Load("timer", &x.timer)
+}
+
+func (x *IntervalTimer) beforeSave() {}
+func (x *IntervalTimer) save(m state.Map) {
+ x.beforeSave()
+ m.Save("timer", &x.timer)
+ m.Save("target", &x.target)
+ m.Save("signo", &x.signo)
+ m.Save("id", &x.id)
+ m.Save("sigval", &x.sigval)
+ m.Save("group", &x.group)
+ m.Save("sigpending", &x.sigpending)
+ m.Save("sigorphan", &x.sigorphan)
+ m.Save("overrunCur", &x.overrunCur)
+ m.Save("overrunLast", &x.overrunLast)
+}
+
+func (x *IntervalTimer) afterLoad() {}
+func (x *IntervalTimer) load(m state.Map) {
+ m.Load("timer", &x.timer)
+ m.Load("target", &x.target)
+ m.Load("signo", &x.signo)
+ m.Load("id", &x.id)
+ m.Load("sigval", &x.sigval)
+ m.Load("group", &x.group)
+ m.Load("sigpending", &x.sigpending)
+ m.Load("sigorphan", &x.sigorphan)
+ m.Load("overrunCur", &x.overrunCur)
+ m.Load("overrunLast", &x.overrunLast)
+}
+
+func (x *processGroupList) beforeSave() {}
+func (x *processGroupList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *processGroupList) afterLoad() {}
+func (x *processGroupList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *processGroupEntry) beforeSave() {}
+func (x *processGroupEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *processGroupEntry) afterLoad() {}
+func (x *processGroupEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *ptraceOptions) beforeSave() {}
+func (x *ptraceOptions) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ExitKill", &x.ExitKill)
+ m.Save("SysGood", &x.SysGood)
+ m.Save("TraceClone", &x.TraceClone)
+ m.Save("TraceExec", &x.TraceExec)
+ m.Save("TraceExit", &x.TraceExit)
+ m.Save("TraceFork", &x.TraceFork)
+ m.Save("TraceSeccomp", &x.TraceSeccomp)
+ m.Save("TraceVfork", &x.TraceVfork)
+ m.Save("TraceVforkDone", &x.TraceVforkDone)
+}
+
+func (x *ptraceOptions) afterLoad() {}
+func (x *ptraceOptions) load(m state.Map) {
+ m.Load("ExitKill", &x.ExitKill)
+ m.Load("SysGood", &x.SysGood)
+ m.Load("TraceClone", &x.TraceClone)
+ m.Load("TraceExec", &x.TraceExec)
+ m.Load("TraceExit", &x.TraceExit)
+ m.Load("TraceFork", &x.TraceFork)
+ m.Load("TraceSeccomp", &x.TraceSeccomp)
+ m.Load("TraceVfork", &x.TraceVfork)
+ m.Load("TraceVforkDone", &x.TraceVforkDone)
+}
+
+func (x *ptraceStop) beforeSave() {}
+func (x *ptraceStop) save(m state.Map) {
+ x.beforeSave()
+ m.Save("frozen", &x.frozen)
+ m.Save("listen", &x.listen)
+}
+
+func (x *ptraceStop) afterLoad() {}
+func (x *ptraceStop) load(m state.Map) {
+ m.Load("frozen", &x.frozen)
+ m.Load("listen", &x.listen)
+}
+
+func (x *RSEQCriticalRegion) beforeSave() {}
+func (x *RSEQCriticalRegion) save(m state.Map) {
+ x.beforeSave()
+ m.Save("CriticalSection", &x.CriticalSection)
+ m.Save("Restart", &x.Restart)
+}
+
+func (x *RSEQCriticalRegion) afterLoad() {}
+func (x *RSEQCriticalRegion) load(m state.Map) {
+ m.Load("CriticalSection", &x.CriticalSection)
+ m.Load("Restart", &x.Restart)
+}
+
+func (x *sessionList) beforeSave() {}
+func (x *sessionList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *sessionList) afterLoad() {}
+func (x *sessionList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *sessionEntry) beforeSave() {}
+func (x *sessionEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *sessionEntry) afterLoad() {}
+func (x *sessionEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *Session) beforeSave() {}
+func (x *Session) save(m state.Map) {
+ x.beforeSave()
+ m.Save("refs", &x.refs)
+ m.Save("leader", &x.leader)
+ m.Save("id", &x.id)
+ m.Save("processGroups", &x.processGroups)
+ m.Save("sessionEntry", &x.sessionEntry)
+}
+
+func (x *Session) afterLoad() {}
+func (x *Session) load(m state.Map) {
+ m.Load("refs", &x.refs)
+ m.Load("leader", &x.leader)
+ m.Load("id", &x.id)
+ m.Load("processGroups", &x.processGroups)
+ m.Load("sessionEntry", &x.sessionEntry)
+}
+
+func (x *ProcessGroup) beforeSave() {}
+func (x *ProcessGroup) save(m state.Map) {
+ x.beforeSave()
+ m.Save("refs", &x.refs)
+ m.Save("originator", &x.originator)
+ m.Save("id", &x.id)
+ m.Save("session", &x.session)
+ m.Save("ancestors", &x.ancestors)
+ m.Save("processGroupEntry", &x.processGroupEntry)
+}
+
+func (x *ProcessGroup) afterLoad() {}
+func (x *ProcessGroup) load(m state.Map) {
+ m.Load("refs", &x.refs)
+ m.Load("originator", &x.originator)
+ m.Load("id", &x.id)
+ m.Load("session", &x.session)
+ m.Load("ancestors", &x.ancestors)
+ m.Load("processGroupEntry", &x.processGroupEntry)
+}
+
+func (x *SignalHandlers) beforeSave() {}
+func (x *SignalHandlers) save(m state.Map) {
+ x.beforeSave()
+ m.Save("actions", &x.actions)
+}
+
+func (x *SignalHandlers) afterLoad() {}
+func (x *SignalHandlers) load(m state.Map) {
+ m.Load("actions", &x.actions)
+}
+
+func (x *SyscallTable) beforeSave() {}
+func (x *SyscallTable) save(m state.Map) {
+ x.beforeSave()
+ m.Save("OS", &x.OS)
+ m.Save("Arch", &x.Arch)
+}
+
+func (x *SyscallTable) load(m state.Map) {
+ m.LoadWait("OS", &x.OS)
+ m.LoadWait("Arch", &x.Arch)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *syslog) beforeSave() {}
+func (x *syslog) save(m state.Map) {
+ x.beforeSave()
+ m.Save("msg", &x.msg)
+}
+
+func (x *syslog) afterLoad() {}
+func (x *syslog) load(m state.Map) {
+ m.Load("msg", &x.msg)
+}
+
+func (x *Task) beforeSave() {}
+func (x *Task) save(m state.Map) {
+ x.beforeSave()
+ var ptraceTracer *Task = x.savePtraceTracer()
+ m.SaveValue("ptraceTracer", ptraceTracer)
+ var logPrefix string = x.saveLogPrefix()
+ m.SaveValue("logPrefix", logPrefix)
+ var syscallFilters []bpf.Program = x.saveSyscallFilters()
+ m.SaveValue("syscallFilters", syscallFilters)
+ m.Save("taskNode", &x.taskNode)
+ m.Save("runState", &x.runState)
+ m.Save("haveSyscallReturn", &x.haveSyscallReturn)
+ m.Save("gosched", &x.gosched)
+ m.Save("yieldCount", &x.yieldCount)
+ m.Save("pendingSignals", &x.pendingSignals)
+ m.Save("signalMask", &x.signalMask)
+ m.Save("realSignalMask", &x.realSignalMask)
+ m.Save("haveSavedSignalMask", &x.haveSavedSignalMask)
+ m.Save("savedSignalMask", &x.savedSignalMask)
+ m.Save("signalStack", &x.signalStack)
+ m.Save("groupStopPending", &x.groupStopPending)
+ m.Save("groupStopAcknowledged", &x.groupStopAcknowledged)
+ m.Save("trapStopPending", &x.trapStopPending)
+ m.Save("trapNotifyPending", &x.trapNotifyPending)
+ m.Save("stop", &x.stop)
+ m.Save("exitStatus", &x.exitStatus)
+ m.Save("syscallRestartBlock", &x.syscallRestartBlock)
+ m.Save("k", &x.k)
+ m.Save("containerID", &x.containerID)
+ m.Save("tc", &x.tc)
+ m.Save("fsc", &x.fsc)
+ m.Save("fds", &x.fds)
+ m.Save("vforkParent", &x.vforkParent)
+ m.Save("exitState", &x.exitState)
+ m.Save("exitTracerNotified", &x.exitTracerNotified)
+ m.Save("exitTracerAcked", &x.exitTracerAcked)
+ m.Save("exitParentNotified", &x.exitParentNotified)
+ m.Save("exitParentAcked", &x.exitParentAcked)
+ m.Save("ptraceTracees", &x.ptraceTracees)
+ m.Save("ptraceSeized", &x.ptraceSeized)
+ m.Save("ptraceOpts", &x.ptraceOpts)
+ m.Save("ptraceSyscallMode", &x.ptraceSyscallMode)
+ m.Save("ptraceSinglestep", &x.ptraceSinglestep)
+ m.Save("ptraceCode", &x.ptraceCode)
+ m.Save("ptraceSiginfo", &x.ptraceSiginfo)
+ m.Save("ptraceEventMsg", &x.ptraceEventMsg)
+ m.Save("ioUsage", &x.ioUsage)
+ m.Save("creds", &x.creds)
+ m.Save("utsns", &x.utsns)
+ m.Save("ipcns", &x.ipcns)
+ m.Save("abstractSockets", &x.abstractSockets)
+ m.Save("parentDeathSignal", &x.parentDeathSignal)
+ m.Save("cleartid", &x.cleartid)
+ m.Save("allowedCPUMask", &x.allowedCPUMask)
+ m.Save("cpu", &x.cpu)
+ m.Save("niceness", &x.niceness)
+ m.Save("numaPolicy", &x.numaPolicy)
+ m.Save("numaNodeMask", &x.numaNodeMask)
+ m.Save("netns", &x.netns)
+ m.Save("rseqCPUAddr", &x.rseqCPUAddr)
+ m.Save("rseqCPU", &x.rseqCPU)
+ m.Save("startTime", &x.startTime)
+}
+
+func (x *Task) load(m state.Map) {
+ m.Load("taskNode", &x.taskNode)
+ m.Load("runState", &x.runState)
+ m.Load("haveSyscallReturn", &x.haveSyscallReturn)
+ m.Load("gosched", &x.gosched)
+ m.Load("yieldCount", &x.yieldCount)
+ m.Load("pendingSignals", &x.pendingSignals)
+ m.Load("signalMask", &x.signalMask)
+ m.Load("realSignalMask", &x.realSignalMask)
+ m.Load("haveSavedSignalMask", &x.haveSavedSignalMask)
+ m.Load("savedSignalMask", &x.savedSignalMask)
+ m.Load("signalStack", &x.signalStack)
+ m.Load("groupStopPending", &x.groupStopPending)
+ m.Load("groupStopAcknowledged", &x.groupStopAcknowledged)
+ m.Load("trapStopPending", &x.trapStopPending)
+ m.Load("trapNotifyPending", &x.trapNotifyPending)
+ m.Load("stop", &x.stop)
+ m.Load("exitStatus", &x.exitStatus)
+ m.Load("syscallRestartBlock", &x.syscallRestartBlock)
+ m.Load("k", &x.k)
+ m.Load("containerID", &x.containerID)
+ m.Load("tc", &x.tc)
+ m.Load("fsc", &x.fsc)
+ m.Load("fds", &x.fds)
+ m.Load("vforkParent", &x.vforkParent)
+ m.Load("exitState", &x.exitState)
+ m.Load("exitTracerNotified", &x.exitTracerNotified)
+ m.Load("exitTracerAcked", &x.exitTracerAcked)
+ m.Load("exitParentNotified", &x.exitParentNotified)
+ m.Load("exitParentAcked", &x.exitParentAcked)
+ m.Load("ptraceTracees", &x.ptraceTracees)
+ m.Load("ptraceSeized", &x.ptraceSeized)
+ m.Load("ptraceOpts", &x.ptraceOpts)
+ m.Load("ptraceSyscallMode", &x.ptraceSyscallMode)
+ m.Load("ptraceSinglestep", &x.ptraceSinglestep)
+ m.Load("ptraceCode", &x.ptraceCode)
+ m.Load("ptraceSiginfo", &x.ptraceSiginfo)
+ m.Load("ptraceEventMsg", &x.ptraceEventMsg)
+ m.Load("ioUsage", &x.ioUsage)
+ m.Load("creds", &x.creds)
+ m.Load("utsns", &x.utsns)
+ m.Load("ipcns", &x.ipcns)
+ m.Load("abstractSockets", &x.abstractSockets)
+ m.Load("parentDeathSignal", &x.parentDeathSignal)
+ m.Load("cleartid", &x.cleartid)
+ m.Load("allowedCPUMask", &x.allowedCPUMask)
+ m.Load("cpu", &x.cpu)
+ m.Load("niceness", &x.niceness)
+ m.Load("numaPolicy", &x.numaPolicy)
+ m.Load("numaNodeMask", &x.numaNodeMask)
+ m.Load("netns", &x.netns)
+ m.Load("rseqCPUAddr", &x.rseqCPUAddr)
+ m.Load("rseqCPU", &x.rseqCPU)
+ m.Load("startTime", &x.startTime)
+ m.LoadValue("ptraceTracer", new(*Task), func(y interface{}) { x.loadPtraceTracer(y.(*Task)) })
+ m.LoadValue("logPrefix", new(string), func(y interface{}) { x.loadLogPrefix(y.(string)) })
+ m.LoadValue("syscallFilters", new([]bpf.Program), func(y interface{}) { x.loadSyscallFilters(y.([]bpf.Program)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *runSyscallAfterPtraceEventClone) beforeSave() {}
+func (x *runSyscallAfterPtraceEventClone) save(m state.Map) {
+ x.beforeSave()
+ m.Save("vforkChild", &x.vforkChild)
+ m.Save("vforkChildTID", &x.vforkChildTID)
+}
+
+func (x *runSyscallAfterPtraceEventClone) afterLoad() {}
+func (x *runSyscallAfterPtraceEventClone) load(m state.Map) {
+ m.Load("vforkChild", &x.vforkChild)
+ m.Load("vforkChildTID", &x.vforkChildTID)
+}
+
+func (x *runSyscallAfterVforkStop) beforeSave() {}
+func (x *runSyscallAfterVforkStop) save(m state.Map) {
+ x.beforeSave()
+ m.Save("childTID", &x.childTID)
+}
+
+func (x *runSyscallAfterVforkStop) afterLoad() {}
+func (x *runSyscallAfterVforkStop) load(m state.Map) {
+ m.Load("childTID", &x.childTID)
+}
+
+func (x *vforkStop) beforeSave() {}
+func (x *vforkStop) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *vforkStop) afterLoad() {}
+func (x *vforkStop) load(m state.Map) {
+}
+
+func (x *TaskContext) beforeSave() {}
+func (x *TaskContext) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Name", &x.Name)
+ m.Save("Arch", &x.Arch)
+ m.Save("MemoryManager", &x.MemoryManager)
+ m.Save("fu", &x.fu)
+ m.Save("st", &x.st)
+}
+
+func (x *TaskContext) afterLoad() {}
+func (x *TaskContext) load(m state.Map) {
+ m.Load("Name", &x.Name)
+ m.Load("Arch", &x.Arch)
+ m.Load("MemoryManager", &x.MemoryManager)
+ m.Load("fu", &x.fu)
+ m.Load("st", &x.st)
+}
+
+func (x *execStop) beforeSave() {}
+func (x *execStop) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *execStop) afterLoad() {}
+func (x *execStop) load(m state.Map) {
+}
+
+func (x *runSyscallAfterExecStop) beforeSave() {}
+func (x *runSyscallAfterExecStop) save(m state.Map) {
+ x.beforeSave()
+ m.Save("tc", &x.tc)
+}
+
+func (x *runSyscallAfterExecStop) afterLoad() {}
+func (x *runSyscallAfterExecStop) load(m state.Map) {
+ m.Load("tc", &x.tc)
+}
+
+func (x *ExitStatus) beforeSave() {}
+func (x *ExitStatus) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Code", &x.Code)
+ m.Save("Signo", &x.Signo)
+}
+
+func (x *ExitStatus) afterLoad() {}
+func (x *ExitStatus) load(m state.Map) {
+ m.Load("Code", &x.Code)
+ m.Load("Signo", &x.Signo)
+}
+
+func (x *runExit) beforeSave() {}
+func (x *runExit) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runExit) afterLoad() {}
+func (x *runExit) load(m state.Map) {
+}
+
+func (x *runExitMain) beforeSave() {}
+func (x *runExitMain) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runExitMain) afterLoad() {}
+func (x *runExitMain) load(m state.Map) {
+}
+
+func (x *runExitNotify) beforeSave() {}
+func (x *runExitNotify) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runExitNotify) afterLoad() {}
+func (x *runExitNotify) load(m state.Map) {
+}
+
+func (x *taskList) beforeSave() {}
+func (x *taskList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *taskList) afterLoad() {}
+func (x *taskList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *taskEntry) beforeSave() {}
+func (x *taskEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *taskEntry) afterLoad() {}
+func (x *taskEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *runApp) beforeSave() {}
+func (x *runApp) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runApp) afterLoad() {}
+func (x *runApp) load(m state.Map) {
+}
+
+func (x *TaskGoroutineSchedInfo) beforeSave() {}
+func (x *TaskGoroutineSchedInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Timestamp", &x.Timestamp)
+ m.Save("State", &x.State)
+ m.Save("UserTicks", &x.UserTicks)
+ m.Save("SysTicks", &x.SysTicks)
+}
+
+func (x *TaskGoroutineSchedInfo) afterLoad() {}
+func (x *TaskGoroutineSchedInfo) load(m state.Map) {
+ m.Load("Timestamp", &x.Timestamp)
+ m.Load("State", &x.State)
+ m.Load("UserTicks", &x.UserTicks)
+ m.Load("SysTicks", &x.SysTicks)
+}
+
+func (x *taskClock) beforeSave() {}
+func (x *taskClock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+ m.Save("includeSys", &x.includeSys)
+}
+
+func (x *taskClock) afterLoad() {}
+func (x *taskClock) load(m state.Map) {
+ m.Load("t", &x.t)
+ m.Load("includeSys", &x.includeSys)
+}
+
+func (x *tgClock) beforeSave() {}
+func (x *tgClock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("tg", &x.tg)
+ m.Save("includeSys", &x.includeSys)
+}
+
+func (x *tgClock) afterLoad() {}
+func (x *tgClock) load(m state.Map) {
+ m.Load("tg", &x.tg)
+ m.Load("includeSys", &x.includeSys)
+}
+
+func (x *groupStop) beforeSave() {}
+func (x *groupStop) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *groupStop) afterLoad() {}
+func (x *groupStop) load(m state.Map) {
+}
+
+func (x *runInterrupt) beforeSave() {}
+func (x *runInterrupt) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runInterrupt) afterLoad() {}
+func (x *runInterrupt) load(m state.Map) {
+}
+
+func (x *runInterruptAfterSignalDeliveryStop) beforeSave() {}
+func (x *runInterruptAfterSignalDeliveryStop) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runInterruptAfterSignalDeliveryStop) afterLoad() {}
+func (x *runInterruptAfterSignalDeliveryStop) load(m state.Map) {
+}
+
+func (x *runSyscallAfterSyscallEnterStop) beforeSave() {}
+func (x *runSyscallAfterSyscallEnterStop) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runSyscallAfterSyscallEnterStop) afterLoad() {}
+func (x *runSyscallAfterSyscallEnterStop) load(m state.Map) {
+}
+
+func (x *runSyscallAfterSysemuStop) beforeSave() {}
+func (x *runSyscallAfterSysemuStop) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runSyscallAfterSysemuStop) afterLoad() {}
+func (x *runSyscallAfterSysemuStop) load(m state.Map) {
+}
+
+func (x *runSyscallReinvoke) beforeSave() {}
+func (x *runSyscallReinvoke) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runSyscallReinvoke) afterLoad() {}
+func (x *runSyscallReinvoke) load(m state.Map) {
+}
+
+func (x *runSyscallExit) beforeSave() {}
+func (x *runSyscallExit) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *runSyscallExit) afterLoad() {}
+func (x *runSyscallExit) load(m state.Map) {
+}
+
+func (x *ThreadGroup) beforeSave() {}
+func (x *ThreadGroup) save(m state.Map) {
+ x.beforeSave()
+ var rscr *RSEQCriticalRegion = x.saveRscr()
+ m.SaveValue("rscr", rscr)
+ m.Save("threadGroupNode", &x.threadGroupNode)
+ m.Save("signalHandlers", &x.signalHandlers)
+ m.Save("pendingSignals", &x.pendingSignals)
+ m.Save("groupStopDequeued", &x.groupStopDequeued)
+ m.Save("groupStopSignal", &x.groupStopSignal)
+ m.Save("groupStopPendingCount", &x.groupStopPendingCount)
+ m.Save("groupStopComplete", &x.groupStopComplete)
+ m.Save("groupStopWaitable", &x.groupStopWaitable)
+ m.Save("groupContNotify", &x.groupContNotify)
+ m.Save("groupContInterrupted", &x.groupContInterrupted)
+ m.Save("groupContWaitable", &x.groupContWaitable)
+ m.Save("exiting", &x.exiting)
+ m.Save("exitStatus", &x.exitStatus)
+ m.Save("terminationSignal", &x.terminationSignal)
+ m.Save("itimerRealTimer", &x.itimerRealTimer)
+ m.Save("itimerVirtSetting", &x.itimerVirtSetting)
+ m.Save("itimerProfSetting", &x.itimerProfSetting)
+ m.Save("rlimitCPUSoftSetting", &x.rlimitCPUSoftSetting)
+ m.Save("cpuTimersEnabled", &x.cpuTimersEnabled)
+ m.Save("timers", &x.timers)
+ m.Save("nextTimerID", &x.nextTimerID)
+ m.Save("exitedCPUStats", &x.exitedCPUStats)
+ m.Save("childCPUStats", &x.childCPUStats)
+ m.Save("ioUsage", &x.ioUsage)
+ m.Save("maxRSS", &x.maxRSS)
+ m.Save("childMaxRSS", &x.childMaxRSS)
+ m.Save("limits", &x.limits)
+ m.Save("processGroup", &x.processGroup)
+ m.Save("execed", &x.execed)
+}
+
+func (x *ThreadGroup) afterLoad() {}
+func (x *ThreadGroup) load(m state.Map) {
+ m.Load("threadGroupNode", &x.threadGroupNode)
+ m.Load("signalHandlers", &x.signalHandlers)
+ m.Load("pendingSignals", &x.pendingSignals)
+ m.Load("groupStopDequeued", &x.groupStopDequeued)
+ m.Load("groupStopSignal", &x.groupStopSignal)
+ m.Load("groupStopPendingCount", &x.groupStopPendingCount)
+ m.Load("groupStopComplete", &x.groupStopComplete)
+ m.Load("groupStopWaitable", &x.groupStopWaitable)
+ m.Load("groupContNotify", &x.groupContNotify)
+ m.Load("groupContInterrupted", &x.groupContInterrupted)
+ m.Load("groupContWaitable", &x.groupContWaitable)
+ m.Load("exiting", &x.exiting)
+ m.Load("exitStatus", &x.exitStatus)
+ m.Load("terminationSignal", &x.terminationSignal)
+ m.Load("itimerRealTimer", &x.itimerRealTimer)
+ m.Load("itimerVirtSetting", &x.itimerVirtSetting)
+ m.Load("itimerProfSetting", &x.itimerProfSetting)
+ m.Load("rlimitCPUSoftSetting", &x.rlimitCPUSoftSetting)
+ m.Load("cpuTimersEnabled", &x.cpuTimersEnabled)
+ m.Load("timers", &x.timers)
+ m.Load("nextTimerID", &x.nextTimerID)
+ m.Load("exitedCPUStats", &x.exitedCPUStats)
+ m.Load("childCPUStats", &x.childCPUStats)
+ m.Load("ioUsage", &x.ioUsage)
+ m.Load("maxRSS", &x.maxRSS)
+ m.Load("childMaxRSS", &x.childMaxRSS)
+ m.Load("limits", &x.limits)
+ m.Load("processGroup", &x.processGroup)
+ m.Load("execed", &x.execed)
+ m.LoadValue("rscr", new(*RSEQCriticalRegion), func(y interface{}) { x.loadRscr(y.(*RSEQCriticalRegion)) })
+}
+
+func (x *itimerRealListener) beforeSave() {}
+func (x *itimerRealListener) save(m state.Map) {
+ x.beforeSave()
+ m.Save("tg", &x.tg)
+}
+
+func (x *itimerRealListener) afterLoad() {}
+func (x *itimerRealListener) load(m state.Map) {
+ m.Load("tg", &x.tg)
+}
+
+func (x *TaskSet) beforeSave() {}
+func (x *TaskSet) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Root", &x.Root)
+ m.Save("sessions", &x.sessions)
+}
+
+func (x *TaskSet) afterLoad() {}
+func (x *TaskSet) load(m state.Map) {
+ m.Load("Root", &x.Root)
+ m.Load("sessions", &x.sessions)
+}
+
+func (x *PIDNamespace) beforeSave() {}
+func (x *PIDNamespace) save(m state.Map) {
+ x.beforeSave()
+ m.Save("owner", &x.owner)
+ m.Save("parent", &x.parent)
+ m.Save("userns", &x.userns)
+ m.Save("last", &x.last)
+ m.Save("tasks", &x.tasks)
+ m.Save("tids", &x.tids)
+ m.Save("tgids", &x.tgids)
+ m.Save("sessions", &x.sessions)
+ m.Save("sids", &x.sids)
+ m.Save("processGroups", &x.processGroups)
+ m.Save("pgids", &x.pgids)
+ m.Save("exiting", &x.exiting)
+}
+
+func (x *PIDNamespace) afterLoad() {}
+func (x *PIDNamespace) load(m state.Map) {
+ m.Load("owner", &x.owner)
+ m.Load("parent", &x.parent)
+ m.Load("userns", &x.userns)
+ m.Load("last", &x.last)
+ m.Load("tasks", &x.tasks)
+ m.Load("tids", &x.tids)
+ m.Load("tgids", &x.tgids)
+ m.Load("sessions", &x.sessions)
+ m.Load("sids", &x.sids)
+ m.Load("processGroups", &x.processGroups)
+ m.Load("pgids", &x.pgids)
+ m.Load("exiting", &x.exiting)
+}
+
+func (x *threadGroupNode) beforeSave() {}
+func (x *threadGroupNode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("pidns", &x.pidns)
+ m.Save("leader", &x.leader)
+ m.Save("execing", &x.execing)
+ m.Save("tasks", &x.tasks)
+ m.Save("tasksCount", &x.tasksCount)
+ m.Save("liveTasks", &x.liveTasks)
+ m.Save("activeTasks", &x.activeTasks)
+}
+
+func (x *threadGroupNode) afterLoad() {}
+func (x *threadGroupNode) load(m state.Map) {
+ m.Load("pidns", &x.pidns)
+ m.Load("leader", &x.leader)
+ m.Load("execing", &x.execing)
+ m.Load("tasks", &x.tasks)
+ m.Load("tasksCount", &x.tasksCount)
+ m.Load("liveTasks", &x.liveTasks)
+ m.Load("activeTasks", &x.activeTasks)
+}
+
+func (x *taskNode) beforeSave() {}
+func (x *taskNode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("tg", &x.tg)
+ m.Save("taskEntry", &x.taskEntry)
+ m.Save("parent", &x.parent)
+ m.Save("children", &x.children)
+ m.Save("childPIDNamespace", &x.childPIDNamespace)
+}
+
+func (x *taskNode) afterLoad() {}
+func (x *taskNode) load(m state.Map) {
+ m.LoadWait("tg", &x.tg)
+ m.Load("taskEntry", &x.taskEntry)
+ m.Load("parent", &x.parent)
+ m.Load("children", &x.children)
+ m.Load("childPIDNamespace", &x.childPIDNamespace)
+}
+
+func (x *Timekeeper) save(m state.Map) {
+ x.beforeSave()
+ m.Save("bootTime", &x.bootTime)
+ m.Save("saveMonotonic", &x.saveMonotonic)
+ m.Save("saveRealtime", &x.saveRealtime)
+ m.Save("params", &x.params)
+}
+
+func (x *Timekeeper) load(m state.Map) {
+ m.Load("bootTime", &x.bootTime)
+ m.Load("saveMonotonic", &x.saveMonotonic)
+ m.Load("saveRealtime", &x.saveRealtime)
+ m.Load("params", &x.params)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *timekeeperClock) beforeSave() {}
+func (x *timekeeperClock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("tk", &x.tk)
+ m.Save("c", &x.c)
+}
+
+func (x *timekeeperClock) afterLoad() {}
+func (x *timekeeperClock) load(m state.Map) {
+ m.Load("tk", &x.tk)
+ m.Load("c", &x.c)
+}
+
+func (x *UTSNamespace) beforeSave() {}
+func (x *UTSNamespace) save(m state.Map) {
+ x.beforeSave()
+ m.Save("hostName", &x.hostName)
+ m.Save("domainName", &x.domainName)
+ m.Save("userns", &x.userns)
+}
+
+func (x *UTSNamespace) afterLoad() {}
+func (x *UTSNamespace) load(m state.Map) {
+ m.Load("hostName", &x.hostName)
+ m.Load("domainName", &x.domainName)
+ m.Load("userns", &x.userns)
+}
+
+func (x *VDSOParamPage) beforeSave() {}
+func (x *VDSOParamPage) save(m state.Map) {
+ x.beforeSave()
+ m.Save("mfp", &x.mfp)
+ m.Save("fr", &x.fr)
+ m.Save("seq", &x.seq)
+}
+
+func (x *VDSOParamPage) afterLoad() {}
+func (x *VDSOParamPage) load(m state.Map) {
+ m.Load("mfp", &x.mfp)
+ m.Load("fr", &x.fr)
+ m.Load("seq", &x.seq)
+}
+
+func init() {
+ state.Register("kernel.abstractEndpoint", (*abstractEndpoint)(nil), state.Fns{Save: (*abstractEndpoint).save, Load: (*abstractEndpoint).load})
+ state.Register("kernel.AbstractSocketNamespace", (*AbstractSocketNamespace)(nil), state.Fns{Save: (*AbstractSocketNamespace).save, Load: (*AbstractSocketNamespace).load})
+ state.Register("kernel.FDFlags", (*FDFlags)(nil), state.Fns{Save: (*FDFlags).save, Load: (*FDFlags).load})
+ state.Register("kernel.descriptor", (*descriptor)(nil), state.Fns{Save: (*descriptor).save, Load: (*descriptor).load})
+ state.Register("kernel.FDMap", (*FDMap)(nil), state.Fns{Save: (*FDMap).save, Load: (*FDMap).load})
+ state.Register("kernel.FSContext", (*FSContext)(nil), state.Fns{Save: (*FSContext).save, Load: (*FSContext).load})
+ state.Register("kernel.IPCNamespace", (*IPCNamespace)(nil), state.Fns{Save: (*IPCNamespace).save, Load: (*IPCNamespace).load})
+ state.Register("kernel.Kernel", (*Kernel)(nil), state.Fns{Save: (*Kernel).save, Load: (*Kernel).load})
+ state.Register("kernel.socketEntry", (*socketEntry)(nil), state.Fns{Save: (*socketEntry).save, Load: (*socketEntry).load})
+ state.Register("kernel.pendingSignals", (*pendingSignals)(nil), state.Fns{Save: (*pendingSignals).save, Load: (*pendingSignals).load})
+ state.Register("kernel.pendingSignalQueue", (*pendingSignalQueue)(nil), state.Fns{Save: (*pendingSignalQueue).save, Load: (*pendingSignalQueue).load})
+ state.Register("kernel.pendingSignal", (*pendingSignal)(nil), state.Fns{Save: (*pendingSignal).save, Load: (*pendingSignal).load})
+ state.Register("kernel.pendingSignalList", (*pendingSignalList)(nil), state.Fns{Save: (*pendingSignalList).save, Load: (*pendingSignalList).load})
+ state.Register("kernel.pendingSignalEntry", (*pendingSignalEntry)(nil), state.Fns{Save: (*pendingSignalEntry).save, Load: (*pendingSignalEntry).load})
+ state.Register("kernel.savedPendingSignal", (*savedPendingSignal)(nil), state.Fns{Save: (*savedPendingSignal).save, Load: (*savedPendingSignal).load})
+ state.Register("kernel.IntervalTimer", (*IntervalTimer)(nil), state.Fns{Save: (*IntervalTimer).save, Load: (*IntervalTimer).load})
+ state.Register("kernel.processGroupList", (*processGroupList)(nil), state.Fns{Save: (*processGroupList).save, Load: (*processGroupList).load})
+ state.Register("kernel.processGroupEntry", (*processGroupEntry)(nil), state.Fns{Save: (*processGroupEntry).save, Load: (*processGroupEntry).load})
+ state.Register("kernel.ptraceOptions", (*ptraceOptions)(nil), state.Fns{Save: (*ptraceOptions).save, Load: (*ptraceOptions).load})
+ state.Register("kernel.ptraceStop", (*ptraceStop)(nil), state.Fns{Save: (*ptraceStop).save, Load: (*ptraceStop).load})
+ state.Register("kernel.RSEQCriticalRegion", (*RSEQCriticalRegion)(nil), state.Fns{Save: (*RSEQCriticalRegion).save, Load: (*RSEQCriticalRegion).load})
+ state.Register("kernel.sessionList", (*sessionList)(nil), state.Fns{Save: (*sessionList).save, Load: (*sessionList).load})
+ state.Register("kernel.sessionEntry", (*sessionEntry)(nil), state.Fns{Save: (*sessionEntry).save, Load: (*sessionEntry).load})
+ state.Register("kernel.Session", (*Session)(nil), state.Fns{Save: (*Session).save, Load: (*Session).load})
+ state.Register("kernel.ProcessGroup", (*ProcessGroup)(nil), state.Fns{Save: (*ProcessGroup).save, Load: (*ProcessGroup).load})
+ state.Register("kernel.SignalHandlers", (*SignalHandlers)(nil), state.Fns{Save: (*SignalHandlers).save, Load: (*SignalHandlers).load})
+ state.Register("kernel.SyscallTable", (*SyscallTable)(nil), state.Fns{Save: (*SyscallTable).save, Load: (*SyscallTable).load})
+ state.Register("kernel.syslog", (*syslog)(nil), state.Fns{Save: (*syslog).save, Load: (*syslog).load})
+ state.Register("kernel.Task", (*Task)(nil), state.Fns{Save: (*Task).save, Load: (*Task).load})
+ state.Register("kernel.runSyscallAfterPtraceEventClone", (*runSyscallAfterPtraceEventClone)(nil), state.Fns{Save: (*runSyscallAfterPtraceEventClone).save, Load: (*runSyscallAfterPtraceEventClone).load})
+ state.Register("kernel.runSyscallAfterVforkStop", (*runSyscallAfterVforkStop)(nil), state.Fns{Save: (*runSyscallAfterVforkStop).save, Load: (*runSyscallAfterVforkStop).load})
+ state.Register("kernel.vforkStop", (*vforkStop)(nil), state.Fns{Save: (*vforkStop).save, Load: (*vforkStop).load})
+ state.Register("kernel.TaskContext", (*TaskContext)(nil), state.Fns{Save: (*TaskContext).save, Load: (*TaskContext).load})
+ state.Register("kernel.execStop", (*execStop)(nil), state.Fns{Save: (*execStop).save, Load: (*execStop).load})
+ state.Register("kernel.runSyscallAfterExecStop", (*runSyscallAfterExecStop)(nil), state.Fns{Save: (*runSyscallAfterExecStop).save, Load: (*runSyscallAfterExecStop).load})
+ state.Register("kernel.ExitStatus", (*ExitStatus)(nil), state.Fns{Save: (*ExitStatus).save, Load: (*ExitStatus).load})
+ state.Register("kernel.runExit", (*runExit)(nil), state.Fns{Save: (*runExit).save, Load: (*runExit).load})
+ state.Register("kernel.runExitMain", (*runExitMain)(nil), state.Fns{Save: (*runExitMain).save, Load: (*runExitMain).load})
+ state.Register("kernel.runExitNotify", (*runExitNotify)(nil), state.Fns{Save: (*runExitNotify).save, Load: (*runExitNotify).load})
+ state.Register("kernel.taskList", (*taskList)(nil), state.Fns{Save: (*taskList).save, Load: (*taskList).load})
+ state.Register("kernel.taskEntry", (*taskEntry)(nil), state.Fns{Save: (*taskEntry).save, Load: (*taskEntry).load})
+ state.Register("kernel.runApp", (*runApp)(nil), state.Fns{Save: (*runApp).save, Load: (*runApp).load})
+ state.Register("kernel.TaskGoroutineSchedInfo", (*TaskGoroutineSchedInfo)(nil), state.Fns{Save: (*TaskGoroutineSchedInfo).save, Load: (*TaskGoroutineSchedInfo).load})
+ state.Register("kernel.taskClock", (*taskClock)(nil), state.Fns{Save: (*taskClock).save, Load: (*taskClock).load})
+ state.Register("kernel.tgClock", (*tgClock)(nil), state.Fns{Save: (*tgClock).save, Load: (*tgClock).load})
+ state.Register("kernel.groupStop", (*groupStop)(nil), state.Fns{Save: (*groupStop).save, Load: (*groupStop).load})
+ state.Register("kernel.runInterrupt", (*runInterrupt)(nil), state.Fns{Save: (*runInterrupt).save, Load: (*runInterrupt).load})
+ state.Register("kernel.runInterruptAfterSignalDeliveryStop", (*runInterruptAfterSignalDeliveryStop)(nil), state.Fns{Save: (*runInterruptAfterSignalDeliveryStop).save, Load: (*runInterruptAfterSignalDeliveryStop).load})
+ state.Register("kernel.runSyscallAfterSyscallEnterStop", (*runSyscallAfterSyscallEnterStop)(nil), state.Fns{Save: (*runSyscallAfterSyscallEnterStop).save, Load: (*runSyscallAfterSyscallEnterStop).load})
+ state.Register("kernel.runSyscallAfterSysemuStop", (*runSyscallAfterSysemuStop)(nil), state.Fns{Save: (*runSyscallAfterSysemuStop).save, Load: (*runSyscallAfterSysemuStop).load})
+ state.Register("kernel.runSyscallReinvoke", (*runSyscallReinvoke)(nil), state.Fns{Save: (*runSyscallReinvoke).save, Load: (*runSyscallReinvoke).load})
+ state.Register("kernel.runSyscallExit", (*runSyscallExit)(nil), state.Fns{Save: (*runSyscallExit).save, Load: (*runSyscallExit).load})
+ state.Register("kernel.ThreadGroup", (*ThreadGroup)(nil), state.Fns{Save: (*ThreadGroup).save, Load: (*ThreadGroup).load})
+ state.Register("kernel.itimerRealListener", (*itimerRealListener)(nil), state.Fns{Save: (*itimerRealListener).save, Load: (*itimerRealListener).load})
+ state.Register("kernel.TaskSet", (*TaskSet)(nil), state.Fns{Save: (*TaskSet).save, Load: (*TaskSet).load})
+ state.Register("kernel.PIDNamespace", (*PIDNamespace)(nil), state.Fns{Save: (*PIDNamespace).save, Load: (*PIDNamespace).load})
+ state.Register("kernel.threadGroupNode", (*threadGroupNode)(nil), state.Fns{Save: (*threadGroupNode).save, Load: (*threadGroupNode).load})
+ state.Register("kernel.taskNode", (*taskNode)(nil), state.Fns{Save: (*taskNode).save, Load: (*taskNode).load})
+ state.Register("kernel.Timekeeper", (*Timekeeper)(nil), state.Fns{Save: (*Timekeeper).save, Load: (*Timekeeper).load})
+ state.Register("kernel.timekeeperClock", (*timekeeperClock)(nil), state.Fns{Save: (*timekeeperClock).save, Load: (*timekeeperClock).load})
+ state.Register("kernel.UTSNamespace", (*UTSNamespace)(nil), state.Fns{Save: (*UTSNamespace).save, Load: (*UTSNamespace).load})
+ state.Register("kernel.VDSOParamPage", (*VDSOParamPage)(nil), state.Fns{Save: (*VDSOParamPage).save, Load: (*VDSOParamPage).load})
+}
diff --git a/pkg/sentry/kernel/pending_signals.go b/pkg/sentry/kernel/pending_signals.go
new file mode 100644
index 000000000..c93f6598a
--- /dev/null
+++ b/pkg/sentry/kernel/pending_signals.go
@@ -0,0 +1,142 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/bits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+)
+
+const (
+ // stdSignalCap is the maximum number of instances of a given standard
+ // signal that may be pending. ("[If] multiple instances of a standard
+ // signal are delivered while that signal is currently blocked, then only
+ // one instance is queued.") - signal(7)
+ stdSignalCap = 1
+
+ // rtSignalCap is the maximum number of instances of a given realtime
+ // signal that may be pending.
+ //
+ // TODO(igudger): In Linux, the minimum signal queue size is
+ // RLIMIT_SIGPENDING, which is by default max_threads/2.
+ rtSignalCap = 32
+)
+
+// pendingSignals holds a collection of pending signals. The zero value of
+// pendingSignals is a valid empty collection. pendingSignals is thread-unsafe;
+// users must provide synchronization.
+//
+// +stateify savable
+type pendingSignals struct {
+ // signals contains all pending signals.
+ //
+ // Note that signals is zero-indexed, but signal 1 is the first valid
+ // signal, so signals[0] contains signals with signo 1 etc. This offset is
+ // usually handled by using Signal.index().
+ signals [linux.SignalMaximum]pendingSignalQueue `state:".([]savedPendingSignal)"`
+
+ // Bit i of pendingSet is set iff there is at least one signal with signo
+ // i+1 pending.
+ pendingSet linux.SignalSet `state:"manual"`
+}
+
+// pendingSignalQueue holds a pendingSignalList for a single signal number.
+//
+// +stateify savable
+type pendingSignalQueue struct {
+ pendingSignalList
+ length int
+}
+
+// +stateify savable
+type pendingSignal struct {
+ // pendingSignalEntry links into a pendingSignalList.
+ pendingSignalEntry
+ *arch.SignalInfo
+
+ // If timer is not nil, it is the IntervalTimer which sent this signal.
+ timer *IntervalTimer
+}
+
+// enqueue enqueues the given signal. enqueue returns true on success and false
+// on failure (if the given signal's queue is full).
+//
+// Preconditions: info represents a valid signal.
+func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bool {
+ sig := linux.Signal(info.Signo)
+ q := &p.signals[sig.Index()]
+ if sig.IsStandard() {
+ if q.length >= stdSignalCap {
+ return false
+ }
+ } else if q.length >= rtSignalCap {
+ return false
+ }
+ q.pendingSignalList.PushBack(&pendingSignal{SignalInfo: info, timer: timer})
+ q.length++
+ p.pendingSet |= linux.SignalSetOf(sig)
+ return true
+}
+
+// dequeue dequeues and returns any pending signal not masked by mask. If no
+// unmasked signals are pending, dequeue returns nil.
+func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo {
+ // "Real-time signals are delivered in a guaranteed order. Multiple
+ // real-time signals of the same type are delivered in the order they were
+ // sent. If different real-time signals are sent to a process, they are
+ // delivered starting with the lowest-numbered signal. (I.e., low-numbered
+ // signals have highest priority.) By contrast, if multiple standard
+ // signals are pending for a process, the order in which they are delivered
+ // is unspecified. If both standard and real-time signals are pending for a
+ // process, POSIX leaves it unspecified which is delivered first. Linux,
+ // like many other implementations, gives priority to standard signals in
+ // this case." - signal(7)
+ lowestPendingUnblockedBit := bits.TrailingZeros64(uint64(p.pendingSet &^ mask))
+ if lowestPendingUnblockedBit >= linux.SignalMaximum {
+ return nil
+ }
+ return p.dequeueSpecific(linux.Signal(lowestPendingUnblockedBit + 1))
+}
+
+func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *arch.SignalInfo {
+ q := &p.signals[sig.Index()]
+ ps := q.pendingSignalList.Front()
+ if ps == nil {
+ return nil
+ }
+ q.pendingSignalList.Remove(ps)
+ q.length--
+ if q.length == 0 {
+ p.pendingSet &^= linux.SignalSetOf(sig)
+ }
+ if ps.timer != nil {
+ ps.timer.updateDequeuedSignalLocked(ps.SignalInfo)
+ }
+ return ps.SignalInfo
+}
+
+// discardSpecific causes all pending signals with number sig to be discarded.
+func (p *pendingSignals) discardSpecific(sig linux.Signal) {
+ q := &p.signals[sig.Index()]
+ for ps := q.pendingSignalList.Front(); ps != nil; ps = ps.Next() {
+ if ps.timer != nil {
+ ps.timer.signalRejectedLocked()
+ }
+ }
+ q.pendingSignalList.Reset()
+ q.length = 0
+ p.pendingSet &^= linux.SignalSetOf(sig)
+}
diff --git a/pkg/sentry/kernel/pending_signals_list.go b/pkg/sentry/kernel/pending_signals_list.go
new file mode 100755
index 000000000..a3499371a
--- /dev/null
+++ b/pkg/sentry/kernel/pending_signals_list.go
@@ -0,0 +1,173 @@
+package kernel
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type pendingSignalElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (pendingSignalElementMapper) linkerFor(elem *pendingSignal) *pendingSignal { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type pendingSignalList struct {
+ head *pendingSignal
+ tail *pendingSignal
+}
+
+// Reset resets list l to the empty state.
+func (l *pendingSignalList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *pendingSignalList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *pendingSignalList) Front() *pendingSignal {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *pendingSignalList) Back() *pendingSignal {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *pendingSignalList) PushFront(e *pendingSignal) {
+ pendingSignalElementMapper{}.linkerFor(e).SetNext(l.head)
+ pendingSignalElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ pendingSignalElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *pendingSignalList) PushBack(e *pendingSignal) {
+ pendingSignalElementMapper{}.linkerFor(e).SetNext(nil)
+ pendingSignalElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *pendingSignalList) PushBackList(m *pendingSignalList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ pendingSignalElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *pendingSignalList) InsertAfter(b, e *pendingSignal) {
+ a := pendingSignalElementMapper{}.linkerFor(b).Next()
+ pendingSignalElementMapper{}.linkerFor(e).SetNext(a)
+ pendingSignalElementMapper{}.linkerFor(e).SetPrev(b)
+ pendingSignalElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ pendingSignalElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *pendingSignalList) InsertBefore(a, e *pendingSignal) {
+ b := pendingSignalElementMapper{}.linkerFor(a).Prev()
+ pendingSignalElementMapper{}.linkerFor(e).SetNext(a)
+ pendingSignalElementMapper{}.linkerFor(e).SetPrev(b)
+ pendingSignalElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ pendingSignalElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *pendingSignalList) Remove(e *pendingSignal) {
+ prev := pendingSignalElementMapper{}.linkerFor(e).Prev()
+ next := pendingSignalElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ pendingSignalElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ pendingSignalElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type pendingSignalEntry struct {
+ next *pendingSignal
+ prev *pendingSignal
+}
+
+// Next returns the entry that follows e in the list.
+func (e *pendingSignalEntry) Next() *pendingSignal {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *pendingSignalEntry) Prev() *pendingSignal {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *pendingSignalEntry) SetNext(elem *pendingSignal) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *pendingSignalEntry) SetPrev(elem *pendingSignal) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/pending_signals_state.go b/pkg/sentry/kernel/pending_signals_state.go
new file mode 100644
index 000000000..2c902c7e3
--- /dev/null
+++ b/pkg/sentry/kernel/pending_signals_state.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+)
+
+// +stateify savable
+type savedPendingSignal struct {
+ si *arch.SignalInfo
+ timer *IntervalTimer
+}
+
+// saveSignals is invoked by stateify.
+func (p *pendingSignals) saveSignals() []savedPendingSignal {
+ var pending []savedPendingSignal
+ for _, q := range p.signals {
+ for ps := q.pendingSignalList.Front(); ps != nil; ps = ps.Next() {
+ pending = append(pending, savedPendingSignal{
+ si: ps.SignalInfo,
+ timer: ps.timer,
+ })
+ }
+ }
+ return pending
+}
+
+// loadSignals is invoked by stateify.
+func (p *pendingSignals) loadSignals(pending []savedPendingSignal) {
+ for _, sps := range pending {
+ p.enqueue(sps.si, sps.timer)
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
new file mode 100644
index 000000000..4360dc44f
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/buffer.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+)
+
+// buffer encapsulates a queueable byte buffer.
+//
+// Note that the total size is slightly less than two pages. This
+// is done intentionally to ensure that the buffer object aligns
+// with runtime internals. We have no hard size or alignment
+// requirements. This two page size will effectively minimize
+// internal fragmentation, but still have a large enough chunk
+// to limit excessive segmentation.
+//
+// +stateify savable
+type buffer struct {
+ data [8144]byte
+ read int
+ write int
+ bufferEntry
+}
+
+// Reset resets internal data.
+//
+// This must be called before use.
+func (b *buffer) Reset() {
+ b.read = 0
+ b.write = 0
+}
+
+// Empty indicates the buffer is empty.
+//
+// This indicates there is no data left to read.
+func (b *buffer) Empty() bool {
+ return b.read == b.write
+}
+
+// Full indicates the buffer is full.
+//
+// This indicates there is no capacity left to write.
+func (b *buffer) Full() bool {
+ return b.write == len(b.data)
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.write:]))
+ n, err := safemem.CopySeq(dst, srcs)
+ b.write += int(n)
+ return n, err
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
+ n, err := safemem.CopySeq(dsts, src)
+ b.read += int(n)
+ return n, err
+}
+
+// bufferPool is a pool for buffers.
+var bufferPool = sync.Pool{
+ New: func() interface{} {
+ return new(buffer)
+ },
+}
+
+// newBuffer grabs a new buffer from the pool.
+func newBuffer() *buffer {
+ b := bufferPool.Get().(*buffer)
+ b.Reset()
+ return b
+}
diff --git a/pkg/sentry/kernel/pipe/buffer_list.go b/pkg/sentry/kernel/pipe/buffer_list.go
new file mode 100755
index 000000000..42ec78788
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/buffer_list.go
@@ -0,0 +1,173 @@
+package pipe
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type bufferElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (bufferElementMapper) linkerFor(elem *buffer) *buffer { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type bufferList struct {
+ head *buffer
+ tail *buffer
+}
+
+// Reset resets list l to the empty state.
+func (l *bufferList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *bufferList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *bufferList) Front() *buffer {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *bufferList) Back() *buffer {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *bufferList) PushFront(e *buffer) {
+ bufferElementMapper{}.linkerFor(e).SetNext(l.head)
+ bufferElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ bufferElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *bufferList) PushBack(e *buffer) {
+ bufferElementMapper{}.linkerFor(e).SetNext(nil)
+ bufferElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ bufferElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *bufferList) PushBackList(m *bufferList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ bufferElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ bufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *bufferList) InsertAfter(b, e *buffer) {
+ a := bufferElementMapper{}.linkerFor(b).Next()
+ bufferElementMapper{}.linkerFor(e).SetNext(a)
+ bufferElementMapper{}.linkerFor(e).SetPrev(b)
+ bufferElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ bufferElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *bufferList) InsertBefore(a, e *buffer) {
+ b := bufferElementMapper{}.linkerFor(a).Prev()
+ bufferElementMapper{}.linkerFor(e).SetNext(a)
+ bufferElementMapper{}.linkerFor(e).SetPrev(b)
+ bufferElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ bufferElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *bufferList) Remove(e *buffer) {
+ prev := bufferElementMapper{}.linkerFor(e).Prev()
+ next := bufferElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ bufferElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ bufferElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type bufferEntry struct {
+ next *buffer
+ prev *buffer
+}
+
+// Next returns the entry that follows e in the list.
+func (e *bufferEntry) Next() *buffer {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *bufferEntry) Prev() *buffer {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *bufferEntry) SetNext(elem *buffer) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *bufferEntry) SetPrev(elem *buffer) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/pipe/device.go b/pkg/sentry/kernel/pipe/device.go
new file mode 100644
index 000000000..eb59e15a1
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// pipeDevice is used for all pipe files.
+var pipeDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
new file mode 100644
index 000000000..926c4c623
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -0,0 +1,196 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/amutex"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// inodeOperations implements fs.InodeOperations for pipes.
+//
+// +stateify savable
+type inodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeNotVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // p is the underlying Pipe object representing this fifo.
+ p *Pipe
+
+ // Channels for synchronizing the creation of new readers and writers of
+ // this fifo. See waitFor and newHandleLocked.
+ //
+ // These are not saved/restored because all waiters are unblocked on save,
+ // and either automatically restart (via ERESTARTSYS) or return EINTR on
+ // resume. On restarts via ERESTARTSYS, the appropriate channel will be
+ // recreated.
+ rWakeup chan struct{} `state:"nosave"`
+ wWakeup chan struct{} `state:"nosave"`
+}
+
+var _ fs.InodeOperations = (*inodeOperations)(nil)
+
+// NewInodeOperations returns a new fs.InodeOperations for a given pipe.
+func NewInodeOperations(ctx context.Context, perms fs.FilePermissions, p *Pipe) *inodeOperations {
+ return &inodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), perms, linux.PIPEFS_MAGIC),
+ p: p,
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile. Named pipes have special blocking
+// semantics during open:
+//
+// "Normally, opening the FIFO blocks until the other end is opened also. A
+// process can open a FIFO in nonblocking mode. In this case, opening for
+// read-only will succeed even if no-one has opened on the write side yet,
+// opening for write-only will fail with ENXIO (no such device or address)
+// unless the other end has already been opened. Under Linux, opening a FIFO
+// for read and write will succeed both in blocking and nonblocking mode. POSIX
+// leaves this behavior undefined. This can be used to open a FIFO for writing
+// while there are no readers available." - fifo(7)
+func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ switch {
+ case flags.Read && !flags.Write: // O_RDONLY.
+ r := i.p.Open(ctx, flags)
+ i.newHandleLocked(&i.rWakeup)
+
+ if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
+ if !i.waitFor(&i.wWakeup, ctx) {
+ r.DecRef()
+ return nil, syserror.ErrInterrupted
+ }
+ }
+
+ // By now, either we're doing a nonblocking open or we have a writer. On
+ // a nonblocking read-only open, the open succeeds even if no-one has
+ // opened the write side yet.
+ return r, nil
+
+ case flags.Write && !flags.Read: // O_WRONLY.
+ w := i.p.Open(ctx, flags)
+ i.newHandleLocked(&i.wWakeup)
+
+ if i.p.isNamed && !i.p.HasReaders() {
+ // On a nonblocking, write-only open, the open fails with ENXIO if the
+ // read side isn't open yet.
+ if flags.NonBlocking {
+ w.DecRef()
+ return nil, syserror.ENXIO
+ }
+
+ if !i.waitFor(&i.rWakeup, ctx) {
+ w.DecRef()
+ return nil, syserror.ErrInterrupted
+ }
+ }
+ return w, nil
+
+ case flags.Read && flags.Write: // O_RDWR.
+ // Pipes opened for read-write always succeeds without blocking.
+ rw := i.p.Open(ctx, flags)
+ i.newHandleLocked(&i.rWakeup)
+ i.newHandleLocked(&i.wWakeup)
+ return rw, nil
+
+ default:
+ return nil, syserror.EINVAL
+ }
+}
+
+// waitFor blocks until the underlying pipe has at least one reader/writer is
+// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this
+// function will block for either readers or writers, depending on where
+// 'wakeupChan' points.
+//
+// f.mu must be held by the caller. waitFor returns with f.mu held, but it will
+// drop f.mu before blocking for any reader/writers.
+func (i *inodeOperations) waitFor(wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
+ // Ideally this function would simply use a condition variable. However, the
+ // wait needs to be interruptible via 'sleeper', so we must sychronize via a
+ // channel. The synchronization below relies on the fact that closing a
+ // channel unblocks all receives on the channel.
+
+ // Does an appropriate wakeup channel already exist? If not, create a new
+ // one. This is all done under f.mu to avoid races.
+ if *wakeupChan == nil {
+ *wakeupChan = make(chan struct{})
+ }
+
+ // Grab a local reference to the wakeup channel since it may disappear as
+ // soon as we drop f.mu.
+ wakeup := *wakeupChan
+
+ // Drop the lock and prepare to sleep.
+ i.mu.Unlock()
+ cancel := sleeper.SleepStart()
+
+ // Wait for either a new reader/write to be signalled via 'wakeup', or
+ // for the sleep to be cancelled.
+ select {
+ case <-wakeup:
+ sleeper.SleepFinish(true)
+ case <-cancel:
+ sleeper.SleepFinish(false)
+ }
+
+ // Take the lock and check if we were woken. If we were woken and
+ // interrupted, the former takes priority.
+ i.mu.Lock()
+ select {
+ case <-wakeup:
+ return true
+ default:
+ return false
+ }
+}
+
+// newHandleLocked signals a new pipe reader or writer depending on where
+// 'wakeupChan' points. This unblocks any corresponding reader or writer
+// waiting for the other end of the channel to be opened, see Fifo.waitFor.
+//
+// i.mu must be held.
+func (*inodeOperations) newHandleLocked(wakeupChan *chan struct{}) {
+ if *wakeupChan != nil {
+ close(*wakeupChan)
+ *wakeupChan = nil
+ }
+}
+
+func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
+ return syserror.EPIPE
+}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
new file mode 100644
index 000000000..b65204492
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -0,0 +1,429 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe provides a pipe implementation.
+package pipe
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // MinimumPipeSize is a hard limit of the minimum size of a pipe.
+ MinimumPipeSize = 64 << 10
+
+ // DefaultPipeSize is the system-wide default size of a pipe in bytes.
+ DefaultPipeSize = MinimumPipeSize
+
+ // MaximumPipeSize is a hard limit on the maximum size of a pipe.
+ MaximumPipeSize = 8 << 20
+)
+
+// Sizer is an interface for setting and getting the size of a pipe.
+//
+// It is implemented by Pipe and, through embedding, all other types.
+type Sizer interface {
+ // PipeSize returns the pipe capacity in bytes.
+ PipeSize() int64
+
+ // SetPipeSize sets the new pipe capacity in bytes.
+ //
+ // The new size is returned (which may be capped).
+ SetPipeSize(int64) (int64, error)
+}
+
+// Pipe is an encapsulation of a platform-independent pipe.
+// It manages a buffered byte queue shared between a reader/writer
+// pair.
+//
+// +stateify savable
+type Pipe struct {
+ waiter.Queue `state:"nosave"`
+
+ // isNamed indicates whether this is a named pipe.
+ //
+ // This value is immutable.
+ isNamed bool
+
+ // atomicIOBytes is the maximum number of bytes that the pipe will
+ // guarantee atomic reads or writes atomically.
+ //
+ // This value is immutable.
+ atomicIOBytes int64
+
+ // The dirent backing this pipe. Shared by all readers and writers.
+ //
+ // This value is immutable.
+ Dirent *fs.Dirent
+
+ // The number of active readers for this pipe.
+ //
+ // Access atomically.
+ readers int32
+
+ // The number of active writes for this pipe.
+ //
+ // Access atomically.
+ writers int32
+
+ // mu protects all pipe internal state below.
+ mu sync.Mutex `state:"nosave"`
+
+ // data is the buffer queue of pipe contents.
+ //
+ // This is protected by mu.
+ data bufferList
+
+ // max is the maximum size of the pipe in bytes. When this max has been
+ // reached, writers will get EWOULDBLOCK.
+ //
+ // This is protected by mu.
+ max int64
+
+ // size is the current size of the pipe in bytes.
+ //
+ // This is protected by mu.
+ size int64
+
+ // hadWriter indicates if this pipe ever had a writer. Note that this
+ // does not necessarily indicate there is *currently* a writer, just
+ // that there has been a writer at some point since the pipe was
+ // created.
+ //
+ // This is protected by mu.
+ hadWriter bool
+}
+
+// NewPipe initializes and returns a pipe.
+//
+// N.B. The size and atomicIOBytes will be bounded.
+func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
+ if sizeBytes < MinimumPipeSize {
+ sizeBytes = MinimumPipeSize
+ }
+ if sizeBytes > MaximumPipeSize {
+ sizeBytes = MaximumPipeSize
+ }
+ if atomicIOBytes <= 0 {
+ atomicIOBytes = 1
+ }
+ if atomicIOBytes > sizeBytes {
+ atomicIOBytes = sizeBytes
+ }
+ p := &Pipe{
+ isNamed: isNamed,
+ max: sizeBytes,
+ atomicIOBytes: atomicIOBytes,
+ }
+
+ // Build the fs.Dirent of this pipe, shared by all fs.Files associated
+ // with this pipe.
+ perms := fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ }
+ iops := NewInodeOperations(ctx, perms, p)
+ ino := pipeDevice.NextIno()
+ sattr := fs.StableAttr{
+ Type: fs.Pipe,
+ DeviceID: pipeDevice.DeviceID(),
+ InodeID: ino,
+ BlockSize: int64(atomicIOBytes),
+ }
+ ms := fs.NewPseudoMountSource()
+ p.Dirent = fs.NewDirent(fs.NewInode(iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino))
+ return p
+}
+
+// NewConnectedPipe initializes a pipe and returns a pair of objects
+// representing the read and write ends of the pipe.
+func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) {
+ p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes)
+ return p.Open(ctx, fs.FileFlags{Read: true}), p.Open(ctx, fs.FileFlags{Write: true})
+}
+
+// Open opens the pipe and returns a new file.
+//
+// Precondition: at least one of flags.Read or flags.Write must be set.
+func (p *Pipe) Open(ctx context.Context, flags fs.FileFlags) *fs.File {
+ switch {
+ case flags.Read && flags.Write:
+ p.rOpen()
+ p.wOpen()
+ return fs.NewFile(ctx, p.Dirent, flags, &ReaderWriter{
+ Pipe: p,
+ })
+ case flags.Read:
+ p.rOpen()
+ return fs.NewFile(ctx, p.Dirent, flags, &Reader{
+ ReaderWriter: ReaderWriter{Pipe: p},
+ })
+ case flags.Write:
+ p.wOpen()
+ return fs.NewFile(ctx, p.Dirent, flags, &Writer{
+ ReaderWriter: ReaderWriter{Pipe: p},
+ })
+ default:
+ // Precondition violated.
+ panic("invalid pipe flags")
+ }
+}
+
+// read reads data from the pipe into dst and returns the number of bytes
+// read, or returns ErrWouldBlock if the pipe is empty.
+//
+// Precondition: this pipe must have readers.
+func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ // Don't block for a zero-length read even if the pipe is empty.
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Is the pipe empty?
+ if p.size == 0 {
+ if !p.HasWriters() {
+ // There are no writers, return EOF.
+ return 0, nil
+ }
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Limit how much we consume.
+ if dst.NumBytes() > p.size {
+ dst = dst.TakeFirst64(p.size)
+ }
+
+ done := int64(0)
+ for dst.NumBytes() > 0 {
+ // Pop the first buffer.
+ first := p.data.Front()
+ if first == nil {
+ break
+ }
+
+ // Copy user data.
+ n, err := dst.CopyOutFrom(ctx, first)
+ done += int64(n)
+ p.size -= n
+ dst = dst.DropFirst64(n)
+
+ // Empty buffer?
+ if first.Empty() {
+ // Push to the free list.
+ p.data.Remove(first)
+ bufferPool.Put(first)
+ }
+
+ // Handle errors.
+ if err != nil {
+ return done, err
+ }
+ }
+
+ return done, nil
+}
+
+// write writes data from sv into the pipe and returns the number of bytes
+// written. If no bytes are written because the pipe is full (or has less than
+// atomicIOBytes free capacity), write returns ErrWouldBlock.
+//
+// Precondition: this pipe must have writers.
+func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Can't write to a pipe with no readers.
+ if !p.HasReaders() {
+ return 0, syscall.EPIPE
+ }
+
+ // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
+ // atomic, but requires no atomicity for writes larger than this.
+ wanted := src.NumBytes()
+ if avail := p.max - p.size; wanted > avail {
+ if wanted <= p.atomicIOBytes {
+ return 0, syserror.ErrWouldBlock
+ }
+ // Limit to the available capacity.
+ src = src.TakeFirst64(avail)
+ }
+
+ done := int64(0)
+ for src.NumBytes() > 0 {
+ // Need a new buffer?
+ last := p.data.Back()
+ if last == nil || last.Full() {
+ // Add a new buffer to the data list.
+ last = newBuffer()
+ p.data.PushBack(last)
+ }
+
+ // Copy user data.
+ n, err := src.CopyInTo(ctx, last)
+ done += int64(n)
+ p.size += n
+ src = src.DropFirst64(n)
+
+ // Handle errors.
+ if err != nil {
+ return done, err
+ }
+ }
+ if wanted > done {
+ // Partial write due to full pipe.
+ return done, syserror.ErrWouldBlock
+ }
+
+ return done, nil
+}
+
+// rOpen signals a new reader of the pipe.
+func (p *Pipe) rOpen() {
+ atomic.AddInt32(&p.readers, 1)
+}
+
+// wOpen signals a new writer of the pipe.
+func (p *Pipe) wOpen() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.hadWriter = true
+ atomic.AddInt32(&p.writers, 1)
+}
+
+// rClose signals that a reader has closed their end of the pipe.
+func (p *Pipe) rClose() {
+ newReaders := atomic.AddInt32(&p.readers, -1)
+ if newReaders < 0 {
+ panic(fmt.Sprintf("Refcounting bug, pipe has negative readers: %v", newReaders))
+ }
+}
+
+// wClose signals that a writer has closed their end of the pipe.
+func (p *Pipe) wClose() {
+ newWriters := atomic.AddInt32(&p.writers, -1)
+ if newWriters < 0 {
+ panic(fmt.Sprintf("Refcounting bug, pipe has negative writers: %v.", newWriters))
+ }
+}
+
+// HasReaders returns whether the pipe has any active readers.
+func (p *Pipe) HasReaders() bool {
+ return atomic.LoadInt32(&p.readers) > 0
+}
+
+// HasWriters returns whether the pipe has any active writers.
+func (p *Pipe) HasWriters() bool {
+ return atomic.LoadInt32(&p.writers) > 0
+}
+
+// rReadinessLocked calculates the read readiness.
+//
+// Precondition: mu must be held.
+func (p *Pipe) rReadinessLocked() waiter.EventMask {
+ ready := waiter.EventMask(0)
+ if p.HasReaders() && p.data.Front() != nil {
+ ready |= waiter.EventIn
+ }
+ if !p.HasWriters() && p.hadWriter {
+ // POLLHUP must be suppressed until the pipe has had at least one writer
+ // at some point. Otherwise a reader thread may poll and immediately get
+ // a POLLHUP before the writer ever opens the pipe, which the reader may
+ // interpret as the writer opening then closing the pipe.
+ ready |= waiter.EventHUp
+ }
+ return ready
+}
+
+// rReadiness returns a mask that states whether the read end of the pipe is
+// ready for reading.
+func (p *Pipe) rReadiness() waiter.EventMask {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.rReadinessLocked()
+}
+
+// wReadinessLocked calculates the write readiness.
+//
+// Precondition: mu must be held.
+func (p *Pipe) wReadinessLocked() waiter.EventMask {
+ ready := waiter.EventMask(0)
+ if p.HasWriters() && p.size < p.max {
+ ready |= waiter.EventOut
+ }
+ if !p.HasReaders() {
+ ready |= waiter.EventErr
+ }
+ return ready
+}
+
+// wReadiness returns a mask that states whether the write end of the pipe
+// is ready for writing.
+func (p *Pipe) wReadiness() waiter.EventMask {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.wReadinessLocked()
+}
+
+// rwReadiness returns a mask that states whether a read-write handle to the
+// pipe is ready for IO.
+func (p *Pipe) rwReadiness() waiter.EventMask {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.rReadinessLocked() | p.wReadinessLocked()
+}
+
+// queued returns the amount of queued data.
+func (p *Pipe) queued() int64 {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.size
+}
+
+// PipeSize implements PipeSizer.PipeSize.
+func (p *Pipe) PipeSize() int64 {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.max
+}
+
+// SetPipeSize implements PipeSize.SetPipeSize.
+func (p *Pipe) SetPipeSize(size int64) (int64, error) {
+ if size < 0 {
+ return 0, syserror.EINVAL
+ }
+ if size < MinimumPipeSize {
+ size = MinimumPipeSize // Per spec.
+ }
+ if size > MaximumPipeSize {
+ return 0, syserror.EPERM
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if size < p.size {
+ return 0, syserror.EBUSY
+ }
+ p.max = size
+ return size, nil
+}
diff --git a/pkg/sentry/kernel/pipe/pipe_state_autogen.go b/pkg/sentry/kernel/pipe/pipe_state_autogen.go
new file mode 100755
index 000000000..5d3686109
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_state_autogen.go
@@ -0,0 +1,134 @@
+// automatically generated by stateify.
+
+package pipe
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *buffer) beforeSave() {}
+func (x *buffer) save(m state.Map) {
+ x.beforeSave()
+ m.Save("data", &x.data)
+ m.Save("read", &x.read)
+ m.Save("write", &x.write)
+ m.Save("bufferEntry", &x.bufferEntry)
+}
+
+func (x *buffer) afterLoad() {}
+func (x *buffer) load(m state.Map) {
+ m.Load("data", &x.data)
+ m.Load("read", &x.read)
+ m.Load("write", &x.write)
+ m.Load("bufferEntry", &x.bufferEntry)
+}
+
+func (x *bufferList) beforeSave() {}
+func (x *bufferList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *bufferList) afterLoad() {}
+func (x *bufferList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *bufferEntry) beforeSave() {}
+func (x *bufferEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *bufferEntry) afterLoad() {}
+func (x *bufferEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *inodeOperations) beforeSave() {}
+func (x *inodeOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Save("p", &x.p)
+}
+
+func (x *inodeOperations) afterLoad() {}
+func (x *inodeOperations) load(m state.Map) {
+ m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes)
+ m.Load("p", &x.p)
+}
+
+func (x *Pipe) beforeSave() {}
+func (x *Pipe) save(m state.Map) {
+ x.beforeSave()
+ m.Save("isNamed", &x.isNamed)
+ m.Save("atomicIOBytes", &x.atomicIOBytes)
+ m.Save("Dirent", &x.Dirent)
+ m.Save("readers", &x.readers)
+ m.Save("writers", &x.writers)
+ m.Save("data", &x.data)
+ m.Save("max", &x.max)
+ m.Save("size", &x.size)
+ m.Save("hadWriter", &x.hadWriter)
+}
+
+func (x *Pipe) afterLoad() {}
+func (x *Pipe) load(m state.Map) {
+ m.Load("isNamed", &x.isNamed)
+ m.Load("atomicIOBytes", &x.atomicIOBytes)
+ m.Load("Dirent", &x.Dirent)
+ m.Load("readers", &x.readers)
+ m.Load("writers", &x.writers)
+ m.Load("data", &x.data)
+ m.Load("max", &x.max)
+ m.Load("size", &x.size)
+ m.Load("hadWriter", &x.hadWriter)
+}
+
+func (x *Reader) beforeSave() {}
+func (x *Reader) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ReaderWriter", &x.ReaderWriter)
+}
+
+func (x *Reader) afterLoad() {}
+func (x *Reader) load(m state.Map) {
+ m.Load("ReaderWriter", &x.ReaderWriter)
+}
+
+func (x *ReaderWriter) beforeSave() {}
+func (x *ReaderWriter) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Pipe", &x.Pipe)
+}
+
+func (x *ReaderWriter) afterLoad() {}
+func (x *ReaderWriter) load(m state.Map) {
+ m.Load("Pipe", &x.Pipe)
+}
+
+func (x *Writer) beforeSave() {}
+func (x *Writer) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ReaderWriter", &x.ReaderWriter)
+}
+
+func (x *Writer) afterLoad() {}
+func (x *Writer) load(m state.Map) {
+ m.Load("ReaderWriter", &x.ReaderWriter)
+}
+
+func init() {
+ state.Register("pipe.buffer", (*buffer)(nil), state.Fns{Save: (*buffer).save, Load: (*buffer).load})
+ state.Register("pipe.bufferList", (*bufferList)(nil), state.Fns{Save: (*bufferList).save, Load: (*bufferList).load})
+ state.Register("pipe.bufferEntry", (*bufferEntry)(nil), state.Fns{Save: (*bufferEntry).save, Load: (*bufferEntry).load})
+ state.Register("pipe.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load})
+ state.Register("pipe.Pipe", (*Pipe)(nil), state.Fns{Save: (*Pipe).save, Load: (*Pipe).load})
+ state.Register("pipe.Reader", (*Reader)(nil), state.Fns{Save: (*Reader).save, Load: (*Reader).load})
+ state.Register("pipe.ReaderWriter", (*ReaderWriter)(nil), state.Fns{Save: (*ReaderWriter).save, Load: (*ReaderWriter).load})
+ state.Register("pipe.Writer", (*Writer)(nil), state.Fns{Save: (*Writer).save, Load: (*Writer).load})
+}
diff --git a/pkg/sentry/kernel/pipe/reader.go b/pkg/sentry/kernel/pipe/reader.go
new file mode 100644
index 000000000..656be824d
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/reader.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Reader satisfies the fs.FileOperations interface for read-only pipes.
+// Reader should be used with !fs.FileFlags.Write to reject writes.
+//
+// +stateify savable
+type Reader struct {
+ ReaderWriter
+}
+
+// Release implements fs.FileOperations.Release.
+//
+// This overrides ReaderWriter.Release.
+func (r *Reader) Release() {
+ r.Pipe.rClose()
+
+ // Wake up writers.
+ r.Pipe.Notify(waiter.EventOut)
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (r *Reader) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return r.Pipe.rReadiness() & mask
+}
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
new file mode 100644
index 000000000..e560b9be9
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -0,0 +1,96 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "math"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// ReaderWriter satisfies the FileOperations interface and services both
+// read and write requests. This should only be used directly for named pipes.
+// pipe(2) and pipe2(2) only support unidirectional pipes and should use
+// either pipe.Reader or pipe.Writer.
+//
+// +stateify savable
+type ReaderWriter struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ *Pipe
+}
+
+// Release implements fs.FileOperations.Release.
+func (rw *ReaderWriter) Release() {
+ rw.Pipe.rClose()
+ rw.Pipe.wClose()
+
+ // Wake up readers and writers.
+ rw.Pipe.Notify(waiter.EventIn | waiter.EventOut)
+}
+
+// Read implements fs.FileOperations.Read.
+func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ n, err := rw.Pipe.read(ctx, dst)
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// Write implements fs.FileOperations.Write.
+func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := rw.Pipe.write(ctx, src)
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (rw *ReaderWriter) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return rw.Pipe.rwReadiness() & mask
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (rw *ReaderWriter) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Switch on ioctl request.
+ switch int(args[1].Int()) {
+ case linux.FIONREAD:
+ v := rw.queued()
+ if v > math.MaxInt32 {
+ v = math.MaxInt32 // Silently truncate.
+ }
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ default:
+ return 0, syscall.ENOTTY
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/writer.go b/pkg/sentry/kernel/pipe/writer.go
new file mode 100644
index 000000000..8d5b68541
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/writer.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Writer satisfies the fs.FileOperations interface for write-only pipes.
+// Writer should be used with !fs.FileFlags.Read to reject reads.
+//
+// +stateify savable
+type Writer struct {
+ ReaderWriter
+}
+
+// Release implements fs.FileOperations.Release.
+//
+// This overrides ReaderWriter.Release.
+func (w *Writer) Release() {
+ w.Pipe.wClose()
+
+ // Wake up readers.
+ w.Pipe.Notify(waiter.EventHUp)
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (w *Writer) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return w.Pipe.wReadiness() & mask
+}
diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go
new file mode 100644
index 000000000..a016b4087
--- /dev/null
+++ b/pkg/sentry/kernel/posixtimer.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// IntervalTimer represents a POSIX interval timer as described by
+// timer_create(2).
+//
+// +stateify savable
+type IntervalTimer struct {
+ timer *ktime.Timer
+
+ // If target is not nil, it receives signo from timer expirations. If group
+ // is true, these signals are thread-group-directed. These fields are
+ // immutable.
+ target *Task
+ signo linux.Signal
+ id linux.TimerID
+ sigval uint64
+ group bool
+
+ // If sigpending is true, a signal to target is already queued, and timer
+ // expirations should increment overrunCur instead of sending another
+ // signal. sigpending is protected by target's signal mutex. (If target is
+ // nil, the timer will never send signals, so sigpending will be unused.)
+ sigpending bool
+
+ // If sigorphan is true, timer's setting has been changed since sigpending
+ // last became true, such that overruns should no longer be counted in the
+ // pending signals si_overrun. sigorphan is protected by target's signal
+ // mutex.
+ sigorphan bool
+
+ // overrunCur is the number of overruns that have occurred since the last
+ // time a signal was sent. overrunCur is protected by target's signal
+ // mutex.
+ overrunCur uint64
+
+ // Consider the last signal sent by this timer that has been dequeued.
+ // overrunLast is the number of overruns that occurred between when this
+ // signal was sent and when it was dequeued. Equivalently, overrunLast was
+ // the value of overrunCur when this signal was dequeued. overrunLast is
+ // protected by target's signal mutex.
+ overrunLast uint64
+}
+
+// DestroyTimer releases it's resources.
+func (it *IntervalTimer) DestroyTimer() {
+ it.timer.Destroy()
+ it.timerSettingChanged()
+ // A destroyed IntervalTimer is still potentially reachable via a
+ // pendingSignal; nil out timer so that it won't be saved.
+ it.timer = nil
+}
+
+func (it *IntervalTimer) timerSettingChanged() {
+ if it.target == nil {
+ return
+ }
+ it.target.tg.pidns.owner.mu.RLock()
+ defer it.target.tg.pidns.owner.mu.RUnlock()
+ it.target.tg.signalHandlers.mu.Lock()
+ defer it.target.tg.signalHandlers.mu.Unlock()
+ it.sigorphan = true
+ it.overrunCur = 0
+ it.overrunLast = 0
+}
+
+// PauseTimer pauses the associated Timer.
+func (it *IntervalTimer) PauseTimer() {
+ it.timer.Pause()
+}
+
+// ResumeTimer resumes the associated Timer.
+func (it *IntervalTimer) ResumeTimer() {
+ it.timer.Resume()
+}
+
+// Preconditions: it.target's signal mutex must be locked.
+func (it *IntervalTimer) updateDequeuedSignalLocked(si *arch.SignalInfo) {
+ it.sigpending = false
+ if it.sigorphan {
+ return
+ }
+ it.overrunLast = it.overrunCur
+ it.overrunCur = 0
+ si.SetOverrun(saturateI32FromU64(it.overrunLast))
+}
+
+// Preconditions: it.target's signal mutex must be locked.
+func (it *IntervalTimer) signalRejectedLocked() {
+ it.sigpending = false
+ if it.sigorphan {
+ return
+ }
+ it.overrunCur++
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (it *IntervalTimer) Notify(exp uint64) {
+ if it.target == nil {
+ return
+ }
+
+ it.target.tg.pidns.owner.mu.RLock()
+ defer it.target.tg.pidns.owner.mu.RUnlock()
+ it.target.tg.signalHandlers.mu.Lock()
+ defer it.target.tg.signalHandlers.mu.Unlock()
+
+ if it.sigpending {
+ it.overrunCur += exp
+ return
+ }
+
+ // sigpending must be set before sendSignalTimerLocked() so that it can be
+ // unset if the signal is discarded (in which case sendSignalTimerLocked()
+ // will return nil).
+ it.sigpending = true
+ it.sigorphan = false
+ it.overrunCur += exp - 1
+ si := &arch.SignalInfo{
+ Signo: int32(it.signo),
+ Code: arch.SignalInfoTimer,
+ }
+ si.SetTimerID(it.id)
+ si.SetSigval(it.sigval)
+ // si_overrun is set when the signal is dequeued.
+ if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil {
+ it.signalRejectedLocked()
+ }
+}
+
+// Destroy implements ktime.TimerListener.Destroy. Users of Timer should call
+// DestroyTimer instead.
+func (it *IntervalTimer) Destroy() {
+}
+
+// IntervalTimerCreate implements timer_create(2).
+func (t *Task) IntervalTimerCreate(c ktime.Clock, sigev *linux.Sigevent) (linux.TimerID, error) {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+
+ // Allocate a timer ID.
+ var id linux.TimerID
+ end := t.tg.nextTimerID
+ for {
+ id = t.tg.nextTimerID
+ _, ok := t.tg.timers[id]
+ t.tg.nextTimerID++
+ if t.tg.nextTimerID < 0 {
+ t.tg.nextTimerID = 0
+ }
+ if !ok {
+ break
+ }
+ if t.tg.nextTimerID == end {
+ return 0, syserror.EAGAIN
+ }
+ }
+
+ // "The implementation of the default case where evp [sic] is NULL is
+ // handled inside glibc, which invokes the underlying system call with a
+ // suitably populated sigevent structure." - timer_create(2). This is
+ // misleading; the timer_create syscall also handles a NULL sevp as
+ // described by the man page
+ // (kernel/time/posix-timers.c:sys_timer_create(), do_timer_create()). This
+ // must be handled here instead of the syscall wrapper since sigval is the
+ // timer ID, which isn't available until we allocate it in this function.
+ if sigev == nil {
+ sigev = &linux.Sigevent{
+ Signo: int32(linux.SIGALRM),
+ Notify: linux.SIGEV_SIGNAL,
+ Value: uint64(id),
+ }
+ }
+
+ // Construct the timer.
+ it := &IntervalTimer{
+ id: id,
+ sigval: sigev.Value,
+ }
+ switch sigev.Notify {
+ case linux.SIGEV_NONE:
+ // leave it.target = nil
+ case linux.SIGEV_SIGNAL, linux.SIGEV_THREAD:
+ // POSIX SIGEV_THREAD semantics are implemented in userspace by libc;
+ // to the kernel, SIGEV_THREAD and SIGEV_SIGNAL are equivalent. (See
+ // Linux's kernel/time/posix-timers.c:good_sigevent().)
+ it.target = t.tg.leader
+ it.group = true
+ case linux.SIGEV_THREAD_ID:
+ t.tg.pidns.owner.mu.RLock()
+ target, ok := t.tg.pidns.tasks[ThreadID(sigev.Tid)]
+ t.tg.pidns.owner.mu.RUnlock()
+ if !ok || target.tg != t.tg {
+ return 0, syserror.EINVAL
+ }
+ it.target = target
+ default:
+ return 0, syserror.EINVAL
+ }
+ if sigev.Notify != linux.SIGEV_NONE {
+ it.signo = linux.Signal(sigev.Signo)
+ if !it.signo.IsValid() {
+ return 0, syserror.EINVAL
+ }
+ }
+ it.timer = ktime.NewTimer(c, it)
+
+ t.tg.timers[id] = it
+ return id, nil
+}
+
+// IntervalTimerDelete implements timer_delete(2).
+func (t *Task) IntervalTimerDelete(id linux.TimerID) error {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+ it := t.tg.timers[id]
+ if it == nil {
+ return syserror.EINVAL
+ }
+ delete(t.tg.timers, id)
+ it.DestroyTimer()
+ return nil
+}
+
+// IntervalTimerSettime implements timer_settime(2).
+func (t *Task) IntervalTimerSettime(id linux.TimerID, its linux.Itimerspec, abs bool) (linux.Itimerspec, error) {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+ it := t.tg.timers[id]
+ if it == nil {
+ return linux.Itimerspec{}, syserror.EINVAL
+ }
+
+ newS, err := ktime.SettingFromItimerspec(its, abs, it.timer.Clock())
+ if err != nil {
+ return linux.Itimerspec{}, err
+ }
+ tm, oldS := it.timer.SwapAnd(newS, it.timerSettingChanged)
+ its = ktime.ItimerspecFromSetting(tm, oldS)
+ return its, nil
+}
+
+// IntervalTimerGettime implements timer_gettime(2).
+func (t *Task) IntervalTimerGettime(id linux.TimerID) (linux.Itimerspec, error) {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+ it := t.tg.timers[id]
+ if it == nil {
+ return linux.Itimerspec{}, syserror.EINVAL
+ }
+
+ tm, s := it.timer.Get()
+ its := ktime.ItimerspecFromSetting(tm, s)
+ return its, nil
+}
+
+// IntervalTimerGetoverrun implements timer_getoverrun(2).
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) IntervalTimerGetoverrun(id linux.TimerID) (int32, error) {
+ t.tg.timerMu.Lock()
+ defer t.tg.timerMu.Unlock()
+ it := t.tg.timers[id]
+ if it == nil {
+ return 0, syserror.EINVAL
+ }
+ // By timer_create(2) invariant, either it.target == nil (in which case
+ // it.overrunLast is immutably 0) or t.tg == it.target.tg; and the fact
+ // that t is executing timer_getoverrun(2) means that t.tg can't be
+ // completing execve, so t.tg.signalHandlers can't be changing, allowing us
+ // to lock t.tg.signalHandlers.mu without holding the TaskSet mutex.
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ // This is consistent with Linux after 78c9c4dfbf8c ("posix-timers:
+ // Sanitize overrun handling").
+ return saturateI32FromU64(it.overrunLast), nil
+}
+
+func saturateI32FromU64(x uint64) int32 {
+ if x > math.MaxInt32 {
+ return math.MaxInt32
+ }
+ return int32(x)
+}
diff --git a/pkg/sentry/kernel/process_group_list.go b/pkg/sentry/kernel/process_group_list.go
new file mode 100755
index 000000000..853145237
--- /dev/null
+++ b/pkg/sentry/kernel/process_group_list.go
@@ -0,0 +1,173 @@
+package kernel
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type processGroupElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (processGroupElementMapper) linkerFor(elem *ProcessGroup) *ProcessGroup { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type processGroupList struct {
+ head *ProcessGroup
+ tail *ProcessGroup
+}
+
+// Reset resets list l to the empty state.
+func (l *processGroupList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *processGroupList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *processGroupList) Front() *ProcessGroup {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *processGroupList) Back() *ProcessGroup {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *processGroupList) PushFront(e *ProcessGroup) {
+ processGroupElementMapper{}.linkerFor(e).SetNext(l.head)
+ processGroupElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ processGroupElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *processGroupList) PushBack(e *ProcessGroup) {
+ processGroupElementMapper{}.linkerFor(e).SetNext(nil)
+ processGroupElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ processGroupElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *processGroupList) PushBackList(m *processGroupList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ processGroupElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ processGroupElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *processGroupList) InsertAfter(b, e *ProcessGroup) {
+ a := processGroupElementMapper{}.linkerFor(b).Next()
+ processGroupElementMapper{}.linkerFor(e).SetNext(a)
+ processGroupElementMapper{}.linkerFor(e).SetPrev(b)
+ processGroupElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ processGroupElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *processGroupList) InsertBefore(a, e *ProcessGroup) {
+ b := processGroupElementMapper{}.linkerFor(a).Prev()
+ processGroupElementMapper{}.linkerFor(e).SetNext(a)
+ processGroupElementMapper{}.linkerFor(e).SetPrev(b)
+ processGroupElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ processGroupElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *processGroupList) Remove(e *ProcessGroup) {
+ prev := processGroupElementMapper{}.linkerFor(e).Prev()
+ next := processGroupElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ processGroupElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ processGroupElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type processGroupEntry struct {
+ next *ProcessGroup
+ prev *ProcessGroup
+}
+
+// Next returns the entry that follows e in the list.
+func (e *processGroupEntry) Next() *ProcessGroup {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *processGroupEntry) Prev() *ProcessGroup {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *processGroupEntry) SetNext(elem *ProcessGroup) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *processGroupEntry) SetPrev(elem *ProcessGroup) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
new file mode 100644
index 000000000..4423e7efd
--- /dev/null
+++ b/pkg/sentry/kernel/ptrace.go
@@ -0,0 +1,1105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// ptraceOptions are the subset of options controlling a task's ptrace behavior
+// that are set by ptrace(PTRACE_SETOPTIONS).
+//
+// +stateify savable
+type ptraceOptions struct {
+ // ExitKill is true if the tracee should be sent SIGKILL when the tracer
+ // exits.
+ ExitKill bool
+
+ // If SysGood is true, set bit 7 in the signal number for
+ // syscall-entry-stop and syscall-exit-stop traps delivered to this task's
+ // tracer.
+ SysGood bool
+
+ // TraceClone is true if the tracer wants to receive PTRACE_EVENT_CLONE
+ // events.
+ TraceClone bool
+
+ // TraceExec is true if the tracer wants to receive PTRACE_EVENT_EXEC
+ // events.
+ TraceExec bool
+
+ // TraceExit is true if the tracer wants to receive PTRACE_EVENT_EXIT
+ // events.
+ TraceExit bool
+
+ // TraceFork is true if the tracer wants to receive PTRACE_EVENT_FORK
+ // events.
+ TraceFork bool
+
+ // TraceSeccomp is true if the tracer wants to receive PTRACE_EVENT_SECCOMP
+ // events.
+ TraceSeccomp bool
+
+ // TraceVfork is true if the tracer wants to receive PTRACE_EVENT_VFORK
+ // events.
+ TraceVfork bool
+
+ // TraceVforkDone is true if the tracer wants to receive
+ // PTRACE_EVENT_VFORK_DONE events.
+ TraceVforkDone bool
+}
+
+// ptraceSyscallMode controls the behavior of a ptraced task at syscall entry
+// and exit.
+type ptraceSyscallMode int
+
+const (
+ // ptraceSyscallNone indicates that the task has never ptrace-stopped, or
+ // that it was resumed from its last ptrace-stop by PTRACE_CONT or
+ // PTRACE_DETACH. The task's syscalls will not be intercepted.
+ ptraceSyscallNone ptraceSyscallMode = iota
+
+ // ptraceSyscallIntercept indicates that the task was resumed from its last
+ // ptrace-stop by PTRACE_SYSCALL. The next time the task enters or exits a
+ // syscall, a ptrace-stop will occur.
+ ptraceSyscallIntercept
+
+ // ptraceSyscallEmu indicates that the task was resumed from its last
+ // ptrace-stop by PTRACE_SYSEMU or PTRACE_SYSEMU_SINGLESTEP. The next time
+ // the task enters a syscall, the syscall will be skipped, and a
+ // ptrace-stop will occur.
+ ptraceSyscallEmu
+)
+
+// CanTrace checks that t is permitted to access target's state, as defined by
+// ptrace(2), subsection "Ptrace access mode checking". If attach is true, it
+// checks for access mode PTRACE_MODE_ATTACH; otherwise, it checks for access
+// mode PTRACE_MODE_READ.
+func (t *Task) CanTrace(target *Task, attach bool) bool {
+ // "1. If the calling thread and the target thread are in the same thread
+ // group, access is always allowed." - ptrace(2)
+ //
+ // Note: Strictly speaking, prior to 73af963f9f30 ("__ptrace_may_access()
+ // should not deny sub-threads", first released in Linux 3.12), the rule
+ // only applies if t and target are the same task. But, as that commit
+ // message puts it, "[any] security check is pointless when the tasks share
+ // the same ->mm."
+ if t.tg == target.tg {
+ return true
+ }
+
+ // """
+ // 2. If the access mode specifies PTRACE_MODE_FSCREDS (ED: snipped,
+ // doesn't exist until Linux 4.5).
+ //
+ // Otherwise, the access mode specifies PTRACE_MODE_REALCREDS, so use the
+ // caller's real UID and GID for the checks in the next step. (Most APIs
+ // that check the caller's UID and GID use the effective IDs. For
+ // historical reasons, the PTRACE_MODE_REALCREDS check uses the real IDs
+ // instead.)
+ //
+ // 3. Deny access if neither of the following is true:
+ //
+ // - The real, effective, and saved-set user IDs of the target match the
+ // caller's user ID, *and* the real, effective, and saved-set group IDs of
+ // the target match the caller's group ID.
+ //
+ // - The caller has the CAP_SYS_PTRACE capability in the user namespace of
+ // the target.
+ //
+ // 4. Deny access if the target process "dumpable" attribute has a value
+ // other than 1 (SUID_DUMP_USER; see the discussion of PR_SET_DUMPABLE in
+ // prctl(2)), and the caller does not have the CAP_SYS_PTRACE capability in
+ // the user namespace of the target process.
+ //
+ // 5. The kernel LSM security_ptrace_access_check() interface is invoked to
+ // see if ptrace access is permitted. The results depend on the LSM(s). The
+ // implementation of this interface in the commoncap LSM performs the
+ // following steps:
+ //
+ // a) If the access mode includes PTRACE_MODE_FSCREDS, then use the
+ // caller's effective capability set; otherwise (the access mode specifies
+ // PTRACE_MODE_REALCREDS, so) use the caller's permitted capability set.
+ //
+ // b) Deny access if neither of the following is true:
+ //
+ // - The caller and the target process are in the same user namespace, and
+ // the caller's capabilities are a proper superset of the target process's
+ // permitted capabilities.
+ //
+ // - The caller has the CAP_SYS_PTRACE capability in the target process's
+ // user namespace.
+ //
+ // Note that the commoncap LSM does not distinguish between
+ // PTRACE_MODE_READ and PTRACE_MODE_ATTACH. (ED: From earlier in this
+ // section: "the commoncap LSM ... is always invoked".)
+ // """
+ callerCreds := t.Credentials()
+ targetCreds := target.Credentials()
+ if callerCreds.HasCapabilityIn(linux.CAP_SYS_PTRACE, targetCreds.UserNamespace) {
+ return true
+ }
+ if cuid := callerCreds.RealKUID; cuid != targetCreds.RealKUID || cuid != targetCreds.EffectiveKUID || cuid != targetCreds.SavedKUID {
+ return false
+ }
+ if cgid := callerCreds.RealKGID; cgid != targetCreds.RealKGID || cgid != targetCreds.EffectiveKGID || cgid != targetCreds.SavedKGID {
+ return false
+ }
+ // TODO(b/31916171): dumpability check
+ if callerCreds.UserNamespace != targetCreds.UserNamespace {
+ return false
+ }
+ if targetCreds.PermittedCaps&^callerCreds.PermittedCaps != 0 {
+ return false
+ }
+ // TODO: Yama LSM
+ return true
+}
+
+// Tracer returns t's ptrace Tracer.
+func (t *Task) Tracer() *Task {
+ return t.ptraceTracer.Load().(*Task)
+}
+
+// hasTracer returns true if t has a ptrace tracer attached.
+func (t *Task) hasTracer() bool {
+ // This isn't just inlined into callers so that if Task.Tracer() turns out
+ // to be too expensive because of e.g. interface conversion, we can switch
+ // to having a separate atomic flag more easily.
+ return t.Tracer() != nil
+}
+
+// ptraceStop is a TaskStop placed on tasks in a ptrace-stop.
+//
+// +stateify savable
+type ptraceStop struct {
+ // If frozen is true, the stopped task's tracer is currently operating on
+ // it, so Task.Kill should not remove the stop.
+ frozen bool
+
+ // If listen is true, the stopped task's tracer invoked PTRACE_LISTEN, so
+ // ptraceFreeze should fail.
+ listen bool
+}
+
+// Killable implements TaskStop.Killable.
+func (s *ptraceStop) Killable() bool {
+ return !s.frozen
+}
+
+// beginPtraceStopLocked initiates an unfrozen ptrace-stop on t. If t has been
+// killed, the stop is skipped, and beginPtraceStopLocked returns false.
+//
+// beginPtraceStopLocked does not signal t's tracer or wake it if it is
+// waiting.
+//
+// Preconditions: The TaskSet mutex must be locked. The caller must be running
+// on the task goroutine.
+func (t *Task) beginPtraceStopLocked() bool {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ // This is analogous to Linux's kernel/signal.c:ptrace_stop() => ... =>
+ // kernel/sched/core.c:__schedule() => signal_pending_state() check, which
+ // is what prevents tasks from entering ptrace-stops after being killed.
+ // Note that if t was SIGKILLed and beingPtraceStopLocked is being called
+ // for PTRACE_EVENT_EXIT, the task will have dequeued the signal before
+ // entering the exit path, so t.killedLocked() will no longer return true.
+ // This is consistent with Linux: "Bugs: ... A SIGKILL signal may still
+ // cause a PTRACE_EVENT_EXIT stop before actual signal death. This may be
+ // changed in the future; SIGKILL is meant to always immediately kill tasks
+ // even under ptrace. Last confirmed on Linux 3.13." - ptrace(2)
+ if t.killedLocked() {
+ return false
+ }
+ t.beginInternalStopLocked(&ptraceStop{})
+ return true
+}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) ptraceTrapLocked(code int32) {
+ // This is unconditional in ptrace_stop().
+ t.tg.signalHandlers.mu.Lock()
+ t.trapStopPending = false
+ t.tg.signalHandlers.mu.Unlock()
+ t.ptraceCode = code
+ t.ptraceSiginfo = &arch.SignalInfo{
+ Signo: int32(linux.SIGTRAP),
+ Code: code,
+ }
+ t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ if t.beginPtraceStopLocked() {
+ tracer := t.Tracer()
+ tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP))
+ tracer.tg.eventQueue.Notify(EventTraceeStop)
+ }
+}
+
+// ptraceFreeze checks if t is in a ptraceStop. If so, it freezes the
+// ptraceStop, temporarily preventing it from being removed by a concurrent
+// Task.Kill, and returns true. Otherwise it returns false.
+//
+// Preconditions: The TaskSet mutex must be locked. The caller must be running
+// on the task goroutine of t's tracer.
+func (t *Task) ptraceFreeze() bool {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if t.stop == nil {
+ return false
+ }
+ s, ok := t.stop.(*ptraceStop)
+ if !ok {
+ return false
+ }
+ if s.listen {
+ return false
+ }
+ s.frozen = true
+ return true
+}
+
+// ptraceUnfreeze ends the effect of a previous successful call to
+// ptraceFreeze.
+//
+// Preconditions: t must be in a frozen ptraceStop.
+func (t *Task) ptraceUnfreeze() {
+ // t.tg.signalHandlers is stable because t is in a frozen ptrace-stop,
+ // preventing its thread group from completing execve.
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.ptraceUnfreezeLocked()
+}
+
+// Preconditions: t must be in a frozen ptraceStop. t's signal mutex must be
+// locked.
+func (t *Task) ptraceUnfreezeLocked() {
+ // Do this even if the task has been killed to ensure a panic if t.stop is
+ // nil or not a ptraceStop.
+ t.stop.(*ptraceStop).frozen = false
+ if t.killedLocked() {
+ t.endInternalStopLocked()
+ }
+}
+
+// ptraceUnstop implements ptrace request PTRACE_CONT, PTRACE_SYSCALL,
+// PTRACE_SINGLESTEP, PTRACE_SYSEMU, or PTRACE_SYSEMU_SINGLESTEP depending on
+// mode and singlestep.
+//
+// Preconditions: t must be in a frozen ptrace stop.
+//
+// Postconditions: If ptraceUnstop returns nil, t will no longer be in a ptrace
+// stop.
+func (t *Task) ptraceUnstop(mode ptraceSyscallMode, singlestep bool, sig linux.Signal) error {
+ if sig != 0 && !sig.IsValid() {
+ return syserror.EIO
+ }
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ t.ptraceCode = int32(sig)
+ t.ptraceSyscallMode = mode
+ t.ptraceSinglestep = singlestep
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.endInternalStopLocked()
+ return nil
+}
+
+func (t *Task) ptraceTraceme() error {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ if t.hasTracer() {
+ return syserror.EPERM
+ }
+ if t.parent == nil {
+ // In Linux, only init can not have a parent, and init is assumed never
+ // to invoke PTRACE_TRACEME. In the sentry, TGID 1 is an arbitrary user
+ // application that may invoke PTRACE_TRACEME; having no parent can
+ // also occur if all tasks in the parent thread group have exited, and
+ // failed to find a living thread group to reparent to. The former case
+ // is treated as if TGID 1 has an exited parent in an invisible
+ // ancestor PID namespace that is an owner of the root user namespace
+ // (and consequently has CAP_SYS_PTRACE), and the latter case is a
+ // special form of the exited parent case below. In either case,
+ // returning nil here is correct.
+ return nil
+ }
+ if !t.parent.CanTrace(t, true) {
+ return syserror.EPERM
+ }
+ if t.parent.exitState != TaskExitNone {
+ // Fail silently, as if we were successfully attached but then
+ // immediately detached. This is consistent with Linux.
+ return nil
+ }
+ t.ptraceTracer.Store(t.parent)
+ t.parent.ptraceTracees[t] = struct{}{}
+ return nil
+}
+
+// ptraceAttach implements ptrace(PTRACE_ATTACH, target) if seize is false, and
+// ptrace(PTRACE_SEIZE, target, 0, opts) if seize is true. t is the caller.
+func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error {
+ if t.tg == target.tg {
+ return syserror.EPERM
+ }
+ if !t.CanTrace(target, true) {
+ return syserror.EPERM
+ }
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ if target.hasTracer() {
+ return syserror.EPERM
+ }
+ // Attaching to zombies and dead tasks is not permitted; the exit
+ // notification logic relies on this. Linux allows attaching to PF_EXITING
+ // tasks, though.
+ if target.exitState >= TaskExitZombie {
+ return syserror.EPERM
+ }
+ if seize {
+ if err := target.ptraceSetOptionsLocked(opts); err != nil {
+ return syserror.EIO
+ }
+ }
+ target.ptraceTracer.Store(t)
+ t.ptraceTracees[target] = struct{}{}
+ target.ptraceSeized = seize
+ target.tg.signalHandlers.mu.Lock()
+ // "Unlike PTRACE_ATTACH, PTRACE_SEIZE does not stop the process." -
+ // ptrace(2)
+ if !seize {
+ target.sendSignalLocked(&arch.SignalInfo{
+ Signo: int32(linux.SIGSTOP),
+ Code: arch.SignalInfoUser,
+ }, false /* group */)
+ }
+ // Undocumented Linux feature: If the tracee is already group-stopped (and
+ // consequently will not report the SIGSTOP just sent), force it to leave
+ // and re-enter the stop so that it will switch to a ptrace-stop.
+ if target.stop == (*groupStop)(nil) {
+ target.trapStopPending = true
+ target.endInternalStopLocked()
+ // TODO(jamieliu): Linux blocks ptrace_attach() until the task has
+ // entered the ptrace-stop (or exited) via JOBCTL_TRAPPING.
+ }
+ target.tg.signalHandlers.mu.Unlock()
+ return nil
+}
+
+// ptraceDetach implements ptrace(PTRACE_DETACH, target, 0, sig). t is the
+// caller.
+//
+// Preconditions: target must be a tracee of t in a frozen ptrace stop.
+//
+// Postconditions: If ptraceDetach returns nil, target will no longer be in a
+// ptrace stop.
+func (t *Task) ptraceDetach(target *Task, sig linux.Signal) error {
+ if sig != 0 && !sig.IsValid() {
+ return syserror.EIO
+ }
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ target.ptraceCode = int32(sig)
+ target.forgetTracerLocked()
+ delete(t.ptraceTracees, target)
+ return nil
+}
+
+// exitPtrace is called in the exit path to detach all of t's tracees.
+func (t *Task) exitPtrace() {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ for target := range t.ptraceTracees {
+ if target.ptraceOpts.ExitKill {
+ target.tg.signalHandlers.mu.Lock()
+ target.sendSignalLocked(&arch.SignalInfo{
+ Signo: int32(linux.SIGKILL),
+ }, false /* group */)
+ target.tg.signalHandlers.mu.Unlock()
+ }
+ // Leave ptraceCode unchanged so that if the task is ptrace-stopped, it
+ // observes the ptraceCode it set before it entered the stop. I believe
+ // this is consistent with Linux.
+ target.forgetTracerLocked()
+ }
+ // "nil maps cannot be saved"
+ t.ptraceTracees = make(map[*Task]struct{})
+}
+
+// forgetTracerLocked detaches t's tracer and ensures that t is no longer
+// ptrace-stopped.
+//
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) forgetTracerLocked() {
+ t.ptraceSeized = false
+ t.ptraceOpts = ptraceOptions{}
+ t.ptraceSyscallMode = ptraceSyscallNone
+ t.ptraceSinglestep = false
+ t.ptraceTracer.Store((*Task)(nil))
+ if t.exitTracerNotified && !t.exitTracerAcked {
+ t.exitTracerAcked = true
+ t.exitNotifyLocked(true)
+ }
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ // Unset t.trapStopPending, which might have been set by PTRACE_INTERRUPT. If
+ // it wasn't, it will be reset via t.groupStopPending after the following.
+ t.trapStopPending = false
+ // If t's thread group is in a group stop and t is eligible to participate,
+ // make it do so. This is essentially the reverse of the special case in
+ // ptraceAttach, which converts a group stop to a ptrace stop. ("Handling
+ // of restart from group-stop is currently buggy, but the "as planned"
+ // behavior is to leave tracee stopped and waiting for SIGCONT." -
+ // ptrace(2))
+ if (t.tg.groupStopComplete || t.tg.groupStopPendingCount != 0) && !t.groupStopPending && t.exitState < TaskExitInitiated {
+ t.groupStopPending = true
+ // t already participated in the group stop when it unset
+ // groupStopPending.
+ t.groupStopAcknowledged = true
+ t.interrupt()
+ }
+ if _, ok := t.stop.(*ptraceStop); ok {
+ t.endInternalStopLocked()
+ }
+}
+
+// ptraceSignalLocked is called after signal dequeueing to check if t should
+// enter ptrace signal-delivery-stop.
+//
+// Preconditions: The signal mutex must be locked. The caller must be running
+// on the task goroutine.
+func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool {
+ if linux.Signal(info.Signo) == linux.SIGKILL {
+ return false
+ }
+ if !t.hasTracer() {
+ return false
+ }
+ // The tracer might change this signal into a stop signal, in which case
+ // any SIGCONT received after the signal was originally dequeued should
+ // cancel it. This is consistent with Linux.
+ t.tg.groupStopDequeued = true
+ // This is unconditional in ptrace_stop().
+ t.trapStopPending = false
+ // Can't lock the TaskSet mutex while holding a signal mutex.
+ t.tg.signalHandlers.mu.Unlock()
+ defer t.tg.signalHandlers.mu.Lock()
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ tracer := t.Tracer()
+ if tracer == nil {
+ return false
+ }
+ t.ptraceCode = info.Signo
+ t.ptraceSiginfo = info
+ t.Debugf("Entering signal-delivery-stop for signal %d", info.Signo)
+ if t.beginPtraceStopLocked() {
+ tracer.signalStop(t, arch.CLD_TRAPPED, info.Signo)
+ tracer.tg.eventQueue.Notify(EventTraceeStop)
+ }
+ return true
+}
+
+// ptraceSeccomp is called when a seccomp-bpf filter returns action
+// SECCOMP_RET_TRACE to check if t should enter PTRACE_EVENT_SECCOMP stop. data
+// is the lower 16 bits of the filter's return value.
+func (t *Task) ptraceSeccomp(data uint16) bool {
+ if !t.hasTracer() {
+ return false
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if !t.ptraceOpts.TraceSeccomp {
+ return false
+ }
+ t.Debugf("Entering PTRACE_EVENT_SECCOMP stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_SECCOMP, uint64(data))
+ return true
+}
+
+// ptraceSyscallEnter is called immediately before entering a syscall to check
+// if t should enter ptrace syscall-enter-stop.
+func (t *Task) ptraceSyscallEnter() (taskRunState, bool) {
+ if !t.hasTracer() {
+ return nil, false
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ switch t.ptraceSyscallMode {
+ case ptraceSyscallNone:
+ return nil, false
+ case ptraceSyscallIntercept:
+ t.Debugf("Entering syscall-enter-stop from PTRACE_SYSCALL")
+ t.ptraceSyscallStopLocked()
+ return (*runSyscallAfterSyscallEnterStop)(nil), true
+ case ptraceSyscallEmu:
+ t.Debugf("Entering syscall-enter-stop from PTRACE_SYSEMU")
+ t.ptraceSyscallStopLocked()
+ return (*runSyscallAfterSysemuStop)(nil), true
+ }
+ panic(fmt.Sprintf("Unknown ptraceSyscallMode: %v", t.ptraceSyscallMode))
+}
+
+// ptraceSyscallExit is called immediately after leaving a syscall to check if
+// t should enter ptrace syscall-exit-stop.
+func (t *Task) ptraceSyscallExit() {
+ if !t.hasTracer() {
+ return
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if t.ptraceSyscallMode != ptraceSyscallIntercept {
+ return
+ }
+ t.Debugf("Entering syscall-exit-stop")
+ t.ptraceSyscallStopLocked()
+}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) ptraceSyscallStopLocked() {
+ code := int32(linux.SIGTRAP)
+ if t.ptraceOpts.SysGood {
+ code |= 0x80
+ }
+ t.ptraceTrapLocked(code)
+}
+
+type ptraceCloneKind int32
+
+const (
+ // ptraceCloneKindClone represents a call to Task.Clone where
+ // TerminationSignal is not SIGCHLD and Vfork is false.
+ ptraceCloneKindClone ptraceCloneKind = iota
+
+ // ptraceCloneKindFork represents a call to Task.Clone where
+ // TerminationSignal is SIGCHLD and Vfork is false.
+ ptraceCloneKindFork
+
+ // ptraceCloneKindVfork represents a call to Task.Clone where Vfork is
+ // true.
+ ptraceCloneKindVfork
+)
+
+// ptraceClone is called at the end of a clone or fork syscall to check if t
+// should enter PTRACE_EVENT_CLONE, PTRACE_EVENT_FORK, or PTRACE_EVENT_VFORK
+// stop. child is the new task.
+func (t *Task) ptraceClone(kind ptraceCloneKind, child *Task, opts *CloneOptions) bool {
+ if !t.hasTracer() {
+ return false
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ event := false
+ if !opts.Untraced {
+ switch kind {
+ case ptraceCloneKindClone:
+ if t.ptraceOpts.TraceClone {
+ t.Debugf("Entering PTRACE_EVENT_CLONE stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_CLONE, uint64(t.tg.pidns.tids[child]))
+ event = true
+ }
+ case ptraceCloneKindFork:
+ if t.ptraceOpts.TraceFork {
+ t.Debugf("Entering PTRACE_EVENT_FORK stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_FORK, uint64(t.tg.pidns.tids[child]))
+ event = true
+ }
+ case ptraceCloneKindVfork:
+ if t.ptraceOpts.TraceVfork {
+ t.Debugf("Entering PTRACE_EVENT_VFORK stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_VFORK, uint64(t.tg.pidns.tids[child]))
+ event = true
+ }
+ default:
+ panic(fmt.Sprintf("Unknown ptraceCloneKind: %v", kind))
+ }
+ }
+ // "If the PTRACE_O_TRACEFORK, PTRACE_O_TRACEVFORK, or PTRACE_O_TRACECLONE
+ // options are in effect, then children created by, respectively, vfork(2)
+ // or clone(2) with the CLONE_VFORK flag, fork(2) or clone(2) with the exit
+ // signal set to SIGCHLD, and other kinds of clone(2), are automatically
+ // attached to the same tracer which traced their parent. SIGSTOP is
+ // delivered to the children, causing them to enter signal-delivery-stop
+ // after they exit the system call which created them." - ptrace(2)
+ //
+ // clone(2)'s documentation of CLONE_UNTRACED and CLONE_PTRACE is
+ // confusingly wrong; see kernel/fork.c:_do_fork() => copy_process() =>
+ // include/linux/ptrace.h:ptrace_init_task().
+ if event || opts.InheritTracer {
+ tracer := t.Tracer()
+ if tracer != nil {
+ child.ptraceTracer.Store(tracer)
+ tracer.ptraceTracees[child] = struct{}{}
+ // "The "seized" behavior ... is inherited by children that are
+ // automatically attached using PTRACE_O_TRACEFORK,
+ // PTRACE_O_TRACEVFORK, and PTRACE_O_TRACECLONE." - ptrace(2)
+ child.ptraceSeized = t.ptraceSeized
+ // "Flags are inherited by new tracees created and "auto-attached"
+ // via active PTRACE_O_TRACEFORK, PTRACE_O_TRACEVFORK, or
+ // PTRACE_O_TRACECLONE options." - ptrace(2)
+ child.ptraceOpts = t.ptraceOpts
+ child.tg.signalHandlers.mu.Lock()
+ // "PTRACE_SEIZE: ... Automatically attached children stop with
+ // PTRACE_EVENT_STOP and WSTOPSIG(status) returns SIGTRAP instead
+ // of having SIGSTOP signal delivered to them." - ptrace(2)
+ if child.ptraceSeized {
+ child.trapStopPending = true
+ } else {
+ child.pendingSignals.enqueue(&arch.SignalInfo{
+ Signo: int32(linux.SIGSTOP),
+ }, nil)
+ }
+ // The child will self-interrupt() when its task goroutine starts
+ // running, so we don't have to.
+ child.tg.signalHandlers.mu.Unlock()
+ }
+ }
+ return event
+}
+
+// ptraceVforkDone is called after the end of a vfork stop to check if t should
+// enter PTRACE_EVENT_VFORK_DONE stop. child is the new task's thread ID in t's
+// PID namespace.
+func (t *Task) ptraceVforkDone(child ThreadID) bool {
+ if !t.hasTracer() {
+ return false
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if !t.ptraceOpts.TraceVforkDone {
+ return false
+ }
+ t.Debugf("Entering PTRACE_EVENT_VFORK_DONE stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_VFORK_DONE, uint64(child))
+ return true
+}
+
+// ptraceExec is called at the end of an execve syscall to check if t should
+// enter PTRACE_EVENT_EXEC stop. oldTID is t's thread ID, in its *tracer's* PID
+// namespace, prior to the execve. (If t did not have a tracer at the time
+// oldTID was read, oldTID may be 0. This is consistent with Linux.)
+func (t *Task) ptraceExec(oldTID ThreadID) {
+ if !t.hasTracer() {
+ return
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ // Recheck with the TaskSet mutex locked. Most ptrace points don't need to
+ // do this because detaching resets ptrace options, but PTRACE_EVENT_EXEC
+ // is special because both TraceExec and !TraceExec do something if a
+ // tracer is attached.
+ if !t.hasTracer() {
+ return
+ }
+ if t.ptraceOpts.TraceExec {
+ t.Debugf("Entering PTRACE_EVENT_EXEC stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_EXEC, uint64(oldTID))
+ return
+ }
+ // "If the PTRACE_O_TRACEEXEC option is not in effect for the execing
+ // tracee, and if the tracee was PTRACE_ATTACHed rather that [sic]
+ // PTRACE_SEIZEd, the kernel delivers an extra SIGTRAP to the tracee after
+ // execve(2) returns. This is an ordinary signal (similar to one which can
+ // be generated by `kill -TRAP`, not a special kind of ptrace-stop.
+ // Employing PTRACE_GETSIGINFO for this signal returns si_code set to 0
+ // (SI_USER). This signal may be blocked by signal mask, and thus may be
+ // delivered (much) later." - ptrace(2)
+ if t.ptraceSeized {
+ return
+ }
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.sendSignalLocked(&arch.SignalInfo{
+ Signo: int32(linux.SIGTRAP),
+ Code: arch.SignalInfoUser,
+ }, false /* group */)
+}
+
+// ptraceExit is called early in the task exit path to check if t should enter
+// PTRACE_EVENT_EXIT stop.
+func (t *Task) ptraceExit() {
+ if !t.hasTracer() {
+ return
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if !t.ptraceOpts.TraceExit {
+ return
+ }
+ t.tg.signalHandlers.mu.Lock()
+ status := t.exitStatus.Status()
+ t.tg.signalHandlers.mu.Unlock()
+ t.Debugf("Entering PTRACE_EVENT_EXIT stop")
+ t.ptraceEventLocked(linux.PTRACE_EVENT_EXIT, uint64(status))
+}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) ptraceEventLocked(event int32, msg uint64) {
+ t.ptraceEventMsg = msg
+ // """
+ // PTRACE_EVENT stops are observed by the tracer as waitpid(2) returning
+ // with WIFSTOPPED(status), and WSTOPSIG(status) returns SIGTRAP. An
+ // additional bit is set in the higher byte of the status word: the value
+ // status>>8 will be
+ //
+ // (SIGTRAP | PTRACE_EVENT_foo << 8).
+ //
+ // ...
+ //
+ // """ - ptrace(2)
+ t.ptraceTrapLocked(int32(linux.SIGTRAP) | (event << 8))
+}
+
+// ptraceKill implements ptrace(PTRACE_KILL, target). t is the caller.
+func (t *Task) ptraceKill(target *Task) error {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ if target.Tracer() != t {
+ return syserror.ESRCH
+ }
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ // "This operation is deprecated; do not use it! Instead, send a SIGKILL
+ // directly using kill(2) or tgkill(2). The problem with PTRACE_KILL is
+ // that it requires the tracee to be in signal-delivery-stop, otherwise it
+ // may not work (i.e., may complete successfully but won't kill the
+ // tracee)." - ptrace(2)
+ if target.stop == nil {
+ return nil
+ }
+ if _, ok := target.stop.(*ptraceStop); !ok {
+ return nil
+ }
+ target.ptraceCode = int32(linux.SIGKILL)
+ target.endInternalStopLocked()
+ return nil
+}
+
+func (t *Task) ptraceInterrupt(target *Task) error {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ if target.Tracer() != t {
+ return syserror.ESRCH
+ }
+ if !target.ptraceSeized {
+ return syserror.EIO
+ }
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if target.killedLocked() || target.exitState >= TaskExitInitiated {
+ return nil
+ }
+ target.trapStopPending = true
+ if s, ok := target.stop.(*ptraceStop); ok && s.listen {
+ target.endInternalStopLocked()
+ }
+ target.interrupt()
+ return nil
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing. t must have a
+// tracer.
+func (t *Task) ptraceSetOptionsLocked(opts uintptr) error {
+ const valid = uintptr(linux.PTRACE_O_EXITKILL |
+ linux.PTRACE_O_TRACESYSGOOD |
+ linux.PTRACE_O_TRACECLONE |
+ linux.PTRACE_O_TRACEEXEC |
+ linux.PTRACE_O_TRACEEXIT |
+ linux.PTRACE_O_TRACEFORK |
+ linux.PTRACE_O_TRACESECCOMP |
+ linux.PTRACE_O_TRACEVFORK |
+ linux.PTRACE_O_TRACEVFORKDONE)
+ if opts&^valid != 0 {
+ return syserror.EINVAL
+ }
+ t.ptraceOpts = ptraceOptions{
+ ExitKill: opts&linux.PTRACE_O_EXITKILL != 0,
+ SysGood: opts&linux.PTRACE_O_TRACESYSGOOD != 0,
+ TraceClone: opts&linux.PTRACE_O_TRACECLONE != 0,
+ TraceExec: opts&linux.PTRACE_O_TRACEEXEC != 0,
+ TraceExit: opts&linux.PTRACE_O_TRACEEXIT != 0,
+ TraceFork: opts&linux.PTRACE_O_TRACEFORK != 0,
+ TraceSeccomp: opts&linux.PTRACE_O_TRACESECCOMP != 0,
+ TraceVfork: opts&linux.PTRACE_O_TRACEVFORK != 0,
+ TraceVforkDone: opts&linux.PTRACE_O_TRACEVFORKDONE != 0,
+ }
+ return nil
+}
+
+// Ptrace implements the ptrace system call.
+func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
+ // PTRACE_TRACEME ignores all other arguments.
+ if req == linux.PTRACE_TRACEME {
+ return t.ptraceTraceme()
+ }
+ // All other ptrace requests operate on a current or future tracee
+ // specified by pid.
+ target := t.tg.pidns.TaskWithID(pid)
+ if target == nil {
+ return syserror.ESRCH
+ }
+
+ // PTRACE_ATTACH and PTRACE_SEIZE do not require that target is not already
+ // a tracee.
+ if req == linux.PTRACE_ATTACH || req == linux.PTRACE_SEIZE {
+ seize := req == linux.PTRACE_SEIZE
+ if seize && addr != 0 {
+ return syserror.EIO
+ }
+ return t.ptraceAttach(target, seize, uintptr(data))
+ }
+ // PTRACE_KILL and PTRACE_INTERRUPT require that the target is a tracee,
+ // but does not require that it is ptrace-stopped.
+ if req == linux.PTRACE_KILL {
+ return t.ptraceKill(target)
+ }
+ if req == linux.PTRACE_INTERRUPT {
+ return t.ptraceInterrupt(target)
+ }
+ // All other ptrace requests require that the target is a ptrace-stopped
+ // tracee, and freeze the ptrace-stop so the tracee can be operated on.
+ t.tg.pidns.owner.mu.RLock()
+ if target.Tracer() != t {
+ t.tg.pidns.owner.mu.RUnlock()
+ return syserror.ESRCH
+ }
+ if !target.ptraceFreeze() {
+ t.tg.pidns.owner.mu.RUnlock()
+ // "Most ptrace commands (all except PTRACE_ATTACH, PTRACE_SEIZE,
+ // PTRACE_TRACEME, PTRACE_INTERRUPT, and PTRACE_KILL) require the
+ // tracee to be in a ptrace-stop, otherwise they fail with ESRCH." -
+ // ptrace(2)
+ return syserror.ESRCH
+ }
+ t.tg.pidns.owner.mu.RUnlock()
+ // Even if the target has a ptrace-stop active, the tracee's task goroutine
+ // may not yet have reached Task.doStop; wait for it to do so. This is safe
+ // because there's no way for target to initiate a ptrace-stop and then
+ // block (by calling Task.block) before entering it.
+ //
+ // Caveat: If tasks were just restored, the tracee's first call to
+ // Task.Activate (in Task.run) occurs before its first call to Task.doStop,
+ // which may block if the tracer's address space is active.
+ t.UninterruptibleSleepStart(true)
+ target.waitGoroutineStoppedOrExited()
+ t.UninterruptibleSleepFinish(true)
+
+ // Resuming commands end the ptrace stop, but only if successful.
+ // PTRACE_LISTEN ends the ptrace stop if trapNotifyPending is already set on the
+ // target.
+ switch req {
+ case linux.PTRACE_DETACH:
+ if err := t.ptraceDetach(target, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_CONT:
+ if err := target.ptraceUnstop(ptraceSyscallNone, false, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_SYSCALL:
+ if err := target.ptraceUnstop(ptraceSyscallIntercept, false, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_SINGLESTEP:
+ if err := target.ptraceUnstop(ptraceSyscallNone, true, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_SYSEMU:
+ if err := target.ptraceUnstop(ptraceSyscallEmu, false, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_SYSEMU_SINGLESTEP:
+ if err := target.ptraceUnstop(ptraceSyscallEmu, true, linux.Signal(data)); err != nil {
+ target.ptraceUnfreeze()
+ return err
+ }
+ return nil
+
+ case linux.PTRACE_LISTEN:
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if !target.ptraceSeized {
+ return syserror.EIO
+ }
+ if target.ptraceSiginfo == nil {
+ return syserror.EIO
+ }
+ if target.ptraceSiginfo.Code>>8 != linux.PTRACE_EVENT_STOP {
+ return syserror.EIO
+ }
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if target.trapNotifyPending {
+ target.endInternalStopLocked()
+ } else {
+ target.stop.(*ptraceStop).listen = true
+ target.ptraceUnfreezeLocked()
+ }
+ return nil
+ }
+
+ // All other ptrace requests expect us to unfreeze the stop.
+ defer target.ptraceUnfreeze()
+
+ switch req {
+ case linux.PTRACE_PEEKTEXT, linux.PTRACE_PEEKDATA:
+ // "At the system call level, the PTRACE_PEEKTEXT, PTRACE_PEEKDATA, and
+ // PTRACE_PEEKUSER requests have a different API: they store the result
+ // at the address specified by the data parameter, and the return value
+ // is the error flag." - ptrace(2)
+ word := t.Arch().Native(0)
+ if _, err := usermem.CopyObjectIn(t, target.MemoryManager(), addr, word, usermem.IOOpts{
+ IgnorePermissions: true,
+ }); err != nil {
+ return err
+ }
+ _, err := t.CopyOut(data, word)
+ return err
+
+ case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA:
+ _, err := usermem.CopyObjectOut(t, target.MemoryManager(), addr, t.Arch().Native(uintptr(data)), usermem.IOOpts{
+ IgnorePermissions: true,
+ })
+ return err
+
+ case linux.PTRACE_GETREGSET:
+ // "Read the tracee's registers. addr specifies, in an
+ // architecture-dependent way, the type of registers to be read. ...
+ // data points to a struct iovec, which describes the destination
+ // buffer's location and length. On return, the kernel modifies iov.len
+ // to indicate the actual number of bytes returned." - ptrace(2)
+ ars, err := t.CopyInIovecs(data, 1)
+ if err != nil {
+ return err
+ }
+ ar := ars.Head()
+ n, err := target.Arch().PtraceGetRegSet(uintptr(addr), &usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: ar.Start,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }, int(ar.Length()))
+ if err != nil {
+ return err
+ }
+
+ // Update iovecs to represent the range of the written register set.
+ end, ok := ar.Start.AddLength(uint64(n))
+ if !ok {
+ panic(fmt.Sprintf("%#x + %#x overflows. Invalid reg size > %#x", ar.Start, n, ar.Length()))
+ }
+ ar.End = end
+ return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar))
+
+ case linux.PTRACE_SETREGSET:
+ ars, err := t.CopyInIovecs(data, 1)
+ if err != nil {
+ return err
+ }
+ ar := ars.Head()
+ n, err := target.Arch().PtraceSetRegSet(uintptr(addr), &usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: ar.Start,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }, int(ar.Length()))
+ if err != nil {
+ return err
+ }
+ ar.End -= usermem.Addr(n)
+ return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar))
+
+ case linux.PTRACE_GETSIGINFO:
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if target.ptraceSiginfo == nil {
+ return syserror.EINVAL
+ }
+ _, err := t.CopyOut(data, target.ptraceSiginfo)
+ return err
+
+ case linux.PTRACE_SETSIGINFO:
+ var info arch.SignalInfo
+ if _, err := t.CopyIn(data, &info); err != nil {
+ return err
+ }
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if target.ptraceSiginfo == nil {
+ return syserror.EINVAL
+ }
+ target.ptraceSiginfo = &info
+ return nil
+
+ case linux.PTRACE_GETSIGMASK:
+ if addr != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ _, err := t.CopyOut(data, target.SignalMask())
+ return err
+
+ case linux.PTRACE_SETSIGMASK:
+ if addr != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if _, err := t.CopyIn(data, &mask); err != nil {
+ return err
+ }
+ // The target's task goroutine is stopped, so this is safe:
+ target.SetSignalMask(mask &^ UnblockableSignals)
+ return nil
+
+ case linux.PTRACE_SETOPTIONS:
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ return target.ptraceSetOptionsLocked(uintptr(data))
+
+ case linux.PTRACE_GETEVENTMSG:
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ _, err := t.CopyOut(usermem.Addr(data), target.ptraceEventMsg)
+ return err
+
+ // PEEKSIGINFO is unimplemented but seems to have no users anywhere.
+
+ default:
+ return t.ptraceArch(target, req, addr, data)
+ }
+}
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
new file mode 100644
index 000000000..048eeaa3f
--- /dev/null
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -0,0 +1,89 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// ptraceArch implements arch-specific ptrace commands.
+func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) error {
+ switch req {
+ case linux.PTRACE_PEEKUSR: // aka PTRACE_PEEKUSER
+ n, err := target.Arch().PtracePeekUser(uintptr(addr))
+ if err != nil {
+ return err
+ }
+ _, err = t.CopyOut(data, n)
+ return err
+
+ case linux.PTRACE_POKEUSR: // aka PTRACE_POKEUSER
+ return target.Arch().PtracePokeUser(uintptr(addr), uintptr(data))
+
+ case linux.PTRACE_GETREGS:
+ // "Copy the tracee's general-purpose ... registers ... to the address
+ // data in the tracer. ... (addr is ignored.) Note that SPARC systems
+ // have the meaning of data and addr reversed ..."
+ _, err := target.Arch().PtraceGetRegs(&usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: data,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ })
+ return err
+
+ case linux.PTRACE_GETFPREGS:
+ _, err := target.Arch().PtraceGetFPRegs(&usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: data,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ })
+ return err
+
+ case linux.PTRACE_SETREGS:
+ _, err := target.Arch().PtraceSetRegs(&usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: data,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ })
+ return err
+
+ case linux.PTRACE_SETFPREGS:
+ _, err := target.Arch().PtraceSetFPRegs(&usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: data,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ })
+ return err
+
+ default:
+ return syserror.EIO
+ }
+}
diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go
new file mode 100644
index 000000000..4899c813f
--- /dev/null
+++ b/pkg/sentry/kernel/ptrace_arm64.go
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// ptraceArch implements arch-specific ptrace commands.
+func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) error {
+ return syserror.EIO
+}
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
new file mode 100644
index 000000000..c4fb2c56c
--- /dev/null
+++ b/pkg/sentry/kernel/rseq.go
@@ -0,0 +1,120 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/hostcpu"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Restartable sequences, as described in https://lwn.net/Articles/650333/.
+
+// RSEQCriticalRegion describes a restartable sequence critical region.
+//
+// +stateify savable
+type RSEQCriticalRegion struct {
+ // When a task in this thread group has its CPU preempted (as defined by
+ // platform.ErrContextCPUPreempted) or has a signal delivered to an
+ // application handler while its instruction pointer is in CriticalSection,
+ // set the instruction pointer to Restart and application register r10 (on
+ // amd64) to the former instruction pointer.
+ CriticalSection usermem.AddrRange
+ Restart usermem.Addr
+}
+
+// RSEQAvailable returns true if t supports restartable sequences.
+func (t *Task) RSEQAvailable() bool {
+ return t.k.useHostCores && t.k.Platform.DetectsCPUPreemption()
+}
+
+// RSEQCriticalRegion returns a copy of t's thread group's current restartable
+// sequence.
+func (t *Task) RSEQCriticalRegion() RSEQCriticalRegion {
+ return *t.tg.rscr.Load().(*RSEQCriticalRegion)
+}
+
+// SetRSEQCriticalRegion replaces t's thread group's restartable sequence.
+//
+// Preconditions: t.RSEQAvailable() == true.
+func (t *Task) SetRSEQCriticalRegion(rscr RSEQCriticalRegion) error {
+ // These checks are somewhat more lenient than in Linux, which (bizarrely)
+ // requires rscr.CriticalSection to be non-empty and rscr.Restart to be
+ // outside of rscr.CriticalSection, even if rscr.CriticalSection.Start == 0
+ // (which disables the critical region).
+ if rscr.CriticalSection.Start == 0 {
+ rscr.CriticalSection.End = 0
+ rscr.Restart = 0
+ t.tg.rscr.Store(&rscr)
+ return nil
+ }
+ if rscr.CriticalSection.Start >= rscr.CriticalSection.End {
+ return syserror.EINVAL
+ }
+ if rscr.CriticalSection.Contains(rscr.Restart) {
+ return syserror.EINVAL
+ }
+ // TODO(jamieliu): check that rscr.CriticalSection and rscr.Restart are in
+ // the application address range, for consistency with Linux
+ t.tg.rscr.Store(&rscr)
+ return nil
+}
+
+// RSEQCPUAddr returns the address that RSEQ will keep updated with t's CPU
+// number.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) RSEQCPUAddr() usermem.Addr {
+ return t.rseqCPUAddr
+}
+
+// SetRSEQCPUAddr replaces the address that RSEQ will keep updated with t's CPU
+// number.
+//
+// Preconditions: t.RSEQAvailable() == true. The caller must be running on the
+// task goroutine. t's AddressSpace must be active.
+func (t *Task) SetRSEQCPUAddr(addr usermem.Addr) error {
+ t.rseqCPUAddr = addr
+ if addr != 0 {
+ t.rseqCPU = int32(hostcpu.GetCPU())
+ if err := t.rseqCopyOutCPU(); err != nil {
+ t.rseqCPUAddr = 0
+ t.rseqCPU = -1
+ return syserror.EINVAL // yes, EINVAL, not err or EFAULT
+ }
+ } else {
+ t.rseqCPU = -1
+ }
+ return nil
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqCopyOutCPU() error {
+ buf := t.CopyScratchBuffer(4)
+ usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU))
+ _, err := t.CopyOutBytes(t.rseqCPUAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) rseqInterrupt() {
+ rscr := t.tg.rscr.Load().(*RSEQCriticalRegion)
+ if ip := t.Arch().IP(); rscr.CriticalSection.Contains(usermem.Addr(ip)) {
+ t.Debugf("Interrupted RSEQ critical section at %#x; restarting at %#x", ip, rscr.Restart)
+ t.Arch().SetIP(uintptr(rscr.Restart))
+ t.Arch().SetRSEQInterruptedIP(ip)
+ }
+}
diff --git a/pkg/sentry/kernel/sched/cpuset.go b/pkg/sentry/kernel/sched/cpuset.go
new file mode 100644
index 000000000..c6c436690
--- /dev/null
+++ b/pkg/sentry/kernel/sched/cpuset.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sched
+
+import "math/bits"
+
+const (
+ bitsPerByte = 8
+ bytesPerLong = 8 // only for 64-bit architectures
+)
+
+// CPUSet contains a bitmap to record CPU information.
+//
+// Note that this definition is only correct for little-endian architectures,
+// since Linux's cpumask_t uses unsigned long.
+type CPUSet []byte
+
+// CPUSetSize returns the size in bytes of a CPUSet that can contain num cpus.
+func CPUSetSize(num uint) uint {
+ // NOTE(b/68859821): Applications may expect that the size of a CPUSet in
+ // bytes is always a multiple of sizeof(unsigned long), since this is true
+ // in Linux. Thus we always round up.
+ bytes := (num + bitsPerByte - 1) / bitsPerByte
+ longs := (bytes + bytesPerLong - 1) / bytesPerLong
+ return longs * bytesPerLong
+}
+
+// NewCPUSet returns a CPUSet for the given number of CPUs which initially
+// contains no CPUs.
+func NewCPUSet(num uint) CPUSet {
+ return CPUSet(make([]byte, CPUSetSize(num)))
+}
+
+// NewFullCPUSet returns a CPUSet for the given number of CPUs, all of which
+// are present in the set.
+func NewFullCPUSet(num uint) CPUSet {
+ c := NewCPUSet(num)
+ var i uint
+ for ; i < num/bitsPerByte; i++ {
+ c[i] = 0xff
+ }
+ if rem := num % bitsPerByte; rem != 0 {
+ c[i] = (1 << rem) - 1
+ }
+ return c
+}
+
+// Size returns the size of 'c' in bytes.
+func (c CPUSet) Size() uint {
+ return uint(len(c))
+}
+
+// NumCPUs returns how many cpus are set in the CPUSet.
+func (c CPUSet) NumCPUs() uint {
+ var n int
+ for _, b := range c {
+ n += bits.OnesCount8(b)
+ }
+ return uint(n)
+}
+
+// Copy returns a copy of the CPUSet.
+func (c CPUSet) Copy() CPUSet {
+ return append(CPUSet(nil), c...)
+}
+
+// Set sets the bit corresponding to cpu.
+func (c *CPUSet) Set(cpu uint) {
+ (*c)[cpu/bitsPerByte] |= 1 << (cpu % bitsPerByte)
+}
+
+// ClearAbove clears bits corresponding to cpu and all higher cpus.
+func (c *CPUSet) ClearAbove(cpu uint) {
+ i := cpu / bitsPerByte
+ if i >= c.Size() {
+ return
+ }
+ (*c)[i] &^= 0xff << (cpu % bitsPerByte)
+ for i++; i < c.Size(); i++ {
+ (*c)[i] = 0
+ }
+}
+
+// ForEachCPU iterates over the CPUSet and calls fn with the cpu index if
+// it's set.
+func (c CPUSet) ForEachCPU(fn func(uint)) {
+ for i := uint(0); i < c.Size()*bitsPerByte; i++ {
+ bit := uint(1) << (i & (bitsPerByte - 1))
+ if uint(c[i/bitsPerByte])&bit == bit {
+ fn(i)
+ }
+ }
+}
diff --git a/pkg/sentry/kernel/sched/sched.go b/pkg/sentry/kernel/sched/sched.go
new file mode 100644
index 000000000..de18c9d02
--- /dev/null
+++ b/pkg/sentry/kernel/sched/sched.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sched implements scheduler related features.
+package sched
diff --git a/pkg/sentry/kernel/sched/sched_state_autogen.go b/pkg/sentry/kernel/sched/sched_state_autogen.go
new file mode 100755
index 000000000..2a482732e
--- /dev/null
+++ b/pkg/sentry/kernel/sched/sched_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package sched
+
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
new file mode 100644
index 000000000..cc75eb08a
--- /dev/null
+++ b/pkg/sentry/kernel/seccomp.go
@@ -0,0 +1,217 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/bpf"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const maxSyscallFilterInstructions = 1 << 15
+
+// seccompData is equivalent to struct seccomp_data, which contains the data
+// passed to seccomp-bpf filters.
+type seccompData struct {
+ // nr is the system call number.
+ nr int32
+
+ // arch is an AUDIT_ARCH_* value indicating the system call convention.
+ arch uint32
+
+ // instructionPointer is the value of the instruction pointer at the time
+ // of the system call.
+ instructionPointer uint64
+
+ // args contains the first 6 system call arguments.
+ args [6]uint64
+}
+
+func (d *seccompData) asBPFInput() bpf.Input {
+ return bpf.InputBytes{binary.Marshal(nil, usermem.ByteOrder, d), usermem.ByteOrder}
+}
+
+func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo {
+ si := &arch.SignalInfo{
+ Signo: int32(linux.SIGSYS),
+ Errno: errno,
+ Code: arch.SYS_SECCOMP,
+ }
+ si.SetCallAddr(uint64(ip))
+ si.SetSyscall(sysno)
+ si.SetArch(t.SyscallTable().AuditNumber)
+ return si
+}
+
+// checkSeccompSyscall applies the task's seccomp filters before the execution
+// of syscall sysno at instruction pointer ip. (These parameters must be passed
+// in because vsyscalls do not use the values in t.Arch().)
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip usermem.Addr) linux.BPFAction {
+ result := linux.BPFAction(t.evaluateSyscallFilters(sysno, args, ip))
+ action := result & linux.SECCOMP_RET_ACTION
+ switch action {
+ case linux.SECCOMP_RET_TRAP:
+ // "Results in the kernel sending a SIGSYS signal to the triggering
+ // task without executing the system call. ... The SECCOMP_RET_DATA
+ // portion of the return value will be passed as si_errno." -
+ // Documentation/prctl/seccomp_filter.txt
+ t.SendSignal(seccompSiginfo(t, int32(result.Data()), sysno, ip))
+ // "The return value register will contain an arch-dependent value." In
+ // practice, it's ~always the syscall number.
+ t.Arch().SetReturn(uintptr(sysno))
+
+ case linux.SECCOMP_RET_ERRNO:
+ // "Results in the lower 16-bits of the return value being passed to
+ // userland as the errno without executing the system call."
+ t.Arch().SetReturn(-uintptr(result.Data()))
+
+ case linux.SECCOMP_RET_TRACE:
+ // "When returned, this value will cause the kernel to attempt to
+ // notify a ptrace()-based tracer prior to executing the system call.
+ // If there is no tracer present, -ENOSYS is returned to userland and
+ // the system call is not executed."
+ if !t.ptraceSeccomp(result.Data()) {
+ // This useless-looking temporary is needed because Go.
+ tmp := uintptr(syscall.ENOSYS)
+ t.Arch().SetReturn(-tmp)
+ return linux.SECCOMP_RET_ERRNO
+ }
+
+ case linux.SECCOMP_RET_ALLOW:
+ // "Results in the system call being executed."
+
+ case linux.SECCOMP_RET_KILL_THREAD:
+ // "Results in the task exiting immediately without executing the
+ // system call. The exit status of the task will be SIGSYS, not
+ // SIGKILL."
+
+ default:
+ // consistent with Linux
+ return linux.SECCOMP_RET_KILL_THREAD
+ }
+ return action
+}
+
+func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 {
+ data := seccompData{
+ nr: sysno,
+ arch: t.tc.st.AuditNumber,
+ instructionPointer: uint64(ip),
+ }
+ // data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so
+ // we can't do any slicing tricks or even use copy/append here.
+ for i, arg := range args {
+ if i >= len(data.args) {
+ break
+ }
+ data.args[i] = arg.Uint64()
+ }
+ input := data.asBPFInput()
+
+ ret := uint32(linux.SECCOMP_RET_ALLOW)
+ f := t.syscallFilters.Load()
+ if f == nil {
+ return ret
+ }
+
+ // "Every filter successfully installed will be evaluated (in reverse
+ // order) for each system call the task makes." - kernel/seccomp.c
+ for i := len(f.([]bpf.Program)) - 1; i >= 0; i-- {
+ thisRet, err := bpf.Exec(f.([]bpf.Program)[i], input)
+ if err != nil {
+ t.Debugf("seccomp-bpf filter %d returned error: %v", i, err)
+ thisRet = uint32(linux.SECCOMP_RET_KILL_THREAD)
+ }
+ // "If multiple filters exist, the return value for the evaluation of a
+ // given system call will always use the highest precedent value." -
+ // Documentation/prctl/seccomp_filter.txt
+ //
+ // (Note that this contradicts prctl(2): "If the filters permit prctl()
+ // calls, then additional filters can be added; they are run in order
+ // until the first non-allow result is seen." prctl(2) is incorrect.)
+ //
+ // "The ordering ensures that a min_t() over composed return values
+ // always selects the least permissive choice." -
+ // include/uapi/linux/seccomp.h
+ if (thisRet & linux.SECCOMP_RET_ACTION) < (ret & linux.SECCOMP_RET_ACTION) {
+ ret = thisRet
+ }
+ }
+
+ return ret
+}
+
+// AppendSyscallFilter adds BPF program p as a system call filter.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) AppendSyscallFilter(p bpf.Program, syncAll bool) error {
+ // While syscallFilters are an atomic.Value we must take the mutex to prevent
+ // our read-copy-update from happening while another task is syncing syscall
+ // filters to us, this keeps the filters in a consistent state.
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+
+ // Cap the combined length of all syscall filters (plus a penalty of 4
+ // instructions per filter beyond the first) to maxSyscallFilterInstructions.
+ // This restriction is inherited from Linux.
+ totalLength := p.Length()
+ var newFilters []bpf.Program
+
+ if sf := t.syscallFilters.Load(); sf != nil {
+ oldFilters := sf.([]bpf.Program)
+ for _, f := range oldFilters {
+ totalLength += f.Length() + 4
+ }
+ newFilters = append(newFilters, oldFilters...)
+ }
+
+ if totalLength > maxSyscallFilterInstructions {
+ return syserror.ENOMEM
+ }
+
+ newFilters = append(newFilters, p)
+ t.syscallFilters.Store(newFilters)
+
+ if syncAll {
+ // Note: No new privs is always assumed to be set.
+ for ot := t.tg.tasks.Front(); ot != nil; ot = ot.Next() {
+ if ot != t {
+ var copiedFilters []bpf.Program
+ copiedFilters = append(copiedFilters, newFilters...)
+ ot.syscallFilters.Store(copiedFilters)
+ }
+ }
+ }
+
+ return nil
+}
+
+// SeccompMode returns a SECCOMP_MODE_* constant indicating the task's current
+// seccomp syscall filtering mode, appropriate for both prctl(PR_GET_SECCOMP)
+// and /proc/[pid]/status.
+func (t *Task) SeccompMode() int {
+ f := t.syscallFilters.Load()
+ if f != nil && len(f.([]bpf.Program)) > 0 {
+ return linux.SECCOMP_MODE_FILTER
+ }
+ return linux.SECCOMP_MODE_NONE
+}
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
new file mode 100644
index 000000000..9d0620e02
--- /dev/null
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -0,0 +1,571 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package semaphore implements System V semaphores.
+package semaphore
+
+import (
+ "fmt"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const (
+ valueMax = 32767 // SEMVMX
+
+ // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL).
+ semaphoresMax = 32000
+
+ // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI).
+ setsMax = 32000
+
+ // semaphoresTotalMax is "system-wide limit on the number of semaphores"
+ // (SEMMNS = SEMMNI*SEMMSL).
+ semaphoresTotalMax = 1024000000
+)
+
+// Registry maintains a set of semaphores that can be found by key or ID.
+//
+// +stateify savable
+type Registry struct {
+ // userNS owning the ipc name this registry belongs to. Immutable.
+ userNS *auth.UserNamespace
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+ semaphores map[int32]*Set
+ lastIDUsed int32
+}
+
+// Set represents a set of semaphores that can be operated atomically.
+//
+// +stateify savable
+type Set struct {
+ // registry owning this sem set. Immutable.
+ registry *Registry
+
+ // Id is a handle that identifies the set.
+ ID int32
+
+ // key is an user provided key that can be shared between processes.
+ key int32
+
+ // creator is the user that created the set. Immutable.
+ creator fs.FileOwner
+
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+ owner fs.FileOwner
+ perms fs.FilePermissions
+ opTime ktime.Time
+ changeTime ktime.Time
+
+ // sems holds all semaphores in the set. The slice itself is immutable after
+ // it's been set, however each 'sem' object in the slice requires 'mu' lock.
+ sems []sem
+
+ // dead is set to true when the set is removed and can't be reached anymore.
+ // All waiters must wake up and fail when set is dead.
+ dead bool
+}
+
+// sem represents a single semanphore from a set.
+//
+// +stateify savable
+type sem struct {
+ value int16
+ waiters waiterList `state:"zerovalue"`
+ pid int32
+}
+
+// waiter represents a caller that is waiting for the semaphore value to
+// become positive or zero.
+//
+// +stateify savable
+type waiter struct {
+ waiterEntry
+
+ // value represents how much resource the waiter needs to wake up.
+ value int16
+ ch chan struct{}
+}
+
+// NewRegistry creates a new semaphore set registry.
+func NewRegistry(userNS *auth.UserNamespace) *Registry {
+ return &Registry{
+ userNS: userNS,
+ semaphores: make(map[int32]*Set),
+ }
+}
+
+// FindOrCreate searches for a semaphore set that matches 'key'. If not found,
+// it may create a new one if requested. If private is true, key is ignored and
+// a new set is always created. If create is false, it fails if a set cannot
+// be found. If exclusive is true, it fails if a set with the same key already
+// exists.
+func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
+ if nsems < 0 || nsems > semaphoresMax {
+ return nil, syserror.EINVAL
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !private {
+ // Look up an existing semaphore.
+ if set := r.findByKey(key); set != nil {
+ set.mu.Lock()
+ defer set.mu.Unlock()
+
+ // Check that caller can access semaphore set.
+ creds := auth.CredentialsFromContext(ctx)
+ if !set.checkPerms(creds, fs.PermsFromMode(mode)) {
+ return nil, syserror.EACCES
+ }
+
+ // Validate parameters.
+ if nsems > int32(set.Size()) {
+ return nil, syserror.EINVAL
+ }
+ if create && exclusive {
+ return nil, syserror.EEXIST
+ }
+ return set, nil
+ }
+
+ if !create {
+ // Semaphore not found and should not be created.
+ return nil, syserror.ENOENT
+ }
+ }
+
+ // Zero is only valid if an existing set is found.
+ if nsems == 0 {
+ return nil, syserror.EINVAL
+ }
+
+ // Apply system limits.
+ if len(r.semaphores) >= setsMax {
+ return nil, syserror.EINVAL
+ }
+ if r.totalSems() > int(semaphoresTotalMax-nsems) {
+ return nil, syserror.EINVAL
+ }
+
+ // Finally create a new set.
+ owner := fs.FileOwnerFromContext(ctx)
+ perms := fs.FilePermsFromMode(mode)
+ return r.newSet(ctx, key, owner, owner, perms, nsems)
+}
+
+// RemoveID removes set with give 'id' from the registry and marks the set as
+// dead. All waiters will be awakened and fail.
+func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ set := r.semaphores[id]
+ if set == nil {
+ return syserror.EINVAL
+ }
+
+ set.mu.Lock()
+ defer set.mu.Unlock()
+
+ // "The effective user ID of the calling process must match the creator or
+ // owner of the semaphore set, or the caller must be privileged."
+ if !set.checkCredentials(creds) && !set.checkCapability(creds) {
+ return syserror.EACCES
+ }
+
+ delete(r.semaphores, set.ID)
+ set.destroy()
+ return nil
+}
+
+func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.FileOwner, perms fs.FilePermissions, nsems int32) (*Set, error) {
+ set := &Set{
+ registry: r,
+ key: key,
+ owner: owner,
+ creator: owner,
+ perms: perms,
+ changeTime: ktime.NowFromContext(ctx),
+ sems: make([]sem, nsems),
+ }
+
+ // Find the next available ID.
+ for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ {
+ // Handle wrap around.
+ if id < 0 {
+ id = 0
+ continue
+ }
+ if r.semaphores[id] == nil {
+ r.lastIDUsed = id
+ r.semaphores[id] = set
+ set.ID = id
+ return set, nil
+ }
+ }
+
+ log.Warningf("Semaphore map is full, they must be leaking")
+ return nil, syserror.ENOMEM
+}
+
+// FindByID looks up a set given an ID.
+func (r *Registry) FindByID(id int32) *Set {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return r.semaphores[id]
+}
+
+func (r *Registry) findByKey(key int32) *Set {
+ for _, v := range r.semaphores {
+ if v.key == key {
+ return v
+ }
+ }
+ return nil
+}
+
+func (r *Registry) totalSems() int {
+ totalSems := 0
+ for _, v := range r.semaphores {
+ totalSems += v.Size()
+ }
+ return totalSems
+}
+
+func (s *Set) findSem(num int32) *sem {
+ if num < 0 || int(num) >= s.Size() {
+ return nil
+ }
+ return &s.sems[num]
+}
+
+// Size returns the number of semaphores in the set. Size is immutable.
+func (s *Set) Size() int {
+ return len(s.sems)
+}
+
+// Change changes some fields from the set atomically.
+func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.FileOwner, perms fs.FilePermissions) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The effective UID of the calling process must match the owner or creator
+ // of the semaphore set, or the caller must be privileged."
+ if !s.checkCredentials(creds) && !s.checkCapability(creds) {
+ return syserror.EACCES
+ }
+
+ s.owner = owner
+ s.perms = perms
+ s.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
+// SetVal overrides a semaphore value, waking up waiters as needed.
+func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error {
+ if val < 0 || val > valueMax {
+ return syserror.ERANGE
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have alter permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Write: true}) {
+ return syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return syserror.ERANGE
+ }
+
+ // TODO(b/29354920): Clear undo entries in all processes
+ sem.value = val
+ sem.pid = pid
+ s.changeTime = ktime.NowFromContext(ctx)
+ sem.wakeWaiters()
+ return nil
+}
+
+// SetValAll overrides all semaphores values, waking up waiters as needed. It also
+// sets semaphore's PID which was fixed in Linux 4.6.
+//
+// 'len(vals)' must be equal to 's.Size()'.
+func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credentials, pid int32) error {
+ if len(vals) != s.Size() {
+ panic(fmt.Sprintf("vals length (%d) different that Set.Size() (%d)", len(vals), s.Size()))
+ }
+
+ for _, val := range vals {
+ if val < 0 || val > valueMax {
+ return syserror.ERANGE
+ }
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have alter permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Write: true}) {
+ return syserror.EACCES
+ }
+
+ for i, val := range vals {
+ sem := &s.sems[i]
+
+ // TODO(b/29354920): Clear undo entries in all processes
+ sem.value = int16(val)
+ sem.pid = pid
+ sem.wakeWaiters()
+ }
+ s.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
+// GetVal returns a semaphore value.
+func (s *Set) GetVal(num int32, creds *auth.Credentials) (int16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return 0, syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return 0, syserror.ERANGE
+ }
+ return sem.value, nil
+}
+
+// GetValAll returns value for all semaphores.
+func (s *Set) GetValAll(creds *auth.Credentials) ([]uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return nil, syserror.EACCES
+ }
+
+ vals := make([]uint16, s.Size())
+ for i, sem := range s.sems {
+ vals[i] = uint16(sem.value)
+ }
+ return vals, nil
+}
+
+// GetPID returns the PID set when performing operations in the semaphore.
+func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return 0, syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return 0, syserror.ERANGE
+ }
+ return sem.pid, nil
+}
+
+// ExecuteOps attempts to execute a list of operations to the set. It only
+// succeeds when all operations can be applied. No changes are made if it fails.
+//
+// On failure, it may return an error (retries are hopeless) or it may return
+// a channel that can be waited on before attempting again.
+func (s *Set) ExecuteOps(ctx context.Context, ops []linux.Sembuf, creds *auth.Credentials, pid int32) (chan struct{}, int32, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Did it race with a removal operation?
+ if s.dead {
+ return nil, 0, syserror.EIDRM
+ }
+
+ // Validate the operations.
+ readOnly := true
+ for _, op := range ops {
+ if s.findSem(int32(op.SemNum)) == nil {
+ return nil, 0, syserror.EFBIG
+ }
+ if op.SemOp != 0 {
+ readOnly = false
+ }
+ }
+
+ if !s.checkPerms(creds, fs.PermMask{Read: readOnly, Write: !readOnly}) {
+ return nil, 0, syserror.EACCES
+ }
+
+ ch, num, err := s.executeOps(ctx, ops, pid)
+ if err != nil {
+ return nil, 0, err
+ }
+ return ch, num, nil
+}
+
+func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (chan struct{}, int32, error) {
+ // Changes to semaphores go to this slice temporarily until they all succeed.
+ tmpVals := make([]int16, len(s.sems))
+ for i := range s.sems {
+ tmpVals[i] = s.sems[i].value
+ }
+
+ for _, op := range ops {
+ sem := &s.sems[op.SemNum]
+ if op.SemOp == 0 {
+ // Handle 'wait for zero' operation.
+ if tmpVals[op.SemNum] != 0 {
+ // Semaphore isn't 0, must wait.
+ if op.SemFlg&linux.IPC_NOWAIT != 0 {
+ return nil, 0, syserror.ErrWouldBlock
+ }
+
+ w := newWaiter(op.SemOp)
+ sem.waiters.PushBack(w)
+ return w.ch, int32(op.SemNum), nil
+ }
+ } else {
+ if op.SemOp < 0 {
+ // Handle 'wait' operation.
+ if -op.SemOp > valueMax {
+ return nil, 0, syserror.ERANGE
+ }
+ if -op.SemOp > tmpVals[op.SemNum] {
+ // Not enough resources, must wait.
+ if op.SemFlg&linux.IPC_NOWAIT != 0 {
+ return nil, 0, syserror.ErrWouldBlock
+ }
+
+ w := newWaiter(op.SemOp)
+ sem.waiters.PushBack(w)
+ return w.ch, int32(op.SemNum), nil
+ }
+ } else {
+ // op.SemOp > 0: Handle 'signal' operation.
+ if tmpVals[op.SemNum] > valueMax-op.SemOp {
+ return nil, 0, syserror.ERANGE
+ }
+ }
+
+ tmpVals[op.SemNum] += op.SemOp
+ }
+ }
+
+ // All operations succeeded, apply them.
+ // TODO(b/29354920): handle undo operations.
+ for i, v := range tmpVals {
+ s.sems[i].value = v
+ s.sems[i].wakeWaiters()
+ s.sems[i].pid = pid
+ }
+ s.opTime = ktime.NowFromContext(ctx)
+ return nil, 0, nil
+}
+
+// AbortWait notifies that a waiter is giving up and will not wait on the
+// channel anymore.
+func (s *Set) AbortWait(num int32, ch chan struct{}) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ sem := &s.sems[num]
+ for w := sem.waiters.Front(); w != nil; w = w.Next() {
+ if w.ch == ch {
+ sem.waiters.Remove(w)
+ return
+ }
+ }
+ // Waiter may not be found in case it raced with wakeWaiters().
+}
+
+func (s *Set) checkCredentials(creds *auth.Credentials) bool {
+ return s.owner.UID == creds.EffectiveKUID ||
+ s.owner.GID == creds.EffectiveKGID ||
+ s.creator.UID == creds.EffectiveKUID ||
+ s.creator.GID == creds.EffectiveKGID
+}
+
+func (s *Set) checkCapability(creds *auth.Credentials) bool {
+ return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, s.registry.userNS) && creds.UserNamespace.MapFromKUID(s.owner.UID).Ok()
+}
+
+func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool {
+ // Are we owner, or in group, or other?
+ p := s.perms.Other
+ if s.owner.UID == creds.EffectiveKUID {
+ p = s.perms.User
+ } else if creds.InGroup(s.owner.GID) {
+ p = s.perms.Group
+ }
+
+ // Are permissions satisfied without capability checks?
+ if p.SupersetOf(reqPerms) {
+ return true
+ }
+
+ return s.checkCapability(creds)
+}
+
+// destroy destroys the set. Caller must hold 's.mu'.
+func (s *Set) destroy() {
+ // Notify all waiters. They will fail on the next attempt to execute
+ // operations and return error.
+ s.dead = true
+ for _, s := range s.sems {
+ for w := s.waiters.Front(); w != nil; w = w.Next() {
+ w.ch <- struct{}{}
+ }
+ s.waiters.Reset()
+ }
+}
+
+// wakeWaiters goes over all waiters and checks which of them can be notified.
+func (s *sem) wakeWaiters() {
+ // Note that this will release all waiters waiting for 0 too.
+ for w := s.waiters.Front(); w != nil; {
+ if s.value < w.value {
+ // Still blocked, skip it.
+ continue
+ }
+ w.ch <- struct{}{}
+ old := w
+ w = w.Next()
+ s.waiters.Remove(old)
+ }
+}
+
+func newWaiter(val int16) *waiter {
+ return &waiter{
+ value: val,
+ ch: make(chan struct{}, 1),
+ }
+}
diff --git a/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go
new file mode 100755
index 000000000..1551f792e
--- /dev/null
+++ b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go
@@ -0,0 +1,115 @@
+// automatically generated by stateify.
+
+package semaphore
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Registry) beforeSave() {}
+func (x *Registry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("userNS", &x.userNS)
+ m.Save("semaphores", &x.semaphores)
+ m.Save("lastIDUsed", &x.lastIDUsed)
+}
+
+func (x *Registry) afterLoad() {}
+func (x *Registry) load(m state.Map) {
+ m.Load("userNS", &x.userNS)
+ m.Load("semaphores", &x.semaphores)
+ m.Load("lastIDUsed", &x.lastIDUsed)
+}
+
+func (x *Set) beforeSave() {}
+func (x *Set) save(m state.Map) {
+ x.beforeSave()
+ m.Save("registry", &x.registry)
+ m.Save("ID", &x.ID)
+ m.Save("key", &x.key)
+ m.Save("creator", &x.creator)
+ m.Save("owner", &x.owner)
+ m.Save("perms", &x.perms)
+ m.Save("opTime", &x.opTime)
+ m.Save("changeTime", &x.changeTime)
+ m.Save("sems", &x.sems)
+ m.Save("dead", &x.dead)
+}
+
+func (x *Set) afterLoad() {}
+func (x *Set) load(m state.Map) {
+ m.Load("registry", &x.registry)
+ m.Load("ID", &x.ID)
+ m.Load("key", &x.key)
+ m.Load("creator", &x.creator)
+ m.Load("owner", &x.owner)
+ m.Load("perms", &x.perms)
+ m.Load("opTime", &x.opTime)
+ m.Load("changeTime", &x.changeTime)
+ m.Load("sems", &x.sems)
+ m.Load("dead", &x.dead)
+}
+
+func (x *sem) beforeSave() {}
+func (x *sem) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.waiters) { m.Failf("waiters is %v, expected zero", x.waiters) }
+ m.Save("value", &x.value)
+ m.Save("pid", &x.pid)
+}
+
+func (x *sem) afterLoad() {}
+func (x *sem) load(m state.Map) {
+ m.Load("value", &x.value)
+ m.Load("pid", &x.pid)
+}
+
+func (x *waiter) beforeSave() {}
+func (x *waiter) save(m state.Map) {
+ x.beforeSave()
+ m.Save("waiterEntry", &x.waiterEntry)
+ m.Save("value", &x.value)
+ m.Save("ch", &x.ch)
+}
+
+func (x *waiter) afterLoad() {}
+func (x *waiter) load(m state.Map) {
+ m.Load("waiterEntry", &x.waiterEntry)
+ m.Load("value", &x.value)
+ m.Load("ch", &x.ch)
+}
+
+func (x *waiterList) beforeSave() {}
+func (x *waiterList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *waiterList) afterLoad() {}
+func (x *waiterList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *waiterEntry) beforeSave() {}
+func (x *waiterEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *waiterEntry) afterLoad() {}
+func (x *waiterEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("semaphore.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load})
+ state.Register("semaphore.Set", (*Set)(nil), state.Fns{Save: (*Set).save, Load: (*Set).load})
+ state.Register("semaphore.sem", (*sem)(nil), state.Fns{Save: (*sem).save, Load: (*sem).load})
+ state.Register("semaphore.waiter", (*waiter)(nil), state.Fns{Save: (*waiter).save, Load: (*waiter).load})
+ state.Register("semaphore.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load})
+ state.Register("semaphore.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load})
+}
diff --git a/pkg/sentry/kernel/semaphore/waiter_list.go b/pkg/sentry/kernel/semaphore/waiter_list.go
new file mode 100755
index 000000000..33e29fb55
--- /dev/null
+++ b/pkg/sentry/kernel/semaphore/waiter_list.go
@@ -0,0 +1,173 @@
+package semaphore
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type waiterElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (waiterElementMapper) linkerFor(elem *waiter) *waiter { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type waiterList struct {
+ head *waiter
+ tail *waiter
+}
+
+// Reset resets list l to the empty state.
+func (l *waiterList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *waiterList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *waiterList) Front() *waiter {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *waiterList) Back() *waiter {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *waiterList) PushFront(e *waiter) {
+ waiterElementMapper{}.linkerFor(e).SetNext(l.head)
+ waiterElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ waiterElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *waiterList) PushBack(e *waiter) {
+ waiterElementMapper{}.linkerFor(e).SetNext(nil)
+ waiterElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ waiterElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *waiterList) PushBackList(m *waiterList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *waiterList) InsertAfter(b, e *waiter) {
+ a := waiterElementMapper{}.linkerFor(b).Next()
+ waiterElementMapper{}.linkerFor(e).SetNext(a)
+ waiterElementMapper{}.linkerFor(e).SetPrev(b)
+ waiterElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ waiterElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *waiterList) InsertBefore(a, e *waiter) {
+ b := waiterElementMapper{}.linkerFor(a).Prev()
+ waiterElementMapper{}.linkerFor(e).SetNext(a)
+ waiterElementMapper{}.linkerFor(e).SetPrev(b)
+ waiterElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ waiterElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *waiterList) Remove(e *waiter) {
+ prev := waiterElementMapper{}.linkerFor(e).Prev()
+ next := waiterElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ waiterElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ waiterElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type waiterEntry struct {
+ next *waiter
+ prev *waiter
+}
+
+// Next returns the entry that follows e in the list.
+func (e *waiterEntry) Next() *waiter {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *waiterEntry) Prev() *waiter {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *waiterEntry) SetNext(elem *waiter) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *waiterEntry) SetPrev(elem *waiter) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo.go b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo.go
new file mode 100755
index 000000000..4bf8719f2
--- /dev/null
+++ b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo.go
@@ -0,0 +1,55 @@
+package kernel
+
+import (
+ "reflect"
+ "strings"
+ "unsafe"
+
+ "fmt"
+ "gvisor.googlesource.com/gvisor/third_party/gvsync"
+)
+
+// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
+// with any writer critical sections in sc.
+func SeqAtomicLoadTaskGoroutineSchedInfo(sc *gvsync.SeqCount, ptr *TaskGoroutineSchedInfo) TaskGoroutineSchedInfo {
+ // This function doesn't use SeqAtomicTryLoad because doing so is
+ // measurably, significantly (~20%) slower; Go is awful at inlining.
+ var val TaskGoroutineSchedInfo
+ for {
+ epoch := sc.BeginRead()
+ if gvsync.RaceEnabled {
+
+ gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ } else {
+
+ val = *ptr
+ }
+ if sc.ReadOk(epoch) {
+ break
+ }
+ }
+ return val
+}
+
+// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section
+// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read
+// would race with a writer critical section, SeqAtomicTryLoad returns
+// (unspecified, false).
+func SeqAtomicTryLoadTaskGoroutineSchedInfo(sc *gvsync.SeqCount, epoch gvsync.SeqCountEpoch, ptr *TaskGoroutineSchedInfo) (TaskGoroutineSchedInfo, bool) {
+ var val TaskGoroutineSchedInfo
+ if gvsync.RaceEnabled {
+ gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ } else {
+ val = *ptr
+ }
+ return val, sc.ReadOk(epoch)
+}
+
+func initTaskGoroutineSchedInfo() {
+ var val TaskGoroutineSchedInfo
+ typ := reflect.TypeOf(val)
+ name := typ.Name()
+ if ptrs := gvsync.PointersInType(typ, name); len(ptrs) != 0 {
+ panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
+ }
+}
diff --git a/pkg/sentry/kernel/session_list.go b/pkg/sentry/kernel/session_list.go
new file mode 100755
index 000000000..9ba27b164
--- /dev/null
+++ b/pkg/sentry/kernel/session_list.go
@@ -0,0 +1,173 @@
+package kernel
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type sessionElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (sessionElementMapper) linkerFor(elem *Session) *Session { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type sessionList struct {
+ head *Session
+ tail *Session
+}
+
+// Reset resets list l to the empty state.
+func (l *sessionList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *sessionList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *sessionList) Front() *Session {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *sessionList) Back() *Session {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *sessionList) PushFront(e *Session) {
+ sessionElementMapper{}.linkerFor(e).SetNext(l.head)
+ sessionElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ sessionElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *sessionList) PushBack(e *Session) {
+ sessionElementMapper{}.linkerFor(e).SetNext(nil)
+ sessionElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ sessionElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *sessionList) PushBackList(m *sessionList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ sessionElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ sessionElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *sessionList) InsertAfter(b, e *Session) {
+ a := sessionElementMapper{}.linkerFor(b).Next()
+ sessionElementMapper{}.linkerFor(e).SetNext(a)
+ sessionElementMapper{}.linkerFor(e).SetPrev(b)
+ sessionElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ sessionElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *sessionList) InsertBefore(a, e *Session) {
+ b := sessionElementMapper{}.linkerFor(a).Prev()
+ sessionElementMapper{}.linkerFor(e).SetNext(a)
+ sessionElementMapper{}.linkerFor(e).SetPrev(b)
+ sessionElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ sessionElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *sessionList) Remove(e *Session) {
+ prev := sessionElementMapper{}.linkerFor(e).Prev()
+ next := sessionElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ sessionElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ sessionElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type sessionEntry struct {
+ next *Session
+ prev *Session
+}
+
+// Next returns the entry that follows e in the list.
+func (e *sessionEntry) Next() *Session {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *sessionEntry) Prev() *Session {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *sessionEntry) SetNext(elem *Session) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *sessionEntry) SetPrev(elem *Session) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
new file mode 100644
index 000000000..610e199da
--- /dev/null
+++ b/pkg/sentry/kernel/sessions.go
@@ -0,0 +1,508 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// SessionID is the public identifier.
+type SessionID ThreadID
+
+// ProcessGroupID is the public identifier.
+type ProcessGroupID ThreadID
+
+// Session contains a leader threadgroup and a list of ProcessGroups.
+//
+// +stateify savable
+type Session struct {
+ refs refs.AtomicRefCount
+
+ // leader is the originator of the Session.
+ //
+ // Note that this may no longer be running (and may be reaped), so the
+ // ID is cached upon initial creation. The leader is still required
+ // however, since its PIDNamespace defines the scope of the Session.
+ //
+ // The leader is immutable.
+ leader *ThreadGroup
+
+ // id is the cached identifier in the leader's namespace.
+ //
+ // The id is immutable.
+ id SessionID
+
+ // ProcessGroups is a list of process groups in this Session. This is
+ // protected by TaskSet.mu.
+ processGroups processGroupList
+
+ // sessionEntry is the embed for TaskSet.sessions. This is protected by
+ // TaskSet.mu.
+ sessionEntry
+}
+
+// incRef grabs a reference.
+func (s *Session) incRef() {
+ s.refs.IncRef()
+}
+
+// decRef drops a reference.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (s *Session) decRef() {
+ s.refs.DecRefWithDestructor(func() {
+ // Remove translations from the leader.
+ for ns := s.leader.pidns; ns != nil; ns = ns.parent {
+ id := ns.sids[s]
+ delete(ns.sids, s)
+ delete(ns.sessions, id)
+ }
+
+ // Remove from the list of global Sessions.
+ s.leader.pidns.owner.sessions.Remove(s)
+ })
+}
+
+// ProcessGroup contains an originator threadgroup and a parent Session.
+//
+// +stateify savable
+type ProcessGroup struct {
+ refs refs.AtomicRefCount // not exported.
+
+ // originator is the originator of the group.
+ //
+ // See note re: leader in Session. The same applies here.
+ //
+ // The originator is immutable.
+ originator *ThreadGroup
+
+ // id is the cached identifier in the originator's namespace.
+ //
+ // The id is immutable.
+ id ProcessGroupID
+
+ // Session is the parent Session.
+ //
+ // The session is immutable.
+ session *Session
+
+ // ancestors is the number of thread groups in this process group whose
+ // parent is in a different process group in the same session.
+ //
+ // The name is derived from the fact that process groups where
+ // ancestors is zero are considered "orphans".
+ //
+ // ancestors is protected by TaskSet.mu.
+ ancestors uint32
+
+ // processGroupEntry is the embedded entry for Sessions.groups. This is
+ // protected by TaskSet.mu.
+ processGroupEntry
+}
+
+// Originator retuns the originator of the process group.
+func (pg *ProcessGroup) Originator() *ThreadGroup {
+ return pg.originator
+}
+
+// IsOrphan returns true if this process group is an orphan.
+func (pg *ProcessGroup) IsOrphan() bool {
+ pg.originator.TaskSet().mu.RLock()
+ defer pg.originator.TaskSet().mu.RUnlock()
+ return pg.ancestors == 0
+}
+
+// incRefWithParent grabs a reference.
+//
+// This function is called when this ProcessGroup is being associated with some
+// new ThreadGroup, tg. parentPG is the ProcessGroup of tg's parent
+// ThreadGroup. If tg is init, then parentPG may be nil.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (pg *ProcessGroup) incRefWithParent(parentPG *ProcessGroup) {
+ // We acquire an "ancestor" reference in the case of a nil parent.
+ // This is because the process being associated is init, and init can
+ // never be orphaned (we count it as always having an ancestor).
+ if pg != parentPG && (parentPG == nil || pg.session == parentPG.session) {
+ pg.ancestors++
+ }
+
+ pg.refs.IncRef()
+}
+
+// decRefWithParent drops a reference.
+//
+// parentPG is per incRefWithParent.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) {
+ // See incRefWithParent regarding parent == nil.
+ if pg != parentPG && (parentPG == nil || pg.session == parentPG.session) {
+ pg.ancestors--
+ }
+
+ alive := true
+ pg.refs.DecRefWithDestructor(func() {
+ alive = false // don't bother with handleOrphan.
+
+ // Remove translations from the originator.
+ for ns := pg.originator.pidns; ns != nil; ns = ns.parent {
+ id := ns.pgids[pg]
+ delete(ns.pgids, pg)
+ delete(ns.processGroups, id)
+ }
+
+ // Remove the list of process groups.
+ pg.session.processGroups.Remove(pg)
+ pg.session.decRef()
+ })
+ if alive {
+ pg.handleOrphan()
+ }
+}
+
+// parentPG returns the parent process group.
+//
+// Precondition: callers must hold TaskSet.mu.
+func (tg *ThreadGroup) parentPG() *ProcessGroup {
+ if tg.leader.parent != nil {
+ return tg.leader.parent.tg.processGroup
+ }
+ return nil
+}
+
+// handleOrphan checks whether the process group is an orphan and has any
+// stopped jobs. If yes, then appropriate signals are delivered to each thread
+// group within the process group.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (pg *ProcessGroup) handleOrphan() {
+ // Check if this process is an orphan.
+ if pg.ancestors != 0 {
+ return
+ }
+
+ // See if there are any stopped jobs.
+ hasStopped := false
+ pg.originator.pidns.owner.forEachThreadGroupLocked(func(tg *ThreadGroup) {
+ if tg.processGroup != pg {
+ return
+ }
+ tg.signalHandlers.mu.Lock()
+ if tg.groupStopComplete {
+ hasStopped = true
+ }
+ tg.signalHandlers.mu.Unlock()
+ })
+ if !hasStopped {
+ return
+ }
+
+ // Deliver appropriate signals to all thread groups.
+ pg.originator.pidns.owner.forEachThreadGroupLocked(func(tg *ThreadGroup) {
+ if tg.processGroup != pg {
+ return
+ }
+ tg.signalHandlers.mu.Lock()
+ tg.leader.sendSignalLocked(SignalInfoPriv(linux.SIGHUP), true /* group */)
+ tg.leader.sendSignalLocked(SignalInfoPriv(linux.SIGCONT), true /* group */)
+ tg.signalHandlers.mu.Unlock()
+ })
+
+ return
+}
+
+// Session returns the process group's session without taking a reference.
+func (pg *ProcessGroup) Session() *Session {
+ return pg.session
+}
+
+// SendSignal sends a signal to all processes inside the process group. It is
+// analagous to kernel/signal.c:kill_pgrp.
+func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
+ tasks := pg.originator.TaskSet()
+ tasks.mu.RLock()
+ defer tasks.mu.RUnlock()
+
+ var lastErr error
+ for tg := range tasks.Root.tgids {
+ if tg.ProcessGroup() == pg {
+ tg.signalHandlers.mu.Lock()
+ infoCopy := *info
+ if err := tg.leader.sendSignalLocked(&infoCopy, true /*group*/); err != nil {
+ lastErr = err
+ }
+ tg.signalHandlers.mu.Unlock()
+ }
+ }
+ return lastErr
+}
+
+// CreateSession creates a new Session, with the ThreadGroup as the leader.
+//
+// EPERM may be returned if either the given ThreadGroup is already a Session
+// leader, or a ProcessGroup already exists for the ThreadGroup's ID.
+func (tg *ThreadGroup) CreateSession() error {
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ return tg.createSession()
+}
+
+// createSession creates a new session for a threadgroup.
+//
+// Precondition: callers must hold TaskSet.mu for writing.
+func (tg *ThreadGroup) createSession() error {
+ // Get the ID for this thread in the current namespace.
+ id := tg.pidns.tgids[tg]
+
+ // Check if this ThreadGroup already leads a Session, or
+ // if the proposed group is already taken.
+ for s := tg.pidns.owner.sessions.Front(); s != nil; s = s.Next() {
+ if s.leader.pidns != tg.pidns {
+ continue
+ }
+ if s.leader == tg {
+ return syserror.EPERM
+ }
+ if s.id == SessionID(id) {
+ return syserror.EPERM
+ }
+ for pg := s.processGroups.Front(); pg != nil; pg = pg.Next() {
+ if pg.id == ProcessGroupID(id) {
+ return syserror.EPERM
+ }
+ }
+ }
+
+ // Create a new Session, with a single reference.
+ s := &Session{
+ id: SessionID(id),
+ leader: tg,
+ }
+
+ // Create a new ProcessGroup, belonging to that Session.
+ // This also has a single reference (assigned below).
+ //
+ // Note that since this is a new session and a new process group, there
+ // will be zero ancestors for this process group. (It is an orphan at
+ // this point.)
+ pg := &ProcessGroup{
+ id: ProcessGroupID(id),
+ originator: tg,
+ session: s,
+ ancestors: 0,
+ }
+
+ // Tie them and return the result.
+ s.processGroups.PushBack(pg)
+ tg.pidns.owner.sessions.PushBack(s)
+
+ // Leave the current group, and assign the new one.
+ if tg.processGroup != nil {
+ oldParentPG := tg.parentPG()
+ tg.forEachChildThreadGroupLocked(func(childTG *ThreadGroup) {
+ childTG.processGroup.incRefWithParent(pg)
+ childTG.processGroup.decRefWithParent(oldParentPG)
+ })
+ tg.processGroup.decRefWithParent(oldParentPG)
+ tg.processGroup = pg
+ } else {
+ // The current process group may be nil only in the case of an
+ // unparented thread group (i.e. the init process). This would
+ // not normally occur, but we allow it for the convenience of
+ // CreateSession working from that point. There will be no
+ // child processes. We always say that the very first group
+ // created has ancestors (avoids checks elsewhere).
+ //
+ // Note that this mirrors the parent == nil logic in
+ // incRef/decRef/reparent, which counts nil as an ancestor.
+ tg.processGroup = pg
+ tg.processGroup.ancestors++
+ }
+
+ // Ensure a translation is added to all namespaces.
+ for ns := tg.pidns; ns != nil; ns = ns.parent {
+ local := ns.tgids[tg]
+ ns.sids[s] = SessionID(local)
+ ns.sessions[SessionID(local)] = s
+ ns.pgids[pg] = ProcessGroupID(local)
+ ns.processGroups[ProcessGroupID(local)] = pg
+ }
+
+ return nil
+}
+
+// CreateProcessGroup creates a new process group.
+//
+// An EPERM error will be returned if the ThreadGroup belongs to a different
+// Session, is a Session leader or the group already exists.
+func (tg *ThreadGroup) CreateProcessGroup() error {
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+
+ // Get the ID for this thread in the current namespace.
+ id := tg.pidns.tgids[tg]
+
+ // Per above, check for a Session leader or existing group.
+ for s := tg.pidns.owner.sessions.Front(); s != nil; s = s.Next() {
+ if s.leader.pidns != tg.pidns {
+ continue
+ }
+ if s.leader == tg {
+ return syserror.EPERM
+ }
+ for pg := s.processGroups.Front(); pg != nil; pg = pg.Next() {
+ if pg.id == ProcessGroupID(id) {
+ return syserror.EPERM
+ }
+ }
+ }
+
+ // Create a new ProcessGroup, belonging to the current Session.
+ //
+ // We manually adjust the ancestors if the parent is in the same
+ // session.
+ tg.processGroup.session.incRef()
+ pg := &ProcessGroup{
+ id: ProcessGroupID(id),
+ originator: tg,
+ session: tg.processGroup.session,
+ }
+ if tg.leader.parent != nil && tg.leader.parent.tg.processGroup.session == pg.session {
+ pg.ancestors++
+ }
+
+ // Assign the new process group; adjust children.
+ oldParentPG := tg.parentPG()
+ tg.forEachChildThreadGroupLocked(func(childTG *ThreadGroup) {
+ childTG.processGroup.incRefWithParent(pg)
+ childTG.processGroup.decRefWithParent(oldParentPG)
+ })
+ tg.processGroup.decRefWithParent(oldParentPG)
+ tg.processGroup = pg
+
+ // Add the new process group to the session.
+ pg.session.processGroups.PushBack(pg)
+
+ // Ensure this translation is added to all namespaces.
+ for ns := tg.pidns; ns != nil; ns = ns.parent {
+ local := ns.tgids[tg]
+ ns.pgids[pg] = ProcessGroupID(local)
+ ns.processGroups[ProcessGroupID(local)] = pg
+ }
+
+ return nil
+}
+
+// JoinProcessGroup joins an existing process group.
+//
+// This function will return EACCES if an exec has been performed since fork
+// by the given ThreadGroup, and EPERM if the Sessions are not the same or the
+// group does not exist.
+//
+// If checkExec is set, then the join is not permitted after the process has
+// executed exec at least once.
+func (tg *ThreadGroup) JoinProcessGroup(pidns *PIDNamespace, pgid ProcessGroupID, checkExec bool) error {
+ pidns.owner.mu.Lock()
+ defer pidns.owner.mu.Unlock()
+
+ // Lookup the ProcessGroup.
+ pg := pidns.processGroups[pgid]
+ if pg == nil {
+ return syserror.EPERM
+ }
+
+ // Disallow the join if an execve has performed, per POSIX.
+ if checkExec && tg.execed {
+ return syserror.EACCES
+ }
+
+ // See if it's in the same session as ours.
+ if pg.session != tg.processGroup.session {
+ return syserror.EPERM
+ }
+
+ // Join the group; adjust children.
+ parentPG := tg.parentPG()
+ pg.incRefWithParent(parentPG)
+ tg.forEachChildThreadGroupLocked(func(childTG *ThreadGroup) {
+ childTG.processGroup.incRefWithParent(pg)
+ childTG.processGroup.decRefWithParent(tg.processGroup)
+ })
+ tg.processGroup.decRefWithParent(parentPG)
+ tg.processGroup = pg
+
+ return nil
+}
+
+// Session returns the ThreadGroup's Session.
+//
+// A reference is not taken on the session.
+func (tg *ThreadGroup) Session() *Session {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.processGroup.session
+}
+
+// IDOfSession returns the Session assigned to s in PID namespace ns.
+//
+// If this group isn't visible in this namespace, zero will be returned. It is
+// the callers responsibility to check that before using this function.
+func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID {
+ pidns.owner.mu.RLock()
+ defer pidns.owner.mu.RUnlock()
+ return pidns.sids[s]
+}
+
+// SessionWithID returns the Session with the given ID in the PID namespace ns,
+// or nil if that given ID is not defined in this namespace.
+//
+// A reference is not taken on the session.
+func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session {
+ pidns.owner.mu.RLock()
+ defer pidns.owner.mu.RUnlock()
+ return pidns.sessions[id]
+}
+
+// ProcessGroup returns the ThreadGroup's ProcessGroup.
+//
+// A reference is not taken on the process group.
+func (tg *ThreadGroup) ProcessGroup() *ProcessGroup {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.processGroup
+}
+
+// IDOfProcessGroup returns the process group assigned to pg in PID namespace ns.
+//
+// The same constraints apply as IDOfSession.
+func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID {
+ pidns.owner.mu.RLock()
+ defer pidns.owner.mu.RUnlock()
+ return pidns.pgids[pg]
+}
+
+// ProcessGroupWithID returns the ProcessGroup with the given ID in the PID
+// namespace ns, or nil if that given ID is not defined in this namespace.
+//
+// A reference is not taken on the process group.
+func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup {
+ pidns.owner.mu.RLock()
+ defer pidns.owner.mu.RUnlock()
+ return pidns.processGroups[id]
+}
diff --git a/pkg/sentry/kernel/shm/device.go b/pkg/sentry/kernel/shm/device.go
new file mode 100644
index 000000000..3cb759072
--- /dev/null
+++ b/pkg/sentry/kernel/shm/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shm
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// shmDevice is the kernel shm device.
+var shmDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
new file mode 100644
index 000000000..00393b5f0
--- /dev/null
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -0,0 +1,671 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package shm implements sysv shared memory segments.
+//
+// Known missing features:
+//
+// - SHM_LOCK/SHM_UNLOCK are no-ops. The sentry currently doesn't implement
+// memory locking in general.
+//
+// - SHM_HUGETLB and related flags for shmget(2) are ignored. There's no easy
+// way to implement hugetlb support on a per-map basis, and it has no impact
+// on correctness.
+//
+// - SHM_NORESERVE for shmget(2) is ignored, the sentry doesn't implement swap
+// so it's meaningless to reserve space for swap.
+//
+// - No per-process segment size enforcement. This feature probably isn't used
+// much anyways, since Linux sets the per-process limits to the system-wide
+// limits by default.
+//
+// Lock ordering: mm.mappingMu -> shm registry lock -> shm lock
+package shm
+
+import (
+ "fmt"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Key represents a shm segment key. Analogous to a file name.
+type Key int32
+
+// ID represents the opaque handle for a shm segment. Analogous to an fd.
+type ID int32
+
+// Registry tracks all shared memory segments in an IPC namespace. The registry
+// provides the mechanisms for creating and finding segments, and reporting
+// global shm parameters.
+//
+// +stateify savable
+type Registry struct {
+ // userNS owns the IPC namespace this registry belong to. Immutable.
+ userNS *auth.UserNamespace
+
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // shms maps segment ids to segments.
+ shms map[ID]*Shm
+
+ // keysToShms maps segment keys to segments.
+ keysToShms map[Key]*Shm
+
+ // Sum of the sizes of all existing segments rounded up to page size, in
+ // units of page size.
+ totalPages uint64
+
+ // ID assigned to the last created segment. Used to quickly find the next
+ // unused ID.
+ lastIDUsed ID
+}
+
+// NewRegistry creates a new shm registry.
+func NewRegistry(userNS *auth.UserNamespace) *Registry {
+ return &Registry{
+ userNS: userNS,
+ shms: make(map[ID]*Shm),
+ keysToShms: make(map[Key]*Shm),
+ }
+}
+
+// FindByID looks up a segment given an ID.
+func (r *Registry) FindByID(id ID) *Shm {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return r.shms[id]
+}
+
+// dissociateKey removes the association between a segment and its key,
+// preventing it from being discovered in the registry. This doesn't necessarily
+// mean the segment is about to be destroyed. This is analogous to unlinking a
+// file; the segment can still be used by a process already referencing it, but
+// cannot be discovered by a new process.
+func (r *Registry) dissociateKey(s *Shm) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.key != linux.IPC_PRIVATE {
+ delete(r.keysToShms, s.key)
+ s.key = linux.IPC_PRIVATE
+ }
+}
+
+// FindOrCreate looks up or creates a segment in the registry. It's functionally
+// analogous to open(2).
+func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
+ if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) {
+ // "A new segment was to be created and size is less than SHMMIN or
+ // greater than SHMMAX." - man shmget(2)
+ //
+ // Note that 'private' always implies the creation of a new segment
+ // whether IPC_CREAT is specified or not.
+ return nil, syserror.EINVAL
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if len(r.shms) >= linux.SHMMNI {
+ // "All possible shared memory IDs have been taken (SHMMNI) ..."
+ // - man shmget(2)
+ return nil, syserror.ENOSPC
+ }
+
+ if !private {
+ // Look up an existing segment.
+ if shm := r.keysToShms[key]; shm != nil {
+ shm.mu.Lock()
+ defer shm.mu.Unlock()
+
+ // Check that caller can access the segment.
+ if !shm.checkPermissions(ctx, fs.PermsFromMode(mode)) {
+ // "The user does not have permission to access the shared
+ // memory segment, and does not have the CAP_IPC_OWNER
+ // capability in the user namespace that governs its IPC
+ // namespace." - man shmget(2)
+ return nil, syserror.EACCES
+ }
+
+ if size > shm.size {
+ // "A segment for the given key exists, but size is greater than
+ // the size of that segment." - man shmget(2)
+ return nil, syserror.EINVAL
+ }
+
+ if create && exclusive {
+ // "IPC_CREAT and IPC_EXCL were specified in shmflg, but a
+ // shared memory segment already exists for key."
+ // - man shmget(2)
+ return nil, syserror.EEXIST
+ }
+
+ return shm, nil
+ }
+
+ if !create {
+ // "No segment exists for the given key, and IPC_CREAT was not
+ // specified." - man shmget(2)
+ return nil, syserror.ENOENT
+ }
+ }
+
+ var sizeAligned uint64
+ if val, ok := usermem.Addr(size).RoundUp(); ok {
+ sizeAligned = uint64(val)
+ } else {
+ return nil, syserror.EINVAL
+ }
+
+ if numPages := sizeAligned / usermem.PageSize; r.totalPages+numPages > linux.SHMALL {
+ // "... allocating a segment of the requested size would cause the
+ // system to exceed the system-wide limit on shared memory (SHMALL)."
+ // - man shmget(2)
+ return nil, syserror.ENOSPC
+ }
+
+ // Need to create a new segment.
+ creator := fs.FileOwnerFromContext(ctx)
+ perms := fs.FilePermsFromMode(mode)
+ return r.newShm(ctx, pid, key, creator, perms, size)
+}
+
+// newShm creates a new segment in the registry.
+//
+// Precondition: Caller must hold r.mu.
+func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.FileOwner, perms fs.FilePermissions, size uint64) (*Shm, error) {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
+ panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
+ }
+
+ effectiveSize := uint64(usermem.Addr(size).MustRoundUp())
+ fr, err := mfp.MemoryFile().Allocate(effectiveSize, usage.Anonymous)
+ if err != nil {
+ return nil, err
+ }
+
+ shm := &Shm{
+ mfp: mfp,
+ registry: r,
+ creator: creator,
+ size: size,
+ effectiveSize: effectiveSize,
+ fr: fr,
+ key: key,
+ perms: perms,
+ owner: creator,
+ creatorPID: pid,
+ changeTime: ktime.NowFromContext(ctx),
+ }
+
+ // Find the next available ID.
+ for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ {
+ // Handle wrap around.
+ if id < 0 {
+ id = 0
+ continue
+ }
+ if r.shms[id] == nil {
+ r.lastIDUsed = id
+
+ shm.ID = id
+ r.shms[id] = shm
+ r.keysToShms[key] = shm
+
+ r.totalPages += effectiveSize / usermem.PageSize
+
+ return shm, nil
+ }
+ }
+
+ log.Warningf("Shm ids exhuasted, they may be leaking")
+ return nil, syserror.ENOSPC
+}
+
+// IPCInfo reports global parameters for sysv shared memory segments on this
+// system. See shmctl(IPC_INFO).
+func (r *Registry) IPCInfo() *linux.ShmParams {
+ return &linux.ShmParams{
+ ShmMax: linux.SHMMAX,
+ ShmMin: linux.SHMMIN,
+ ShmMni: linux.SHMMNI,
+ ShmSeg: linux.SHMSEG,
+ ShmAll: linux.SHMALL,
+ }
+}
+
+// ShmInfo reports linux-specific global parameters for sysv shared memory
+// segments on this system. See shmctl(SHM_INFO).
+func (r *Registry) ShmInfo() *linux.ShmInfo {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ return &linux.ShmInfo{
+ UsedIDs: int32(r.lastIDUsed),
+ ShmTot: r.totalPages,
+ ShmRss: r.totalPages, // We could probably get a better estimate from memory accounting.
+ ShmSwp: 0, // No reclaim at the moment.
+ }
+}
+
+// remove deletes a segment from this registry, deaccounting the memory used by
+// the segment.
+//
+// Precondition: Must follow a call to r.dissociateKey(s).
+func (r *Registry) remove(s *Shm) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.key != linux.IPC_PRIVATE {
+ panic(fmt.Sprintf("Attempted to remove %s from the registry whose key is still associated", s.debugLocked()))
+ }
+
+ delete(r.shms, s.ID)
+ r.totalPages -= s.effectiveSize / usermem.PageSize
+}
+
+// Shm represents a single shared memory segment.
+//
+// Shm segment are backed directly by an allocation from platform
+// memory. Segments are always mapped as a whole, greatly simplifying how
+// mappings are tracked. However note that mremap and munmap calls may cause the
+// vma for a segment to become fragmented; which requires special care when
+// unmapping a segment. See mm/shm.go.
+//
+// Segments persist until they are explicitly marked for destruction via
+// shmctl(SHM_RMID).
+//
+// Shm implements memmap.Mappable and memmap.MappingIdentity.
+//
+// +stateify savable
+type Shm struct {
+ // AtomicRefCount tracks the number of references to this segment from
+ // maps. A segment always holds a reference to itself, until it's marked for
+ // destruction.
+ refs.AtomicRefCount
+
+ mfp pgalloc.MemoryFileProvider
+
+ // registry points to the shm registry containing this segment. Immutable.
+ registry *Registry
+
+ // ID is the kernel identifier for this segment. Immutable.
+ ID ID
+
+ // creator is the user that created the segment. Immutable.
+ creator fs.FileOwner
+
+ // size is the requested size of the segment at creation, in
+ // bytes. Immutable.
+ size uint64
+
+ // effectiveSize of the segment, rounding up to the next page
+ // boundary. Immutable.
+ //
+ // Invariant: effectiveSize must be a multiple of usermem.PageSize.
+ effectiveSize uint64
+
+ // fr is the offset into mfp.MemoryFile() that backs this contents of this
+ // segment. Immutable.
+ fr platform.FileRange
+
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // key is the public identifier for this segment.
+ key Key
+
+ // perms is the access permissions for the segment.
+ perms fs.FilePermissions
+
+ // owner of this segment.
+ owner fs.FileOwner
+ // attachTime is updated on every successful shmat.
+ attachTime ktime.Time
+ // detachTime is updated on every successful shmdt.
+ detachTime ktime.Time
+ // changeTime is updated on every successful changes to the segment via
+ // shmctl(IPC_SET).
+ changeTime ktime.Time
+
+ // creatorPID is the PID of the process that created the segment.
+ creatorPID int32
+ // lastAttachDetachPID is the pid of the process that issued the last shmat
+ // or shmdt syscall.
+ lastAttachDetachPID int32
+
+ // pendingDestruction indicates the segment was marked as destroyed through
+ // shmctl(IPC_RMID). When marked as destroyed, the segment will not be found
+ // in the registry and can no longer be attached. When the last user
+ // detaches from the segment, it is destroyed.
+ pendingDestruction bool
+}
+
+// Precondition: Caller must hold s.mu.
+func (s *Shm) debugLocked() string {
+ return fmt.Sprintf("Shm{id: %d, key: %d, size: %d bytes, refs: %d, destroyed: %v}",
+ s.ID, s.key, s.size, s.ReadRefs(), s.pendingDestruction)
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (s *Shm) MappedName(ctx context.Context) string {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return fmt.Sprintf("SYSV%08d", s.key)
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (s *Shm) DeviceID() uint64 {
+ return shmDevice.DeviceID()
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (s *Shm) InodeID() uint64 {
+ // "shmid gets reported as "inode#" in /proc/pid/maps. proc-ps tools use
+ // this. Changing this will break them." -- Linux, ipc/shm.c:newseg()
+ return uint64(s.ID)
+}
+
+// DecRef overrides refs.RefCount.DecRef with a destructor.
+//
+// Precondition: Caller must not hold s.mu.
+func (s *Shm) DecRef() {
+ s.DecRefWithDestructor(s.destroy)
+}
+
+// Msync implements memmap.MappingIdentity.Msync. Msync is a no-op for shm
+// segments.
+func (s *Shm) Msync(context.Context, memmap.MappableRange) error {
+ return nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.attachTime = ktime.NowFromContext(ctx)
+ if pid, ok := context.ThreadGroupIDFromContext(ctx); ok {
+ s.lastAttachDetachPID = pid
+ } else {
+ // AddMapping is called during a syscall, so ctx should always be a task
+ // context.
+ log.Warningf("Adding mapping to %s but couldn't get the current pid; not updating the last attach pid", s.debugLocked())
+ }
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ // TODO(b/38173783): RemoveMapping may be called during task exit, when ctx
+ // is context.Background. Gracefully handle missing clocks. Failing to
+ // update the detach time in these cases is ok, since no one can observe the
+ // omission.
+ if clock := ktime.RealtimeClockFromContext(ctx); clock != nil {
+ s.detachTime = clock.Now()
+ }
+
+ // If called from a non-task context we also won't have a threadgroup
+ // id. Silently skip updating the lastAttachDetachPid in that case.
+ if pid, ok := context.ThreadGroupIDFromContext(ctx); ok {
+ s.lastAttachDetachPID = pid
+ } else {
+ log.Debugf("Couldn't obtain pid when removing mapping to %s, not updating the last detach pid.", s.debugLocked())
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (*Shm) CopyMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, usermem.AddrRange, uint64, bool) error {
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (s *Shm) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ var err error
+ if required.End > s.fr.Length() {
+ err = &memmap.BusError{syserror.EFAULT}
+ }
+ if source := optional.Intersect(memmap.MappableRange{0, s.fr.Length()}); source.Length() != 0 {
+ return []memmap.Translation{
+ {
+ Source: source,
+ File: s.mfp.MemoryFile(),
+ Offset: s.fr.Start + source.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, err
+ }
+ return nil, err
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (s *Shm) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
+
+// AttachOpts describes various flags passed to shmat(2).
+type AttachOpts struct {
+ Execute bool
+ Readonly bool
+ Remap bool
+}
+
+// ConfigureAttach creates an mmap configuration for the segment with the
+// requested attach options.
+//
+// ConfigureAttach returns with a ref on s on success. The caller should drop
+// this once the map is installed. This reference prevents s from being
+// destroyed before the returned configuration is used.
+func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.pendingDestruction && s.ReadRefs() == 0 {
+ return memmap.MMapOpts{}, syserror.EIDRM
+ }
+
+ if !s.checkPermissions(ctx, fs.PermMask{
+ Read: true,
+ Write: !opts.Readonly,
+ Execute: opts.Execute,
+ }) {
+ // "The calling process does not have the required permissions for the
+ // requested attach type, and does not have the CAP_IPC_OWNER capability
+ // in the user namespace that governs its IPC namespace." - man shmat(2)
+ return memmap.MMapOpts{}, syserror.EACCES
+ }
+ s.IncRef()
+ return memmap.MMapOpts{
+ Length: s.size,
+ Offset: 0,
+ Addr: addr,
+ Fixed: opts.Remap,
+ Perms: usermem.AccessType{
+ Read: true,
+ Write: !opts.Readonly,
+ Execute: opts.Execute,
+ },
+ MaxPerms: usermem.AnyAccess,
+ Mappable: s,
+ MappingIdentity: s,
+ }, nil
+}
+
+// EffectiveSize returns the size of the underlying shared memory segment. This
+// may be larger than the requested size at creation, due to rounding to page
+// boundaries.
+func (s *Shm) EffectiveSize() uint64 {
+ return s.effectiveSize
+}
+
+// IPCStat returns information about a shm. See shmctl(IPC_STAT).
+func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The caller must have read permission on the shared memory segment."
+ // - man shmctl(2)
+ if !s.checkPermissions(ctx, fs.PermMask{Read: true}) {
+ // "IPC_STAT or SHM_STAT is requested and shm_perm.mode does not allow
+ // read access for shmid, and the calling process does not have the
+ // CAP_IPC_OWNER capability in the user namespace that governs its IPC
+ // namespace." - man shmctl(2)
+ return nil, syserror.EACCES
+ }
+
+ var mode uint16
+ if s.pendingDestruction {
+ mode |= linux.SHM_DEST
+ }
+ creds := auth.CredentialsFromContext(ctx)
+
+ nattach := uint64(s.ReadRefs())
+ // Don't report the self-reference we keep prior to being marked for
+ // destruction. However, also don't report a count of -1 for segments marked
+ // as destroyed, with no mappings.
+ if !s.pendingDestruction {
+ nattach--
+ }
+
+ ds := &linux.ShmidDS{
+ ShmPerm: linux.IPCPerm{
+ Key: uint32(s.key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)),
+ Mode: mode | uint16(s.perms.LinuxMode()),
+ Seq: 0, // IPC sequences not supported.
+ },
+ ShmSegsz: s.size,
+ ShmAtime: s.attachTime.TimeT(),
+ ShmDtime: s.detachTime.TimeT(),
+ ShmCtime: s.changeTime.TimeT(),
+ ShmCpid: s.creatorPID,
+ ShmLpid: s.lastAttachDetachPID,
+ ShmNattach: nattach,
+ }
+
+ return ds, nil
+}
+
+// Set modifies attributes for a segment. See shmctl(IPC_SET).
+func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !s.checkOwnership(ctx) {
+ return syserror.EPERM
+ }
+
+ creds := auth.CredentialsFromContext(ctx)
+ uid := creds.UserNamespace.MapToKUID(auth.UID(ds.ShmPerm.UID))
+ gid := creds.UserNamespace.MapToKGID(auth.GID(ds.ShmPerm.GID))
+ if !uid.Ok() || !gid.Ok() {
+ return syserror.EINVAL
+ }
+
+ // User may only modify the lower 9 bits of the mode. All the other bits are
+ // always 0 for the underlying inode.
+ mode := linux.FileMode(ds.ShmPerm.Mode & 0x1ff)
+ s.perms = fs.FilePermsFromMode(mode)
+
+ s.owner.UID = uid
+ s.owner.GID = gid
+
+ s.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
+func (s *Shm) destroy() {
+ s.mfp.MemoryFile().DecRef(s.fr)
+ s.registry.remove(s)
+}
+
+// MarkDestroyed marks a segment for destruction. The segment is actually
+// destroyed once it has no references. MarkDestroyed may be called multiple
+// times, and is safe to call after a segment has already been destroyed. See
+// shmctl(IPC_RMID).
+func (s *Shm) MarkDestroyed() {
+ s.registry.dissociateKey(s)
+
+ s.mu.Lock()
+ // Only drop the segment's self-reference once, when destruction is
+ // requested. Otherwise, repeated calls to shmctl(IPC_RMID) would force a
+ // segment to be destroyed prematurely, potentially with active maps to the
+ // segment's address range. Remaining references are dropped when the
+ // segment is detached or unmaped.
+ if !s.pendingDestruction {
+ s.pendingDestruction = true
+ s.mu.Unlock() // Must release s.mu before calling s.DecRef.
+ s.DecRef()
+ return
+ }
+ s.mu.Unlock()
+}
+
+// checkOwnership verifies whether a segment may be accessed by ctx as an
+// owner. See ipc/util.c:ipcctl_pre_down_nolock() in Linux.
+//
+// Precondition: Caller must hold s.mu.
+func (s *Shm) checkOwnership(ctx context.Context) bool {
+ creds := auth.CredentialsFromContext(ctx)
+ if s.owner.UID == creds.EffectiveKUID || s.creator.UID == creds.EffectiveKUID {
+ return true
+ }
+
+ // Tasks with CAP_SYS_ADMIN may bypass ownership checks. Strangely, Linux
+ // doesn't use CAP_IPC_OWNER for this despite CAP_IPC_OWNER being documented
+ // for use to "override IPC ownership checks".
+ return creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, s.registry.userNS)
+}
+
+// checkPermissions verifies whether a segment is accessible by ctx for access
+// described by req. See ipc/util.c:ipcperms() in Linux.
+//
+// Precondition: Caller must hold s.mu.
+func (s *Shm) checkPermissions(ctx context.Context, req fs.PermMask) bool {
+ creds := auth.CredentialsFromContext(ctx)
+
+ p := s.perms.Other
+ if s.owner.UID == creds.EffectiveKUID {
+ p = s.perms.User
+ } else if creds.InGroup(s.owner.GID) {
+ p = s.perms.Group
+ }
+ if p.SupersetOf(req) {
+ return true
+ }
+
+ // Tasks with CAP_IPC_OWNER may bypass permission checks.
+ return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, s.registry.userNS)
+}
diff --git a/pkg/sentry/kernel/shm/shm_state_autogen.go b/pkg/sentry/kernel/shm/shm_state_autogen.go
new file mode 100755
index 000000000..d94d01fce
--- /dev/null
+++ b/pkg/sentry/kernel/shm/shm_state_autogen.go
@@ -0,0 +1,74 @@
+// automatically generated by stateify.
+
+package shm
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Registry) beforeSave() {}
+func (x *Registry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("userNS", &x.userNS)
+ m.Save("shms", &x.shms)
+ m.Save("keysToShms", &x.keysToShms)
+ m.Save("totalPages", &x.totalPages)
+ m.Save("lastIDUsed", &x.lastIDUsed)
+}
+
+func (x *Registry) afterLoad() {}
+func (x *Registry) load(m state.Map) {
+ m.Load("userNS", &x.userNS)
+ m.Load("shms", &x.shms)
+ m.Load("keysToShms", &x.keysToShms)
+ m.Load("totalPages", &x.totalPages)
+ m.Load("lastIDUsed", &x.lastIDUsed)
+}
+
+func (x *Shm) beforeSave() {}
+func (x *Shm) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("mfp", &x.mfp)
+ m.Save("registry", &x.registry)
+ m.Save("ID", &x.ID)
+ m.Save("creator", &x.creator)
+ m.Save("size", &x.size)
+ m.Save("effectiveSize", &x.effectiveSize)
+ m.Save("fr", &x.fr)
+ m.Save("key", &x.key)
+ m.Save("perms", &x.perms)
+ m.Save("owner", &x.owner)
+ m.Save("attachTime", &x.attachTime)
+ m.Save("detachTime", &x.detachTime)
+ m.Save("changeTime", &x.changeTime)
+ m.Save("creatorPID", &x.creatorPID)
+ m.Save("lastAttachDetachPID", &x.lastAttachDetachPID)
+ m.Save("pendingDestruction", &x.pendingDestruction)
+}
+
+func (x *Shm) afterLoad() {}
+func (x *Shm) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("mfp", &x.mfp)
+ m.Load("registry", &x.registry)
+ m.Load("ID", &x.ID)
+ m.Load("creator", &x.creator)
+ m.Load("size", &x.size)
+ m.Load("effectiveSize", &x.effectiveSize)
+ m.Load("fr", &x.fr)
+ m.Load("key", &x.key)
+ m.Load("perms", &x.perms)
+ m.Load("owner", &x.owner)
+ m.Load("attachTime", &x.attachTime)
+ m.Load("detachTime", &x.detachTime)
+ m.Load("changeTime", &x.changeTime)
+ m.Load("creatorPID", &x.creatorPID)
+ m.Load("lastAttachDetachPID", &x.lastAttachDetachPID)
+ m.Load("pendingDestruction", &x.pendingDestruction)
+}
+
+func init() {
+ state.Register("shm.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load})
+ state.Register("shm.Shm", (*Shm)(nil), state.Fns{Save: (*Shm).save, Load: (*Shm).load})
+}
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
new file mode 100644
index 000000000..b528ec0dc
--- /dev/null
+++ b/pkg/sentry/kernel/signal.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+)
+
+// SignalPanic is used to panic the running threads. It is a signal which
+// cannot be used by the application: it must be caught and ignored by the
+// runtime (in order to catch possible races).
+const SignalPanic = linux.SIGUSR2
+
+// sendExternalSignal is called when an asynchronous signal is sent to the
+// sentry ("in sentry context"). On some platforms, it may also be called when
+// an asynchronous signal is sent to sandboxed application threads ("in
+// application context").
+//
+// context is used only for debugging to differentiate these cases.
+//
+// Preconditions: Kernel must have an init process.
+func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) {
+ switch linux.Signal(info.Signo) {
+ case platform.SignalInterrupt:
+ // Assume that a call to platform.Context.Interrupt() misfired.
+
+ case SignalPanic:
+ // SignalPanic is also specially handled in sentry setup to ensure that
+ // it causes a panic even after tasks exit, but SignalPanic may also
+ // be sent here if it is received while in app context.
+ panic("Signal-induced panic")
+
+ default:
+ log.Infof("Received external signal %d in %s context", info.Signo, context)
+ if k.globalInit == nil {
+ panic(fmt.Sprintf("Received external signal %d before init created", info.Signo))
+ }
+ k.globalInit.SendSignal(info)
+ }
+}
+
+// SignalInfoPriv returns a SignalInfo equivalent to Linux's SEND_SIG_PRIV.
+func SignalInfoPriv(sig linux.Signal) *arch.SignalInfo {
+ return &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoKernel,
+ }
+}
+
+// SignalInfoNoInfo returns a SignalInfo equivalent to Linux's SEND_SIG_NOINFO.
+func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo {
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
+ info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ return info
+}
diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go
new file mode 100644
index 000000000..ce8bcb5e5
--- /dev/null
+++ b/pkg/sentry/kernel/signal_handlers.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+)
+
+// SignalHandlers holds information about signal actions.
+//
+// +stateify savable
+type SignalHandlers struct {
+ // mu protects actions, as well as the signal state of all tasks and thread
+ // groups using this SignalHandlers object. (See comment on
+ // ThreadGroup.signalHandlers.)
+ mu sync.Mutex `state:"nosave"`
+
+ // actions is the action to be taken upon receiving each signal.
+ actions map[linux.Signal]arch.SignalAct
+}
+
+// NewSignalHandlers returns a new SignalHandlers specifying all default
+// actions.
+func NewSignalHandlers() *SignalHandlers {
+ return &SignalHandlers{
+ actions: make(map[linux.Signal]arch.SignalAct),
+ }
+}
+
+// Fork returns a copy of sh for a new thread group.
+func (sh *SignalHandlers) Fork() *SignalHandlers {
+ sh2 := NewSignalHandlers()
+ sh.mu.Lock()
+ defer sh.mu.Unlock()
+ for sig, act := range sh.actions {
+ sh2.actions[sig] = act
+ }
+ return sh2
+}
+
+// CopyForExec returns a copy of sh for a thread group that is undergoing an
+// execve. (See comments in Task.finishExec.)
+func (sh *SignalHandlers) CopyForExec() *SignalHandlers {
+ sh2 := NewSignalHandlers()
+ sh.mu.Lock()
+ defer sh.mu.Unlock()
+ for sig, act := range sh.actions {
+ if act.Handler == arch.SignalActIgnore {
+ sh2.actions[sig] = arch.SignalAct{
+ Handler: arch.SignalActIgnore,
+ }
+ }
+ }
+ return sh2
+}
+
+// IsIgnored returns true if the signal is ignored.
+func (sh *SignalHandlers) IsIgnored(sig linux.Signal) bool {
+ sh.mu.Lock()
+ defer sh.mu.Unlock()
+ sa, ok := sh.actions[sig]
+ return ok && sa.Handler == arch.SignalActIgnore
+}
+
+// dequeueActionLocked returns the SignalAct that should be used to handle sig.
+//
+// Preconditions: sh.mu must be locked.
+func (sh *SignalHandlers) dequeueAction(sig linux.Signal) arch.SignalAct {
+ act := sh.actions[sig]
+ if act.IsResetHandler() {
+ delete(sh.actions, sig)
+ }
+ return act
+}
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
new file mode 100644
index 000000000..0572053db
--- /dev/null
+++ b/pkg/sentry/kernel/syscalls.go
@@ -0,0 +1,307 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/bits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// maxSyscallNum is the highest supported syscall number.
+//
+// The types below create fast lookup slices for all syscalls. This maximum
+// serves as a sanity check that we don't allocate huge slices for a very large
+// syscall.
+const maxSyscallNum = 2000
+
+// SyscallFn is a syscall implementation.
+type SyscallFn func(t *Task, args arch.SyscallArguments) (uintptr, *SyscallControl, error)
+
+// MissingFn is a syscall to be called when an implementation is missing.
+type MissingFn func(t *Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error)
+
+// Possible flags for SyscallFlagsTable.enable.
+const (
+ // syscallPresent indicates that this is not a missing syscall.
+ //
+ // This flag is used internally in SyscallFlagsTable.
+ syscallPresent = 1 << iota
+
+ // StraceEnableLog enables syscall log tracing.
+ StraceEnableLog
+
+ // StraceEnableEvent enables syscall event tracing.
+ StraceEnableEvent
+
+ // ExternalBeforeEnable enables the external hook before syscall execution.
+ ExternalBeforeEnable
+
+ // ExternalAfterEnable enables the external hook after syscall execution.
+ ExternalAfterEnable
+)
+
+// StraceEnableBits combines both strace log and event flags.
+const StraceEnableBits = StraceEnableLog | StraceEnableEvent
+
+// SyscallFlagsTable manages a set of enable/disable bit fields on a per-syscall
+// basis.
+type SyscallFlagsTable struct {
+ // mu protects writes to the fields below.
+ //
+ // Atomic loads are always allowed. Atomic stores are allowed only
+ // while mu is held.
+ mu sync.Mutex
+
+ // enable contains the enable bits for each syscall.
+ //
+ // missing syscalls have the same value in enable as missingEnable to
+ // avoid an extra branch in Word.
+ enable []uint32
+
+ // missingEnable contains the enable bits for missing syscalls.
+ missingEnable uint32
+}
+
+// Init initializes the struct, with all syscalls in table set to enable.
+//
+// max is the largest syscall number in table.
+func (e *SyscallFlagsTable) init(table map[uintptr]SyscallFn, max uintptr) {
+ e.enable = make([]uint32, max+1)
+ for num := range table {
+ e.enable[num] = syscallPresent
+ }
+}
+
+// Word returns the enable bitfield for sysno.
+func (e *SyscallFlagsTable) Word(sysno uintptr) uint32 {
+ if sysno < uintptr(len(e.enable)) {
+ return atomic.LoadUint32(&e.enable[sysno])
+ }
+
+ return atomic.LoadUint32(&e.missingEnable)
+}
+
+// Enable sets enable bit bit for all syscalls based on s.
+//
+// Syscalls missing from s are disabled.
+//
+// Syscalls missing from the initial table passed to Init cannot be added as
+// individual syscalls. If present in s they will be ignored.
+//
+// Callers to Word may see either the old or new value while this function
+// is executing.
+func (e *SyscallFlagsTable) Enable(bit uint32, s map[uintptr]bool, missingEnable bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ missingVal := atomic.LoadUint32(&e.missingEnable)
+ if missingEnable {
+ missingVal |= bit
+ } else {
+ missingVal &^= bit
+ }
+ atomic.StoreUint32(&e.missingEnable, missingVal)
+
+ for num := range e.enable {
+ val := atomic.LoadUint32(&e.enable[num])
+ if !bits.IsOn32(val, syscallPresent) {
+ // Missing.
+ atomic.StoreUint32(&e.enable[num], missingVal)
+ continue
+ }
+
+ if s[uintptr(num)] {
+ val |= bit
+ } else {
+ val &^= bit
+ }
+ atomic.StoreUint32(&e.enable[num], val)
+ }
+}
+
+// EnableAll sets enable bit bit for all syscalls, present and missing.
+func (e *SyscallFlagsTable) EnableAll(bit uint32) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ missingVal := atomic.LoadUint32(&e.missingEnable)
+ missingVal |= bit
+ atomic.StoreUint32(&e.missingEnable, missingVal)
+
+ for num := range e.enable {
+ val := atomic.LoadUint32(&e.enable[num])
+ if !bits.IsOn32(val, syscallPresent) {
+ // Missing.
+ atomic.StoreUint32(&e.enable[num], missingVal)
+ continue
+ }
+
+ val |= bit
+ atomic.StoreUint32(&e.enable[num], val)
+ }
+}
+
+// Stracer traces syscall execution.
+type Stracer interface {
+ // SyscallEnter is called on syscall entry.
+ //
+ // The returned private data is passed to SyscallExit.
+ //
+ // TODO(gvisor.dev/issue/155): remove kernel imports from the strace
+ // package so that the type can be used directly.
+ SyscallEnter(t *Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{}
+
+ // SyscallExit is called on syscall exit.
+ SyscallExit(context interface{}, t *Task, sysno, rval uintptr, err error)
+}
+
+// SyscallTable is a lookup table of system calls. Critically, a SyscallTable
+// is *immutable*. In order to make supporting suspend and resume sane, they
+// must be uniquely registered and may not change during operation.
+//
+// +stateify savable
+type SyscallTable struct {
+ // OS is the operating system that this syscall table implements.
+ OS abi.OS `state:"wait"`
+
+ // Arch is the architecture that this syscall table targets.
+ Arch arch.Arch `state:"wait"`
+
+ // The OS version that this syscall table implements.
+ Version Version `state:"manual"`
+
+ // AuditNumber is a numeric constant that represents the syscall table. If
+ // non-zero, auditNumber must be one of the AUDIT_ARCH_* values defined by
+ // linux/audit.h.
+ AuditNumber uint32 `state:"manual"`
+
+ // Table is the collection of functions.
+ Table map[uintptr]SyscallFn `state:"manual"`
+
+ // lookup is a fixed-size array that holds the syscalls (indexed by
+ // their numbers). It is used for fast look ups.
+ lookup []SyscallFn `state:"manual"`
+
+ // Emulate is a collection of instruction addresses to emulate. The
+ // keys are addresses, and the values are system call numbers.
+ Emulate map[usermem.Addr]uintptr `state:"manual"`
+
+ // The function to call in case of a missing system call.
+ Missing MissingFn `state:"manual"`
+
+ // Stracer traces this syscall table.
+ Stracer Stracer `state:"manual"`
+
+ // External is used to handle an external callback.
+ External func(*Kernel) `state:"manual"`
+
+ // ExternalFilterBefore is called before External is called before the syscall is executed.
+ // External is not called if it returns false.
+ ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"`
+
+ // ExternalFilterAfter is called before External is called after the syscall is executed.
+ // External is not called if it returns false.
+ ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"`
+
+ // FeatureEnable stores the strace and one-shot enable bits.
+ FeatureEnable SyscallFlagsTable `state:"manual"`
+}
+
+// allSyscallTables contains all known tables.
+var allSyscallTables []*SyscallTable
+
+// SyscallTables returns a read-only slice of registered SyscallTables.
+func SyscallTables() []*SyscallTable {
+ return allSyscallTables
+}
+
+// LookupSyscallTable returns the SyscallCall table for the OS/Arch combination.
+func LookupSyscallTable(os abi.OS, a arch.Arch) (*SyscallTable, bool) {
+ for _, s := range allSyscallTables {
+ if s.OS == os && s.Arch == a {
+ return s, true
+ }
+ }
+ return nil, false
+}
+
+// RegisterSyscallTable registers a new syscall table for use by a Kernel.
+func RegisterSyscallTable(s *SyscallTable) {
+ if s.Table == nil {
+ // Ensure non-nil lookup table.
+ s.Table = make(map[uintptr]SyscallFn)
+ }
+ if s.Emulate == nil {
+ // Ensure non-nil emulate table.
+ s.Emulate = make(map[usermem.Addr]uintptr)
+ }
+
+ var max uintptr
+ for num := range s.Table {
+ if num > max {
+ max = num
+ }
+ }
+
+ if max > maxSyscallNum {
+ panic(fmt.Sprintf("SyscallTable %+v contains too large syscall number %d", s, max))
+ }
+
+ s.lookup = make([]SyscallFn, max+1)
+
+ // Initialize the fast-lookup table.
+ for num, fn := range s.Table {
+ s.lookup[num] = fn
+ }
+
+ s.FeatureEnable.init(s.Table, max)
+
+ if _, ok := LookupSyscallTable(s.OS, s.Arch); ok {
+ panic(fmt.Sprintf("Duplicate SyscallTable registered for OS %v Arch %v", s.OS, s.Arch))
+ }
+
+ // Save a reference to this table.
+ //
+ // This is required for a Kernel to find the table and for save/restore
+ // operations below.
+ allSyscallTables = append(allSyscallTables, s)
+}
+
+// Lookup returns the syscall implementation, if one exists.
+func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn {
+ if sysno < uintptr(len(s.lookup)) {
+ return s.lookup[sysno]
+ }
+
+ return nil
+}
+
+// LookupEmulate looks up an emulation syscall number.
+func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
+ sysno, ok := s.Emulate[addr]
+ return sysno, ok
+}
+
+// mapLookup is similar to Lookup, except that it only uses the syscall table,
+// that is, it skips the fast look array. This is available for benchmarking.
+func (s *SyscallTable) mapLookup(sysno uintptr) SyscallFn {
+ return s.Table[sysno]
+}
diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go
new file mode 100644
index 000000000..00358326b
--- /dev/null
+++ b/pkg/sentry/kernel/syscalls_state.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "fmt"
+
+// afterLoad is invoked by stateify.
+func (s *SyscallTable) afterLoad() {
+ otherTable, ok := LookupSyscallTable(s.OS, s.Arch)
+ if !ok {
+ // Couldn't find a reference?
+ panic(fmt.Sprintf("syscall table not found for OS %v Arch %v", s.OS, s.Arch))
+ }
+
+ // Copy the table.
+ *s = *otherTable
+}
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
new file mode 100644
index 000000000..175d1b247
--- /dev/null
+++ b/pkg/sentry/kernel/syslog.go
@@ -0,0 +1,106 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "math/rand"
+ "sync"
+)
+
+// syslog represents a sentry-global kernel log.
+//
+// Currently, it contains only fun messages for a dmesg easter egg.
+//
+// +stateify savable
+type syslog struct {
+ // mu protects the below.
+ mu sync.Mutex `state:"nosave"`
+
+ // msg is the syslog message buffer. It is lazily initialized.
+ msg []byte
+}
+
+// Log returns a copy of the syslog.
+func (s *syslog) Log() []byte {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.msg != nil {
+ // Already initialized, just return a copy.
+ o := make([]byte, len(s.msg))
+ copy(o, s.msg)
+ return o
+ }
+
+ // Not initialized, create message.
+ allMessages := []string{
+ "Synthesizing system calls...",
+ "Mounting deweydecimalfs...",
+ "Moving files to filing cabinet...",
+ "Digging up root...",
+ "Constructing home...",
+ "Segmenting fault lines...",
+ "Creating bureaucratic processes...",
+ "Searching for needles in stacks...",
+ "Preparing for the zombie uprising...",
+ "Feeding the init monster...",
+ "Creating cloned children...",
+ "Daemonizing children...",
+ "Waiting for children...",
+ "Gathering forks...",
+ "Committing treasure map to memory...",
+ "Reading process obituaries...",
+ "Searching for socket adapter...",
+ "Creating process schedule...",
+ "Generating random numbers by fair dice roll...",
+ "Rewriting operating system in Javascript...",
+ "Consulting tar man page...",
+ "Forking spaghetti code...",
+ "Checking naughty and nice process list...",
+ "Checking naughty and nice process list...", // Check it up to twice.
+ "Granting licence to kill(2)...", // British spelling for British movie.
+ "Letting the watchdogs out...",
+ }
+
+ selectMessage := func() string {
+ i := rand.Intn(len(allMessages))
+ m := allMessages[i]
+
+ // Delete the selected message.
+ allMessages[i] = allMessages[len(allMessages)-1]
+ allMessages = allMessages[:len(allMessages)-1]
+
+ return m
+ }
+
+ const format = "<6>[%11.6f] %s\n"
+
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, 0.0, "Starting gVisor..."))...)
+
+ time := 0.1
+ for i := 0; i < 10; i++ {
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...)
+ }
+
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...)
+
+ // Return a copy.
+ o := make([]byte, len(s.msg))
+ copy(o, s.msg)
+ return o
+}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
new file mode 100644
index 000000000..f9378c2de
--- /dev/null
+++ b/pkg/sentry/kernel/task.go
@@ -0,0 +1,723 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/bpf"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/sched"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/third_party/gvsync"
+)
+
+// Task represents a thread of execution in the untrusted app. It
+// includes registers and any thread-specific state that you would
+// normally expect.
+//
+// Each task is associated with a goroutine, called the task goroutine, that
+// executes code (application code, system calls, etc.) on behalf of that task.
+// See Task.run (task_run.go).
+//
+// All fields that are "owned by the task goroutine" can only be mutated by the
+// task goroutine while it is running. The task goroutine does not require
+// synchronization to read these fields, although it still requires
+// synchronization as described for those fields to mutate them.
+//
+// All fields that are "exclusive to the task goroutine" can only be accessed
+// by the task goroutine while it is running. The task goroutine does not
+// require synchronization to read or write these fields.
+//
+// +stateify savable
+type Task struct {
+ taskNode
+
+ // runState is what the task goroutine is executing if it is not stopped.
+ // If runState is nil, the task goroutine should exit or has exited.
+ // runState is exclusive to the task goroutine.
+ runState taskRunState
+
+ // haveSyscallReturn is true if tc.Arch().Return() represents a value
+ // returned by a syscall (or set by ptrace after a syscall).
+ //
+ // haveSyscallReturn is exclusive to the task goroutine.
+ haveSyscallReturn bool
+
+ // interruptChan is notified whenever the task goroutine is interrupted
+ // (usually by a pending signal). interruptChan is effectively a condition
+ // variable that can be used in select statements.
+ //
+ // interruptChan is not saved; because saving interrupts all tasks,
+ // interruptChan is always notified after restore (see Task.run).
+ interruptChan chan struct{} `state:"nosave"`
+
+ // gosched contains the current scheduling state of the task goroutine.
+ //
+ // gosched is protected by goschedSeq. gosched is owned by the task
+ // goroutine.
+ goschedSeq gvsync.SeqCount `state:"nosave"`
+ gosched TaskGoroutineSchedInfo
+
+ // yieldCount is the number of times the task goroutine has called
+ // Task.InterruptibleSleepStart, Task.UninterruptibleSleepStart, or
+ // Task.Yield(), voluntarily ceasing execution.
+ //
+ // yieldCount is accessed using atomic memory operations. yieldCount is
+ // owned by the task goroutine.
+ yieldCount uint64
+
+ // pendingSignals is the set of pending signals that may be handled only by
+ // this task.
+ //
+ // pendingSignals is protected by (taskNode.)tg.signalHandlers.mu
+ // (hereafter "the signal mutex"); see comment on
+ // ThreadGroup.signalHandlers.
+ pendingSignals pendingSignals
+
+ // signalMask is the set of signals whose delivery is currently blocked.
+ //
+ // signalMask is accessed using atomic memory operations, and is protected
+ // by the signal mutex (such that reading signalMask is safe if either the
+ // signal mutex is locked or if atomic memory operations are used, while
+ // writing signalMask requires both). signalMask is owned by the task
+ // goroutine.
+ signalMask linux.SignalSet
+
+ // If the task goroutine is currently executing Task.sigtimedwait,
+ // realSignalMask is the previous value of signalMask, which has temporarily
+ // been replaced by Task.sigtimedwait. Otherwise, realSignalMask is 0.
+ //
+ // realSignalMask is exclusive to the task goroutine.
+ realSignalMask linux.SignalSet
+
+ // If haveSavedSignalMask is true, savedSignalMask is the signal mask that
+ // should be applied after the task has either delivered one signal to a
+ // user handler or is about to resume execution in the untrusted
+ // application.
+ //
+ // Both haveSavedSignalMask and savedSignalMask are exclusive to the task
+ // goroutine.
+ haveSavedSignalMask bool
+ savedSignalMask linux.SignalSet
+
+ // signalStack is the alternate signal stack used by signal handlers for
+ // which the SA_ONSTACK flag is set.
+ //
+ // signalStack is exclusive to the task goroutine.
+ signalStack arch.SignalStack
+
+ // If groupStopPending is true, the task should participate in a group
+ // stop in the interrupt path.
+ //
+ // groupStopPending is analogous to JOBCTL_STOP_PENDING in Linux.
+ //
+ // groupStopPending is protected by the signal mutex.
+ groupStopPending bool
+
+ // If groupStopAcknowledged is true, the task has already acknowledged that
+ // it is entering the most recent group stop that has been initiated on its
+ // thread group.
+ //
+ // groupStopAcknowledged is analogous to !JOBCTL_STOP_CONSUME in Linux.
+ //
+ // groupStopAcknowledged is protected by the signal mutex.
+ groupStopAcknowledged bool
+
+ // If trapStopPending is true, the task goroutine should enter a
+ // PTRACE_INTERRUPT-induced stop from the interrupt path.
+ //
+ // trapStopPending is analogous to JOBCTL_TRAP_STOP in Linux, except that
+ // Linux also sets JOBCTL_TRAP_STOP when a ptraced task detects
+ // JOBCTL_STOP_PENDING.
+ //
+ // trapStopPending is protected by the signal mutex.
+ trapStopPending bool
+
+ // If trapNotifyPending is true, this task is PTRACE_SEIZEd, and a group
+ // stop has begun or ended since the last time the task entered a
+ // ptrace-stop from the group-stop path.
+ //
+ // trapNotifyPending is analogous to JOBCTL_TRAP_NOTIFY in Linux.
+ //
+ // trapNotifyPending is protected by the signal mutex.
+ trapNotifyPending bool
+
+ // If stop is not nil, it is the internally-initiated condition that
+ // currently prevents the task goroutine from running.
+ //
+ // stop is protected by the signal mutex.
+ stop TaskStop
+
+ // stopCount is the number of active external stops (calls to
+ // Task.BeginExternalStop that have not been paired with a call to
+ // Task.EndExternalStop), plus 1 if stop is not nil. Hence stopCount is
+ // non-zero if the task goroutine should stop.
+ //
+ // Mutating stopCount requires both locking the signal mutex and using
+ // atomic memory operations. Reading stopCount requires either locking the
+ // signal mutex or using atomic memory operations. This allows Task.doStop
+ // to require only a single atomic read in the common case where stopCount
+ // is 0.
+ //
+ // stopCount is not saved, because external stops cannot be retained across
+ // a save/restore cycle. (Suppose a sentryctl command issues an external
+ // stop; after a save/restore cycle, the restored sentry has no knowledge
+ // of the pre-save sentryctl command, and the stopped task would remain
+ // stopped forever.)
+ stopCount int32 `state:"nosave"`
+
+ // endStopCond is signaled when stopCount transitions to 0. The combination
+ // of stopCount and endStopCond effectively form a sync.WaitGroup, but
+ // WaitGroup provides no way to read its counter value.
+ //
+ // Invariant: endStopCond.L is the signal mutex. (This is not racy because
+ // sync.Cond.Wait is the only user of sync.Cond.L; only the task goroutine
+ // calls sync.Cond.Wait; and only the task goroutine can change the
+ // identity of the signal mutex, in Task.finishExec.)
+ endStopCond sync.Cond `state:"nosave"`
+
+ // exitStatus is the task's exit status.
+ //
+ // exitStatus is protected by the signal mutex.
+ exitStatus ExitStatus
+
+ // syscallRestartBlock represents a custom restart function to run in
+ // restart_syscall(2) to resume an interrupted syscall.
+ //
+ // syscallRestartBlock is exclusive to the task goroutine.
+ syscallRestartBlock SyscallRestartBlock
+
+ // p provides the mechanism by which the task runs code in userspace. The p
+ // interface object is immutable.
+ p platform.Context `state:"nosave"`
+
+ // k is the Kernel that this task belongs to. The k pointer is immutable.
+ k *Kernel
+
+ // containerID has no equivalent in Linux; it's used by runsc to track all
+ // tasks that belong to a given containers since cgroups aren't implemented.
+ // It's inherited by the children, is immutable, and may be empty.
+ //
+ // NOTE: cgroups can be used to track this when implemented.
+ containerID string
+
+ // mu protects some of the following fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // tc holds task data provided by the ELF loader.
+ //
+ // tc is protected by mu, and is owned by the task goroutine.
+ tc TaskContext
+
+ // fsc is the task's filesystem context.
+ //
+ // fsc is protected by mu, and is owned by the task goroutine.
+ fsc *FSContext
+
+ // fds is the task's file descriptor table.
+ //
+ // fds is protected by mu, and is owned by the task goroutine.
+ fds *FDMap
+
+ // If vforkParent is not nil, it is the task that created this task with
+ // vfork() or clone(CLONE_VFORK), and should have its vforkStop ended when
+ // this TaskContext is released.
+ //
+ // vforkParent is protected by the TaskSet mutex.
+ vforkParent *Task
+
+ // exitState is the task's progress through the exit path.
+ //
+ // exitState is protected by the TaskSet mutex. exitState is owned by the
+ // task goroutine.
+ exitState TaskExitState
+
+ // exitTracerNotified is true if the exit path has either signaled the
+ // task's tracer to indicate the exit, or determined that no such signal is
+ // needed. exitTracerNotified can only be true if exitState is
+ // TaskExitZombie or TaskExitDead.
+ //
+ // exitTracerNotified is protected by the TaskSet mutex.
+ exitTracerNotified bool
+
+ // exitTracerAcked is true if exitTracerNotified is true and either the
+ // task's tracer has acknowledged the exit notification, or the exit path
+ // has determined that no such notification is needed.
+ //
+ // exitTracerAcked is protected by the TaskSet mutex.
+ exitTracerAcked bool
+
+ // exitParentNotified is true if the exit path has either signaled the
+ // task's parent to indicate the exit, or determined that no such signal is
+ // needed. exitParentNotified can only be true if exitState is
+ // TaskExitZombie or TaskExitDead.
+ //
+ // exitParentNotified is protected by the TaskSet mutex.
+ exitParentNotified bool
+
+ // exitParentAcked is true if exitParentNotified is true and either the
+ // task's parent has acknowledged the exit notification, or the exit path
+ // has determined that no such acknowledgment is needed.
+ //
+ // exitParentAcked is protected by the TaskSet mutex.
+ exitParentAcked bool
+
+ // goroutineStopped is a WaitGroup whose counter value is 1 when the task
+ // goroutine is running and 0 when the task goroutine is stopped or has
+ // exited.
+ goroutineStopped sync.WaitGroup `state:"nosave"`
+
+ // ptraceTracer is the task that is ptrace-attached to this one. If
+ // ptraceTracer is nil, this task is not being traced. Note that due to
+ // atomic.Value limitations (atomic.Value.Store(nil) panics), a nil
+ // ptraceTracer is always represented as a typed nil (i.e. (*Task)(nil)).
+ //
+ // ptraceTracer is protected by the TaskSet mutex, and accessed with atomic
+ // operations. This allows paths that wouldn't otherwise lock the TaskSet
+ // mutex, notably the syscall path, to check if ptraceTracer is nil without
+ // additional synchronization.
+ ptraceTracer atomic.Value `state:".(*Task)"`
+
+ // ptraceTracees is the set of tasks that this task is ptrace-attached to.
+ //
+ // ptraceTracees is protected by the TaskSet mutex.
+ ptraceTracees map[*Task]struct{}
+
+ // ptraceSeized is true if ptraceTracer attached to this task with
+ // PTRACE_SEIZE.
+ //
+ // ptraceSeized is protected by the TaskSet mutex.
+ ptraceSeized bool
+
+ // ptraceOpts contains ptrace options explicitly set by the tracer. If
+ // ptraceTracer is nil, ptraceOpts is expected to be the zero value.
+ //
+ // ptraceOpts is protected by the TaskSet mutex.
+ ptraceOpts ptraceOptions
+
+ // ptraceSyscallMode controls ptrace behavior around syscall entry and
+ // exit.
+ //
+ // ptraceSyscallMode is protected by the TaskSet mutex.
+ ptraceSyscallMode ptraceSyscallMode
+
+ // If ptraceSinglestep is true, the next time the task executes application
+ // code, single-stepping should be enabled. ptraceSinglestep is stored
+ // independently of the architecture-specific trap flag because tracer
+ // detaching (which can happen concurrently with the tracee's execution if
+ // the tracer exits) must disable single-stepping, and the task's
+ // architectural state is implicitly exclusive to the task goroutine (no
+ // synchronization occurs before passing registers to SwitchToApp).
+ //
+ // ptraceSinglestep is analogous to Linux's TIF_SINGLESTEP.
+ //
+ // ptraceSinglestep is protected by the TaskSet mutex.
+ ptraceSinglestep bool
+
+ // If t is ptrace-stopped, ptraceCode is a ptrace-defined value set at the
+ // time that t entered the ptrace stop, reset to 0 when the tracer
+ // acknowledges the stop with a wait*() syscall. Otherwise, it is the
+ // signal number passed to the ptrace operation that ended the last ptrace
+ // stop on this task. In the latter case, the effect of ptraceCode depends
+ // on the nature of the ptrace stop; signal-delivery-stop uses it to
+ // conditionally override ptraceSiginfo, syscall-entry/exit-stops send the
+ // signal to the task after leaving the stop, and PTRACE_EVENT stops and
+ // traced group stops ignore it entirely.
+ //
+ // Linux contextually stores the equivalent of ptraceCode in
+ // task_struct::exit_code.
+ //
+ // ptraceCode is protected by the TaskSet mutex.
+ ptraceCode int32
+
+ // ptraceSiginfo is the value returned to the tracer by
+ // ptrace(PTRACE_GETSIGINFO) and modified by ptrace(PTRACE_SETSIGINFO).
+ // (Despite the name, PTRACE_PEEKSIGINFO is completely unrelated.)
+ // ptraceSiginfo is nil if the task is in a ptraced group-stop (this is
+ // required for PTRACE_GETSIGINFO to return EINVAL during such stops, which
+ // is in turn required to distinguish group stops from other ptrace stops,
+ // per subsection "Group-stop" in ptrace(2)).
+ //
+ // ptraceSiginfo is analogous to Linux's task_struct::last_siginfo.
+ //
+ // ptraceSiginfo is protected by the TaskSet mutex.
+ ptraceSiginfo *arch.SignalInfo
+
+ // ptraceEventMsg is the value set by PTRACE_EVENT stops and returned to
+ // the tracer by ptrace(PTRACE_GETEVENTMSG).
+ //
+ // ptraceEventMsg is protected by the TaskSet mutex.
+ ptraceEventMsg uint64
+
+ // The struct that holds the IO-related usage. The ioUsage pointer is
+ // immutable.
+ ioUsage *usage.IO
+
+ // logPrefix is a string containing the task's thread ID in the root PID
+ // namespace, and is prepended to log messages emitted by Task.Infof etc.
+ logPrefix atomic.Value `state:".(string)"`
+
+ // creds is the task's credentials.
+ //
+ // creds is protected by mu, however the value itself is immutable and can
+ // only be changed by a copy. After reading the pointer, access will
+ // proceed outside the scope of mu. creds is owned by the task goroutine.
+ creds *auth.Credentials
+
+ // utsns is the task's UTS namespace.
+ //
+ // utsns is protected by mu. utsns is owned by the task goroutine.
+ utsns *UTSNamespace
+
+ // ipcns is the task's IPC namespace.
+ //
+ // ipcns is protected by mu. ipcns is owned by the task goroutine.
+ ipcns *IPCNamespace
+
+ // abstractSockets tracks abstract sockets that are in use.
+ //
+ // abstractSockets is protected by mu.
+ abstractSockets *AbstractSocketNamespace
+
+ // parentDeathSignal is sent to this task's thread group when its parent exits.
+ //
+ // parentDeathSignal is protected by mu.
+ parentDeathSignal linux.Signal
+
+ // syscallFilters is all seccomp-bpf syscall filters applicable to the
+ // task, in the order in which they were installed. The type of the atomic
+ // is []bpf.Program. Writing needs to be protected by the signal mutex.
+ //
+ // syscallFilters is owned by the task goroutine.
+ syscallFilters atomic.Value `state:".([]bpf.Program)"`
+
+ // If cleartid is non-zero, treat it as a pointer to a ThreadID in the
+ // task's virtual address space; when the task exits, set the pointed-to
+ // ThreadID to 0, and wake any futex waiters.
+ //
+ // cleartid is exclusive to the task goroutine.
+ cleartid usermem.Addr
+
+ // This is mostly a fake cpumask just for sched_set/getaffinity as we
+ // don't really control the affinity.
+ //
+ // Invariant: allowedCPUMask.Size() ==
+ // sched.CPUMaskSize(Kernel.applicationCores).
+ //
+ // allowedCPUMask is protected by mu.
+ allowedCPUMask sched.CPUSet
+
+ // cpu is the fake cpu number returned by getcpu(2). cpu is ignored
+ // entirely if Kernel.useHostCores is true.
+ //
+ // cpu is accessed using atomic memory operations.
+ cpu int32
+
+ // This is used to keep track of changes made to a process' priority/niceness.
+ // It is mostly used to provide some reasonable return value from
+ // getpriority(2) after a call to setpriority(2) has been made.
+ // We currently do not actually modify a process' scheduling priority.
+ // NOTE: This represents the userspace view of priority (nice).
+ // This means that the value should be in the range [-20, 19].
+ //
+ // niceness is protected by mu.
+ niceness int
+
+ // This is used to track the numa policy for the current thread. This can be
+ // modified through a set_mempolicy(2) syscall. Since we always report a
+ // single numa node, all policies are no-ops. We only track this information
+ // so that we can return reasonable values if the application calls
+ // get_mempolicy(2) after setting a non-default policy. Note that in the
+ // real syscall, nodemask can be longer than 4 bytes, but we always report a
+ // single node so never need to save more than a single bit.
+ //
+ // numaPolicy and numaNodeMask are protected by mu.
+ numaPolicy int32
+ numaNodeMask uint32
+
+ // If netns is true, the task is in a non-root network namespace. Network
+ // namespaces aren't currently implemented in full; being in a network
+ // namespace simply prevents the task from observing any network devices
+ // (including loopback) or using abstract socket addresses (see unix(7)).
+ //
+ // netns is protected by mu. netns is owned by the task goroutine.
+ netns bool
+
+ // If rseqPreempted is true, before the next call to p.Switch(), interrupt
+ // RSEQ critical regions as defined by tg.rseq and write the task
+ // goroutine's CPU number to rseqCPUAddr. rseqCPU is the last CPU number
+ // written to rseqCPUAddr.
+ //
+ // If rseqCPUAddr is 0, rseqCPU is -1.
+ //
+ // rseqCPUAddr, rseqCPU, and rseqPreempted are exclusive to the task
+ // goroutine.
+ rseqPreempted bool `state:"nosave"`
+ rseqCPUAddr usermem.Addr
+ rseqCPU int32
+
+ // copyScratchBuffer is a buffer available to CopyIn/CopyOut
+ // implementations that require an intermediate buffer to copy data
+ // into/out of. It prevents these buffers from being allocated/zeroed in
+ // each syscall and eventually garbage collected.
+ //
+ // copyScratchBuffer is exclusive to the task goroutine.
+ copyScratchBuffer [copyScratchBufferLen]byte `state:"nosave"`
+
+ // blockingTimer is used for blocking timeouts. blockingTimerChan is the
+ // channel that is sent to when blockingTimer fires.
+ //
+ // blockingTimer is exclusive to the task goroutine.
+ blockingTimer *ktime.Timer `state:"nosave"`
+ blockingTimerChan <-chan struct{} `state:"nosave"`
+
+ // futexWaiter is used for futex(FUTEX_WAIT) syscalls.
+ //
+ // futexWaiter is exclusive to the task goroutine.
+ futexWaiter *futex.Waiter `state:"nosave"`
+
+ // startTime is the real time at which the task started. It is set when
+ // a Task is created or invokes execve(2).
+ //
+ // startTime is protected by mu.
+ startTime ktime.Time
+}
+
+func (t *Task) savePtraceTracer() *Task {
+ return t.ptraceTracer.Load().(*Task)
+}
+
+func (t *Task) loadPtraceTracer(tracer *Task) {
+ t.ptraceTracer.Store(tracer)
+}
+
+func (t *Task) saveLogPrefix() string {
+ return t.logPrefix.Load().(string)
+}
+
+func (t *Task) loadLogPrefix(prefix string) {
+ t.logPrefix.Store(prefix)
+}
+
+func (t *Task) saveSyscallFilters() []bpf.Program {
+ if f := t.syscallFilters.Load(); f != nil {
+ return f.([]bpf.Program)
+ }
+ return nil
+}
+
+func (t *Task) loadSyscallFilters(filters []bpf.Program) {
+ t.syscallFilters.Store(filters)
+}
+
+// afterLoad is invoked by stateify.
+func (t *Task) afterLoad() {
+ t.interruptChan = make(chan struct{}, 1)
+ t.gosched.State = TaskGoroutineNonexistent
+ if t.stop != nil {
+ t.stopCount = 1
+ }
+ t.endStopCond.L = &t.tg.signalHandlers.mu
+ t.p = t.k.Platform.NewContext()
+ t.rseqPreempted = true
+ t.futexWaiter = futex.NewWaiter()
+}
+
+// copyScratchBufferLen is the length of Task.copyScratchBuffer.
+const copyScratchBufferLen = 144 // sizeof(struct stat)
+
+// CopyScratchBuffer returns a scratch buffer to be used in CopyIn/CopyOut
+// functions. It must only be used within those functions and can only be used
+// by the task goroutine; it exists to improve performance and thus
+// intentionally lacks any synchronization.
+//
+// Callers should pass a constant value as an argument if possible, which will
+// allow the compiler to inline and optimize out the if statement below.
+func (t *Task) CopyScratchBuffer(size int) []byte {
+ if size > copyScratchBufferLen {
+ return make([]byte, size)
+ }
+ return t.copyScratchBuffer[:size]
+}
+
+// FutexWaiter returns the Task's futex.Waiter.
+func (t *Task) FutexWaiter() *futex.Waiter {
+ return t.futexWaiter
+}
+
+// Kernel returns the Kernel containing t.
+func (t *Task) Kernel() *Kernel {
+ return t.k
+}
+
+// Value implements context.Context.Value.
+//
+// Preconditions: The caller must be running on the task goroutine (as implied
+// by the requirements of context.Context).
+func (t *Task) Value(key interface{}) interface{} {
+ switch key {
+ case CtxCanTrace:
+ return t.CanTrace
+ case CtxKernel:
+ return t.k
+ case CtxPIDNamespace:
+ return t.tg.pidns
+ case CtxUTSNamespace:
+ return t.utsns
+ case CtxIPCNamespace:
+ return t.ipcns
+ case CtxTask:
+ return t
+ case auth.CtxCredentials:
+ return t.creds
+ case context.CtxThreadGroupID:
+ return int32(t.ThreadGroup().ID())
+ case fs.CtxRoot:
+ return t.fsc.RootDirectory()
+ case fs.CtxDirentCacheLimiter:
+ return t.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return t.NetworkContext()
+ case ktime.CtxRealtimeClock:
+ return t.k.RealtimeClock()
+ case limits.CtxLimits:
+ return t.tg.limits
+ case pgalloc.CtxMemoryFile:
+ return t.k.mf
+ case pgalloc.CtxMemoryFileProvider:
+ return t.k
+ case platform.CtxPlatform:
+ return t.k
+ case uniqueid.CtxGlobalUniqueID:
+ return t.k.UniqueID()
+ case uniqueid.CtxGlobalUniqueIDProvider:
+ return t.k
+ case uniqueid.CtxInotifyCookie:
+ return t.k.GenerateInotifyCookie()
+ case unimpl.CtxEvents:
+ return t.k
+ default:
+ return nil
+ }
+}
+
+// SetClearTID sets t's cleartid.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) SetClearTID(addr usermem.Addr) {
+ t.cleartid = addr
+}
+
+// SetSyscallRestartBlock sets the restart block for use in
+// restart_syscall(2). After registering a restart block, a syscall should
+// return ERESTART_RESTARTBLOCK to request a restart using the block.
+//
+// Precondition: The caller must be running on the task goroutine.
+func (t *Task) SetSyscallRestartBlock(r SyscallRestartBlock) {
+ t.syscallRestartBlock = r
+}
+
+// SyscallRestartBlock returns the currently registered restart block for use in
+// restart_syscall(2). This function is *not* idempotent and may be called once
+// per syscall. This function must not be called if a restart block has not been
+// registered for the current syscall.
+//
+// Precondition: The caller must be running on the task goroutine.
+func (t *Task) SyscallRestartBlock() SyscallRestartBlock {
+ r := t.syscallRestartBlock
+ // Explicitly set the restart block to nil so that a future syscall can't
+ // accidentally reuse it.
+ t.syscallRestartBlock = nil
+ return r
+}
+
+// IsChrooted returns true if the root directory of t's FSContext is not the
+// root directory of t's MountNamespace.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) IsChrooted() bool {
+ realRoot := t.k.mounts.Root()
+ defer realRoot.DecRef()
+ root := t.fsc.RootDirectory()
+ if root != nil {
+ defer root.DecRef()
+ }
+ return root != realRoot
+}
+
+// TaskContext returns t's TaskContext.
+//
+// Precondition: The caller must be running on the task goroutine, or t.mu must
+// be locked.
+func (t *Task) TaskContext() *TaskContext {
+ return &t.tc
+}
+
+// FSContext returns t's FSContext. FSContext does not take an additional
+// reference on the returned FSContext.
+//
+// Precondition: The caller must be running on the task goroutine, or t.mu must
+// be locked.
+func (t *Task) FSContext() *FSContext {
+ return t.fsc
+}
+
+// FDMap returns t's FDMap. FDMap does not take an additional reference on the
+// returned FDMap.
+//
+// Precondition: The caller must be running on the task goroutine, or t.mu must
+// be locked.
+func (t *Task) FDMap() *FDMap {
+ return t.fds
+}
+
+// WithMuLocked executes f with t.mu locked.
+func (t *Task) WithMuLocked(f func(*Task)) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ f(t)
+}
+
+// MountNamespace returns t's MountNamespace. MountNamespace does not take an
+// additional reference on the returned MountNamespace.
+func (t *Task) MountNamespace() *fs.MountNamespace {
+ return t.k.mounts
+}
+
+// AbstractSockets returns t's AbstractSocketNamespace.
+func (t *Task) AbstractSockets() *AbstractSocketNamespace {
+ return t.abstractSockets
+}
+
+// ContainerID returns t's container ID.
+func (t *Task) ContainerID() string {
+ return t.containerID
+}
diff --git a/pkg/sentry/kernel/task_acct.go b/pkg/sentry/kernel/task_acct.go
new file mode 100644
index 000000000..1ca2a82eb
--- /dev/null
+++ b/pkg/sentry/kernel/task_acct.go
@@ -0,0 +1,196 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// Accounting, limits, timers.
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Getitimer implements getitimer(2).
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) Getitimer(id int32) (linux.ItimerVal, error) {
+ var tm ktime.Time
+ var s ktime.Setting
+ switch id {
+ case linux.ITIMER_REAL:
+ tm, s = t.tg.itimerRealTimer.Get()
+ case linux.ITIMER_VIRTUAL:
+ tm = t.tg.UserCPUClock().Now()
+ t.tg.signalHandlers.mu.Lock()
+ s, _ = t.tg.itimerVirtSetting.At(tm)
+ t.tg.signalHandlers.mu.Unlock()
+ case linux.ITIMER_PROF:
+ tm = t.tg.CPUClock().Now()
+ t.tg.signalHandlers.mu.Lock()
+ s, _ = t.tg.itimerProfSetting.At(tm)
+ t.tg.signalHandlers.mu.Unlock()
+ default:
+ return linux.ItimerVal{}, syserror.EINVAL
+ }
+ val, iv := ktime.SpecFromSetting(tm, s)
+ return linux.ItimerVal{
+ Value: linux.DurationToTimeval(val),
+ Interval: linux.DurationToTimeval(iv),
+ }, nil
+}
+
+// Setitimer implements setitimer(2).
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) Setitimer(id int32, newitv linux.ItimerVal) (linux.ItimerVal, error) {
+ var tm ktime.Time
+ var olds ktime.Setting
+ switch id {
+ case linux.ITIMER_REAL:
+ news, err := ktime.SettingFromSpec(newitv.Value.ToDuration(), newitv.Interval.ToDuration(), t.tg.itimerRealTimer.Clock())
+ if err != nil {
+ return linux.ItimerVal{}, err
+ }
+ tm, olds = t.tg.itimerRealTimer.Swap(news)
+ case linux.ITIMER_VIRTUAL:
+ c := t.tg.UserCPUClock()
+ var err error
+ t.k.cpuClockTicker.Atomically(func() {
+ tm = c.Now()
+ var news ktime.Setting
+ news, err = ktime.SettingFromSpecAt(newitv.Value.ToDuration(), newitv.Interval.ToDuration(), tm)
+ if err != nil {
+ return
+ }
+ t.tg.signalHandlers.mu.Lock()
+ olds = t.tg.itimerVirtSetting
+ t.tg.itimerVirtSetting = news
+ t.tg.updateCPUTimersEnabledLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ })
+ if err != nil {
+ return linux.ItimerVal{}, err
+ }
+ case linux.ITIMER_PROF:
+ c := t.tg.CPUClock()
+ var err error
+ t.k.cpuClockTicker.Atomically(func() {
+ tm = c.Now()
+ var news ktime.Setting
+ news, err = ktime.SettingFromSpecAt(newitv.Value.ToDuration(), newitv.Interval.ToDuration(), tm)
+ if err != nil {
+ return
+ }
+ t.tg.signalHandlers.mu.Lock()
+ olds = t.tg.itimerProfSetting
+ t.tg.itimerProfSetting = news
+ t.tg.updateCPUTimersEnabledLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ })
+ if err != nil {
+ return linux.ItimerVal{}, err
+ }
+ default:
+ return linux.ItimerVal{}, syserror.EINVAL
+ }
+ oldval, oldiv := ktime.SpecFromSetting(tm, olds)
+ return linux.ItimerVal{
+ Value: linux.DurationToTimeval(oldval),
+ Interval: linux.DurationToTimeval(oldiv),
+ }, nil
+}
+
+// IOUsage returns the io usage of the thread.
+func (t *Task) IOUsage() *usage.IO {
+ return t.ioUsage
+}
+
+// IOUsage returns the total io usage of all dead and live threads in the group.
+func (tg *ThreadGroup) IOUsage() *usage.IO {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+
+ io := *tg.ioUsage
+ // Account for active tasks.
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ io.Accumulate(t.IOUsage())
+ }
+ return &io
+}
+
+// Name returns t's name.
+func (t *Task) Name() string {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.tc.Name
+}
+
+// SetName changes t's name.
+func (t *Task) SetName(name string) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.tc.Name = name
+ t.Debugf("Set thread name to %q", name)
+}
+
+// Limits implements context.Context.Limits.
+func (t *Task) Limits() *limits.LimitSet {
+ return t.ThreadGroup().Limits()
+}
+
+// StartTime returns t's start time.
+func (t *Task) StartTime() ktime.Time {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.startTime
+}
+
+// MaxRSS returns the maximum resident set size of the task in bytes. which
+// should be one of RUSAGE_SELF, RUSAGE_CHILDREN, RUSAGE_THREAD, or
+// RUSAGE_BOTH. See getrusage(2) for documentation on the behavior of these
+// flags.
+func (t *Task) MaxRSS(which int32) uint64 {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+
+ switch which {
+ case linux.RUSAGE_SELF, linux.RUSAGE_THREAD:
+ // If there's an active mm we can use its value.
+ if mm := t.MemoryManager(); mm != nil {
+ if mmMaxRSS := mm.MaxResidentSetSize(); mmMaxRSS > t.tg.maxRSS {
+ return mmMaxRSS
+ }
+ }
+ return t.tg.maxRSS
+ case linux.RUSAGE_CHILDREN:
+ return t.tg.childMaxRSS
+ case linux.RUSAGE_BOTH:
+ maxRSS := t.tg.maxRSS
+ if maxRSS < t.tg.childMaxRSS {
+ maxRSS = t.tg.childMaxRSS
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ if mmMaxRSS := mm.MaxResidentSetSize(); mmMaxRSS > maxRSS {
+ return mmMaxRSS
+ }
+ }
+ return maxRSS
+ default:
+ // We'll only get here if which is invalid.
+ return 0
+ }
+}
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
new file mode 100644
index 000000000..1c76c4d84
--- /dev/null
+++ b/pkg/sentry/kernel/task_block.go
@@ -0,0 +1,212 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "time"
+
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// BlockWithTimeout blocks t until an event is received from C, the application
+// monotonic clock indicates that timeout has elapsed (only if haveTimeout is true),
+// or t is interrupted. It returns:
+//
+// - The remaining timeout, which is guaranteed to be 0 if the timeout expired,
+// and is unspecified if haveTimeout is false.
+//
+// - An error which is nil if an event is received from C, ETIMEDOUT if the timeout
+// expired, and syserror.ErrInterrupted if t is interrupted.
+func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.Duration) (time.Duration, error) {
+ if !haveTimeout {
+ return timeout, t.block(C, nil)
+ }
+
+ start := t.Kernel().MonotonicClock().Now()
+ deadline := start.Add(timeout)
+ err := t.BlockWithDeadline(C, true, deadline)
+
+ // Timeout, explicitly return a remaining duration of 0.
+ if err == syserror.ETIMEDOUT {
+ return 0, err
+ }
+
+ // Compute the remaining timeout. Note that even if block() above didn't
+ // return due to a timeout, we may have used up any of the remaining time
+ // since then. We cap the remaining timeout to 0 to make it easier to
+ // directly use the returned duration.
+ end := t.Kernel().MonotonicClock().Now()
+ remainingTimeout := timeout - end.Sub(start)
+ if remainingTimeout < 0 {
+ remainingTimeout = 0
+ }
+
+ return remainingTimeout, err
+}
+
+// BlockWithDeadline blocks t until an event is received from C, the
+// application monotonic clock indicates a time of deadline (only if
+// haveDeadline is true), or t is interrupted. It returns nil if an event is
+// received from C, ETIMEDOUT if the deadline expired, and
+// syserror.ErrInterrupted if t is interrupted.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) BlockWithDeadline(C chan struct{}, haveDeadline bool, deadline ktime.Time) error {
+ if !haveDeadline {
+ return t.block(C, nil)
+ }
+
+ // Start the timeout timer.
+ t.blockingTimer.Swap(ktime.Setting{
+ Enabled: true,
+ Next: deadline,
+ })
+
+ err := t.block(C, t.blockingTimerChan)
+
+ // Stop the timeout timer and drain the channel.
+ t.blockingTimer.Swap(ktime.Setting{})
+ select {
+ case <-t.blockingTimerChan:
+ default:
+ }
+
+ return err
+}
+
+// BlockWithTimer blocks t until an event is received from C or tchan, or t is
+// interrupted. It returns nil if an event is received from C, ETIMEDOUT if an
+// event is received from tchan, and syserror.ErrInterrupted if t is
+// interrupted.
+//
+// Most clients should use BlockWithDeadline or BlockWithTimeout instead.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) BlockWithTimer(C <-chan struct{}, tchan <-chan struct{}) error {
+ return t.block(C, tchan)
+}
+
+// Block blocks t until an event is received from C or t is interrupted. It
+// returns nil if an event is received from C and syserror.ErrInterrupted if t
+// is interrupted.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) Block(C <-chan struct{}) error {
+ return t.block(C, nil)
+}
+
+// block blocks a task on one of many events.
+// N.B. defer is too expensive to be used here.
+func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error {
+ // Fast path if the request is already done.
+ select {
+ case <-C:
+ return nil
+ default:
+ }
+
+ // Deactive our address space, we don't need it.
+ interrupt := t.SleepStart()
+
+ select {
+ case <-C:
+ t.SleepFinish(true)
+ return nil
+
+ case <-interrupt:
+ t.SleepFinish(false)
+ // Return the indicated error on interrupt.
+ return syserror.ErrInterrupted
+
+ case <-timerChan:
+ // We've timed out.
+ t.SleepFinish(true)
+ return syserror.ETIMEDOUT
+ }
+}
+
+// SleepStart implements amutex.Sleeper.SleepStart.
+func (t *Task) SleepStart() <-chan struct{} {
+ t.Deactivate()
+ t.accountTaskGoroutineEnter(TaskGoroutineBlockedInterruptible)
+ return t.interruptChan
+}
+
+// SleepFinish implements amutex.Sleeper.SleepFinish.
+func (t *Task) SleepFinish(success bool) {
+ if !success {
+ // The interrupted notification is consumed only at the top-level
+ // (Run). Therefore we attempt to reset the pending notification.
+ // This will also elide our next entry back into the task, so we
+ // will process signals, state changes, etc.
+ t.interruptSelf()
+ }
+ t.accountTaskGoroutineLeave(TaskGoroutineBlockedInterruptible)
+ t.Activate()
+}
+
+// Interrupted implements amutex.Sleeper.Interrupted
+func (t *Task) Interrupted() bool {
+ return len(t.interruptChan) != 0
+}
+
+// UninterruptibleSleepStart implements context.Context.UninterruptibleSleepStart.
+func (t *Task) UninterruptibleSleepStart(deactivate bool) {
+ if deactivate {
+ t.Deactivate()
+ }
+ t.accountTaskGoroutineEnter(TaskGoroutineBlockedUninterruptible)
+}
+
+// UninterruptibleSleepFinish implements context.Context.UninterruptibleSleepFinish.
+func (t *Task) UninterruptibleSleepFinish(activate bool) {
+ t.accountTaskGoroutineLeave(TaskGoroutineBlockedUninterruptible)
+ if activate {
+ t.Activate()
+ }
+}
+
+// interrupted returns true if interrupt or interruptSelf has been called at
+// least once since the last call to interrupted.
+func (t *Task) interrupted() bool {
+ select {
+ case <-t.interruptChan:
+ return true
+ default:
+ return false
+ }
+}
+
+// interrupt unblocks the task and interrupts it if it's currently running in
+// userspace.
+func (t *Task) interrupt() {
+ t.interruptSelf()
+ t.p.Interrupt()
+}
+
+// interruptSelf is like Interrupt, but can only be called by the task
+// goroutine.
+func (t *Task) interruptSelf() {
+ select {
+ case t.interruptChan <- struct{}{}:
+ t.Debugf("Interrupt queued")
+ default:
+ t.Debugf("Dropping duplicate interrupt")
+ }
+ // platform.Context.Interrupt() is unnecessary since a task goroutine
+ // calling interruptSelf() cannot also be blocked in
+ // platform.Context.Switch().
+}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
new file mode 100644
index 000000000..bba8ddd39
--- /dev/null
+++ b/pkg/sentry/kernel/task_clone.go
@@ -0,0 +1,516 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/bpf"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// SharingOptions controls what resources are shared by a new task created by
+// Task.Clone, or an existing task affected by Task.Unshare.
+type SharingOptions struct {
+ // If NewAddressSpace is true, the task should have an independent virtual
+ // address space.
+ NewAddressSpace bool
+
+ // If NewSignalHandlers is true, the task should use an independent set of
+ // signal handlers.
+ NewSignalHandlers bool
+
+ // If NewThreadGroup is true, the task should be the leader of its own
+ // thread group. TerminationSignal is the signal that the thread group
+ // will send to its parent when it exits. If NewThreadGroup is false,
+ // TerminationSignal is ignored.
+ NewThreadGroup bool
+ TerminationSignal linux.Signal
+
+ // If NewPIDNamespace is true:
+ //
+ // - In the context of Task.Clone, the new task should be the init task
+ // (TID 1) in a new PID namespace.
+ //
+ // - In the context of Task.Unshare, the task should create a new PID
+ // namespace, and all subsequent clones of the task should be members of
+ // the new PID namespace.
+ NewPIDNamespace bool
+
+ // If NewUserNamespace is true, the task should have an independent user
+ // namespace.
+ NewUserNamespace bool
+
+ // If NewNetworkNamespace is true, the task should have an independent
+ // network namespace. (Note that network namespaces are not really
+ // implemented; see comment on Task.netns for details.)
+ NewNetworkNamespace bool
+
+ // If NewFiles is true, the task should use an independent file descriptor
+ // table.
+ NewFiles bool
+
+ // If NewFSContext is true, the task should have an independent FSContext.
+ NewFSContext bool
+
+ // If NewUTSNamespace is true, the task should have an independent UTS
+ // namespace.
+ NewUTSNamespace bool
+
+ // If NewIPCNamespace is true, the task should have an independent IPC
+ // namespace.
+ NewIPCNamespace bool
+}
+
+// CloneOptions controls the behavior of Task.Clone.
+type CloneOptions struct {
+ // SharingOptions defines the set of resources that the new task will share
+ // with its parent.
+ SharingOptions
+
+ // Stack is the initial stack pointer of the new task. If Stack is 0, the
+ // new task will start with the same stack pointer as its parent.
+ Stack usermem.Addr
+
+ // If SetTLS is true, set the new task's TLS (thread-local storage)
+ // descriptor to TLS. If SetTLS is false, TLS is ignored.
+ SetTLS bool
+ TLS usermem.Addr
+
+ // If ChildClearTID is true, when the child exits, 0 is written to the
+ // address ChildTID in the child's memory, and if the write is successful a
+ // futex wake on the same address is performed.
+ //
+ // If ChildSetTID is true, the child's thread ID (in the child's PID
+ // namespace) is written to address ChildTID in the child's memory. (As in
+ // Linux, failed writes are silently ignored.)
+ ChildClearTID bool
+ ChildSetTID bool
+ ChildTID usermem.Addr
+
+ // If ParentSetTID is true, the child's thread ID (in the parent's PID
+ // namespace) is written to address ParentTID in the parent's memory. (As
+ // in Linux, failed writes are silently ignored.)
+ //
+ // Older versions of the clone(2) man page state that CLONE_PARENT_SETTID
+ // causes the child's thread ID to be written to ptid in both the parent
+ // and child's memory, but this is a documentation error fixed by
+ // 87ab04792ced ("clone.2: Fix description of CLONE_PARENT_SETTID").
+ ParentSetTID bool
+ ParentTID usermem.Addr
+
+ // If Vfork is true, place the parent in vforkStop until the cloned task
+ // releases its TaskContext.
+ Vfork bool
+
+ // If Untraced is true, do not report PTRACE_EVENT_CLONE/FORK/VFORK for
+ // this clone(), and do not ptrace-attach the caller's tracer to the new
+ // task. (PTRACE_EVENT_VFORK_DONE will still be reported if appropriate).
+ Untraced bool
+
+ // If InheritTracer is true, ptrace-attach the caller's tracer to the new
+ // task, even if no PTRACE_EVENT_CLONE/FORK/VFORK event would be reported
+ // for it. If both Untraced and InheritTracer are true, no event will be
+ // reported, but tracer inheritance will still occur.
+ InheritTracer bool
+}
+
+// Clone implements the clone(2) syscall and returns the thread ID of the new
+// task in t's PID namespace. Clone may return both a non-zero thread ID and a
+// non-nil error.
+//
+// Preconditions: The caller must be running Task.doSyscallInvoke on the task
+// goroutine.
+func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
+ // Since signal actions may refer to application signal handlers by virtual
+ // address, any set of signal handlers must refer to the same address
+ // space.
+ if !opts.NewSignalHandlers && opts.NewAddressSpace {
+ return 0, nil, syserror.EINVAL
+ }
+ // In order for the behavior of thread-group-directed signals to be sane,
+ // all tasks in a thread group must share signal handlers.
+ if !opts.NewThreadGroup && opts.NewSignalHandlers {
+ return 0, nil, syserror.EINVAL
+ }
+ // All tasks in a thread group must be in the same PID namespace.
+ if !opts.NewThreadGroup && (opts.NewPIDNamespace || t.childPIDNamespace != nil) {
+ return 0, nil, syserror.EINVAL
+ }
+ // The two different ways of specifying a new PID namespace are
+ // incompatible.
+ if opts.NewPIDNamespace && t.childPIDNamespace != nil {
+ return 0, nil, syserror.EINVAL
+ }
+ // Thread groups and FS contexts cannot span user namespaces.
+ if opts.NewUserNamespace && (!opts.NewThreadGroup || !opts.NewFSContext) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "If CLONE_NEWUSER is specified along with other CLONE_NEW* flags in a
+ // single clone(2) or unshare(2) call, the user namespace is guaranteed to
+ // be created first, giving the child (clone(2)) or caller (unshare(2))
+ // privileges over the remaining namespaces created by the call." -
+ // user_namespaces(7)
+ creds := t.Credentials()
+ userns := creds.UserNamespace
+ if opts.NewUserNamespace {
+ var err error
+ // "EPERM (since Linux 3.9): CLONE_NEWUSER was specified in flags and
+ // the caller is in a chroot environment (i.e., the caller's root
+ // directory does not match the root directory of the mount namespace
+ // in which it resides)." - clone(2). Neither chroot(2) nor
+ // user_namespaces(7) document this.
+ if t.IsChrooted() {
+ return 0, nil, syserror.EPERM
+ }
+ userns, err = creds.NewChildUserNamespace()
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if (opts.NewPIDNamespace || opts.NewNetworkNamespace || opts.NewUTSNamespace) && !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, userns) {
+ return 0, nil, syserror.EPERM
+ }
+
+ utsns := t.UTSNamespace()
+ if opts.NewUTSNamespace {
+ // Note that this must happen after NewUserNamespace so we get
+ // the new userns if there is one.
+ utsns = t.UTSNamespace().Clone(userns)
+ }
+
+ ipcns := t.IPCNamespace()
+ if opts.NewIPCNamespace {
+ // Note that "If CLONE_NEWIPC is set, then create the process in a new IPC
+ // namespace"
+ ipcns = NewIPCNamespace(userns)
+ }
+
+ tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace)
+ if err != nil {
+ return 0, nil, err
+ }
+ // clone() returns 0 in the child.
+ tc.Arch.SetReturn(0)
+ if opts.Stack != 0 {
+ tc.Arch.SetStack(uintptr(opts.Stack))
+ }
+ if opts.SetTLS {
+ if !tc.Arch.SetTLS(uintptr(opts.TLS)) {
+ return 0, nil, syserror.EPERM
+ }
+ }
+
+ var fsc *FSContext
+ if opts.NewFSContext {
+ fsc = t.fsc.Fork()
+ } else {
+ fsc = t.fsc
+ fsc.IncRef()
+ }
+
+ var fds *FDMap
+ if opts.NewFiles {
+ fds = t.fds.Fork()
+ } else {
+ fds = t.fds
+ fds.IncRef()
+ }
+
+ pidns := t.tg.pidns
+ if t.childPIDNamespace != nil {
+ pidns = t.childPIDNamespace
+ } else if opts.NewPIDNamespace {
+ pidns = pidns.NewChild(userns)
+ }
+ tg := t.tg
+ if opts.NewThreadGroup {
+ sh := t.tg.signalHandlers
+ if opts.NewSignalHandlers {
+ sh = sh.Fork()
+ }
+ tg = t.k.newThreadGroup(pidns, sh, opts.TerminationSignal, tg.limits.GetCopy(), t.k.monotonicClock)
+ }
+
+ cfg := &TaskConfig{
+ Kernel: t.k,
+ ThreadGroup: tg,
+ SignalMask: t.SignalMask(),
+ TaskContext: tc,
+ FSContext: fsc,
+ FDMap: fds,
+ Credentials: creds,
+ Niceness: t.Niceness(),
+ NetworkNamespaced: t.netns,
+ AllowedCPUMask: t.CPUMask(),
+ UTSNamespace: utsns,
+ IPCNamespace: ipcns,
+ AbstractSocketNamespace: t.abstractSockets,
+ ContainerID: t.ContainerID(),
+ }
+ if opts.NewThreadGroup {
+ cfg.Parent = t
+ } else {
+ cfg.InheritParent = t
+ }
+ if opts.NewNetworkNamespace {
+ cfg.NetworkNamespaced = true
+ }
+ nt, err := t.tg.pidns.owner.NewTask(cfg)
+ if err != nil {
+ if opts.NewThreadGroup {
+ tg.release()
+ }
+ return 0, nil, err
+ }
+
+ // "A child process created via fork(2) inherits a copy of its parent's
+ // alternate signal stack settings" - sigaltstack(2).
+ //
+ // However kernel/fork.c:copy_process() adds a limitation to this:
+ // "sigaltstack should be cleared when sharing the same VM".
+ if opts.NewAddressSpace || opts.Vfork {
+ nt.SetSignalStack(t.SignalStack())
+ }
+
+ if userns != creds.UserNamespace {
+ if err := nt.SetUserNamespace(userns); err != nil {
+ // This shouldn't be possible: userns was created from nt.creds, so
+ // nt should have CAP_SYS_ADMIN in userns.
+ panic("Task.Clone: SetUserNamespace failed: " + err.Error())
+ }
+ }
+
+ // This has to happen last, because e.g. ptraceClone may send a SIGSTOP to
+ // nt that it must receive before its task goroutine starts running.
+ tid := nt.k.tasks.Root.IDOfTask(nt)
+ defer nt.Start(tid)
+
+ // "If fork/clone and execve are allowed by @prog, any child processes will
+ // be constrained to the same filters and system call ABI as the parent." -
+ // Documentation/prctl/seccomp_filter.txt
+ if f := t.syscallFilters.Load(); f != nil {
+ copiedFilters := append([]bpf.Program(nil), f.([]bpf.Program)...)
+ nt.syscallFilters.Store(copiedFilters)
+ }
+ if opts.Vfork {
+ nt.vforkParent = t
+ }
+
+ if opts.ChildClearTID {
+ nt.SetClearTID(opts.ChildTID)
+ }
+ if opts.ChildSetTID {
+ // Can't use Task.CopyOut, which assumes AddressSpaceActive.
+ usermem.CopyObjectOut(t, nt.MemoryManager(), opts.ChildTID, nt.ThreadID(), usermem.IOOpts{})
+ }
+ ntid := t.tg.pidns.IDOfTask(nt)
+ if opts.ParentSetTID {
+ t.CopyOut(opts.ParentTID, ntid)
+ }
+
+ kind := ptraceCloneKindClone
+ if opts.Vfork {
+ kind = ptraceCloneKindVfork
+ } else if opts.TerminationSignal == linux.SIGCHLD {
+ kind = ptraceCloneKindFork
+ }
+ if t.ptraceClone(kind, nt, opts) {
+ if opts.Vfork {
+ return ntid, &SyscallControl{next: &runSyscallAfterPtraceEventClone{vforkChild: nt, vforkChildTID: ntid}}, nil
+ }
+ return ntid, &SyscallControl{next: &runSyscallAfterPtraceEventClone{}}, nil
+ }
+ if opts.Vfork {
+ t.maybeBeginVforkStop(nt)
+ return ntid, &SyscallControl{next: &runSyscallAfterVforkStop{childTID: ntid}}, nil
+ }
+ return ntid, nil, nil
+}
+
+// maybeBeginVforkStop checks if a previously-started vfork child is still
+// running and has not yet released its MM, such that its parent t should enter
+// a vforkStop.
+//
+// Preconditions: The caller must be running on t's task goroutine.
+func (t *Task) maybeBeginVforkStop(child *Task) {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if t.killedLocked() {
+ child.vforkParent = nil
+ return
+ }
+ if child.vforkParent == t {
+ t.beginInternalStopLocked((*vforkStop)(nil))
+ }
+}
+
+func (t *Task) unstopVforkParent() {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ if p := t.vforkParent; p != nil {
+ p.tg.signalHandlers.mu.Lock()
+ defer p.tg.signalHandlers.mu.Unlock()
+ if _, ok := p.stop.(*vforkStop); ok {
+ p.endInternalStopLocked()
+ }
+ // Parent no longer needs to be unstopped.
+ t.vforkParent = nil
+ }
+}
+
+// +stateify savable
+type runSyscallAfterPtraceEventClone struct {
+ vforkChild *Task
+
+ // If vforkChild is not nil, vforkChildTID is its thread ID in the parent's
+ // PID namespace. vforkChildTID must be stored since the child may exit and
+ // release its TID before the PTRACE_EVENT stop ends.
+ vforkChildTID ThreadID
+}
+
+func (r *runSyscallAfterPtraceEventClone) execute(t *Task) taskRunState {
+ if r.vforkChild != nil {
+ t.maybeBeginVforkStop(r.vforkChild)
+ return &runSyscallAfterVforkStop{r.vforkChildTID}
+ }
+ return (*runSyscallExit)(nil)
+}
+
+// +stateify savable
+type runSyscallAfterVforkStop struct {
+ // childTID has the same meaning as
+ // runSyscallAfterPtraceEventClone.vforkChildTID.
+ childTID ThreadID
+}
+
+func (r *runSyscallAfterVforkStop) execute(t *Task) taskRunState {
+ t.ptraceVforkDone(r.childTID)
+ return (*runSyscallExit)(nil)
+}
+
+// Unshare changes the set of resources t shares with other tasks, as specified
+// by opts.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) Unshare(opts *SharingOptions) error {
+ // In Linux unshare(2), NewThreadGroup implies NewSignalHandlers and
+ // NewSignalHandlers implies NewAddressSpace. All three flags are no-ops if
+ // t is the only task using its MM, which due to clone(2)'s rules imply
+ // that it is also the only task using its signal handlers / in its thread
+ // group, and cause EINVAL to be returned otherwise.
+ //
+ // Since we don't count the number of tasks using each address space or set
+ // of signal handlers, we reject NewSignalHandlers and NewAddressSpace
+ // altogether, and interpret NewThreadGroup as requiring that t be the only
+ // member of its thread group. This seems to be logically coherent, in the
+ // sense that clone(2) allows a task to share signal handlers and address
+ // spaces with tasks in other thread groups.
+ if opts.NewAddressSpace || opts.NewSignalHandlers {
+ return syserror.EINVAL
+ }
+ if opts.NewThreadGroup {
+ t.tg.signalHandlers.mu.Lock()
+ if t.tg.tasksCount != 1 {
+ t.tg.signalHandlers.mu.Unlock()
+ return syserror.EINVAL
+ }
+ t.tg.signalHandlers.mu.Unlock()
+ // This isn't racy because we're the only living task, and therefore
+ // the only task capable of creating new ones, in our thread group.
+ }
+ if opts.NewUserNamespace {
+ if t.IsChrooted() {
+ return syserror.EPERM
+ }
+ // This temporary is needed because Go.
+ creds := t.Credentials()
+ newUserNS, err := creds.NewChildUserNamespace()
+ if err != nil {
+ return err
+ }
+ err = t.SetUserNamespace(newUserNS)
+ if err != nil {
+ return err
+ }
+ }
+ haveCapSysAdmin := t.HasCapability(linux.CAP_SYS_ADMIN)
+ if opts.NewPIDNamespace {
+ if !haveCapSysAdmin {
+ return syserror.EPERM
+ }
+ t.childPIDNamespace = t.tg.pidns.NewChild(t.UserNamespace())
+ }
+ t.mu.Lock()
+ // Can't defer unlock: DecRefs must occur without holding t.mu.
+ if opts.NewNetworkNamespace {
+ if !haveCapSysAdmin {
+ t.mu.Unlock()
+ return syserror.EPERM
+ }
+ t.netns = true
+ }
+ if opts.NewUTSNamespace {
+ if !haveCapSysAdmin {
+ t.mu.Unlock()
+ return syserror.EPERM
+ }
+ // Note that this must happen after NewUserNamespace, so the
+ // new user namespace is used if there is one.
+ t.utsns = t.utsns.Clone(t.creds.UserNamespace)
+ }
+ if opts.NewIPCNamespace {
+ if !haveCapSysAdmin {
+ t.mu.Unlock()
+ return syserror.EPERM
+ }
+ // Note that "If CLONE_NEWIPC is set, then create the process in a new IPC
+ // namespace"
+ t.ipcns = NewIPCNamespace(t.creds.UserNamespace)
+ }
+ var oldfds *FDMap
+ if opts.NewFiles {
+ oldfds = t.fds
+ t.fds = oldfds.Fork()
+ }
+ var oldfsc *FSContext
+ if opts.NewFSContext {
+ oldfsc = t.fsc
+ t.fsc = oldfsc.Fork()
+ }
+ t.mu.Unlock()
+ if oldfds != nil {
+ oldfds.DecRef()
+ }
+ if oldfsc != nil {
+ oldfsc.DecRef()
+ }
+ return nil
+}
+
+// vforkStop is a TaskStop imposed on a task that creates a child with
+// CLONE_VFORK or vfork(2), that ends when the child task ceases to use its
+// current MM. (Normally, CLONE_VFORK is used in conjunction with CLONE_VM, so
+// that the child and parent share mappings until the child execve()s into a
+// new process image or exits.)
+//
+// +stateify savable
+type vforkStop struct{}
+
+// StopIgnoresKill implements TaskStop.Killable.
+func (*vforkStop) Killable() bool { return true }
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
new file mode 100644
index 000000000..bbd294141
--- /dev/null
+++ b/pkg/sentry/kernel/task_context.go
@@ -0,0 +1,174 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/loader"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+)
+
+var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC)
+
+// Auxmap contains miscellaneous data for the task.
+type Auxmap map[string]interface{}
+
+// TaskContext is the subset of a task's data that is provided by the loader.
+//
+// +stateify savable
+type TaskContext struct {
+ // Name is the thread name set by the prctl(PR_SET_NAME) system call.
+ Name string
+
+ // Arch is the architecture-specific context (registers, etc.)
+ Arch arch.Context
+
+ // MemoryManager is the task's address space.
+ MemoryManager *mm.MemoryManager
+
+ // fu implements futexes in the address space.
+ fu *futex.Manager
+
+ // st is the task's syscall table.
+ st *SyscallTable
+}
+
+// release releases all resources held by the TaskContext. release is called by
+// the task when it execs into a new TaskContext or exits.
+func (tc *TaskContext) release() {
+ // Nil out pointers so that if the task is saved after release, it doesn't
+ // follow the pointers to possibly now-invalid objects.
+ if tc.MemoryManager != nil {
+ // TODO(b/38173783)
+ tc.MemoryManager.DecUsers(context.Background())
+ tc.MemoryManager = nil
+ }
+ tc.fu = nil
+}
+
+// Fork returns a duplicate of tc. The copied TaskContext always has an
+// independent arch.Context. If shareAddressSpace is true, the copied
+// TaskContext shares an address space with the original; otherwise, the copied
+// TaskContext has an independent address space that is initially a duplicate
+// of the original's.
+func (tc *TaskContext) Fork(ctx context.Context, k *Kernel, shareAddressSpace bool) (*TaskContext, error) {
+ newTC := &TaskContext{
+ Name: tc.Name,
+ Arch: tc.Arch.Fork(),
+ st: tc.st,
+ }
+ if shareAddressSpace {
+ newTC.MemoryManager = tc.MemoryManager
+ if newTC.MemoryManager != nil {
+ if !newTC.MemoryManager.IncUsers() {
+ // Shouldn't be possible since tc.MemoryManager should be a
+ // counted user.
+ panic(fmt.Sprintf("TaskContext.Fork called with userless TaskContext.MemoryManager"))
+ }
+ }
+ newTC.fu = tc.fu
+ } else {
+ newMM, err := tc.MemoryManager.Fork(ctx)
+ if err != nil {
+ return nil, err
+ }
+ newTC.MemoryManager = newMM
+ newTC.fu = k.futexes.Fork()
+ }
+ return newTC, nil
+}
+
+// Arch returns t's arch.Context.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) Arch() arch.Context {
+ return t.tc.Arch
+}
+
+// MemoryManager returns t's MemoryManager. MemoryManager does not take an
+// additional reference on the returned MM.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) MemoryManager() *mm.MemoryManager {
+ return t.tc.MemoryManager
+}
+
+// SyscallTable returns t's syscall table.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) SyscallTable() *SyscallTable {
+ return t.tc.st
+}
+
+// Stack returns the userspace stack.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) Stack() *arch.Stack {
+ return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())}
+}
+
+// LoadTaskImage loads filename into a new TaskContext.
+//
+// It takes several arguments:
+// * mounts: MountNamespace to lookup filename in
+// * root: Root to lookup filename under
+// * wd: Working directory to lookup filename under
+// * maxTraversals: maximum number of symlinks to follow
+// * filename: path to binary to load
+// * argv: Binary argv
+// * envv: Binary envv
+// * fs: Binary FeatureSet
+func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) {
+ // Prepare a new user address space to load into.
+ m := mm.NewMemoryManager(k, k)
+ defer m.DecUsers(ctx)
+
+ os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, argv, envv, k.extraAuxv, k.vdso)
+ if err != nil {
+ return nil, err
+ }
+
+ // Lookup our new syscall table.
+ st, ok := LookupSyscallTable(os, ac.Arch())
+ if !ok {
+ // No syscall table found. This means that the ELF binary does not match
+ // the architecture.
+ return nil, errNoSyscalls
+ }
+
+ if !m.IncUsers() {
+ panic("Failed to increment users count on new MM")
+ }
+ return &TaskContext{
+ Name: name,
+ Arch: ac,
+ MemoryManager: m,
+ fu: k.futexes.Fork(),
+ st: st,
+ }, nil
+}
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
new file mode 100644
index 000000000..5d1425d5c
--- /dev/null
+++ b/pkg/sentry/kernel/task_exec.go
@@ -0,0 +1,262 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// This file implements the machinery behind the execve() syscall. In brief, a
+// thread executes an execve() by killing all other threads in its thread
+// group, assuming the leader's identity, and then switching process images.
+//
+// This design is effectively mandated by Linux. From ptrace(2):
+//
+// """
+// execve(2) under ptrace
+// When one thread in a multithreaded process calls execve(2), the
+// kernel destroys all other threads in the process, and resets the
+// thread ID of the execing thread to the thread group ID (process ID).
+// (Or, to put things another way, when a multithreaded process does an
+// execve(2), at completion of the call, it appears as though the
+// execve(2) occurred in the thread group leader, regardless of which
+// thread did the execve(2).) This resetting of the thread ID looks
+// very confusing to tracers:
+//
+// * All other threads stop in PTRACE_EVENT_EXIT stop, if the
+// PTRACE_O_TRACEEXIT option was turned on. Then all other threads
+// except the thread group leader report death as if they exited via
+// _exit(2) with exit code 0.
+//
+// * The execing tracee changes its thread ID while it is in the
+// execve(2). (Remember, under ptrace, the "pid" returned from
+// waitpid(2), or fed into ptrace calls, is the tracee's thread ID.)
+// That is, the tracee's thread ID is reset to be the same as its
+// process ID, which is the same as the thread group leader's thread
+// ID.
+//
+// * Then a PTRACE_EVENT_EXEC stop happens, if the PTRACE_O_TRACEEXEC
+// option was turned on.
+//
+// * If the thread group leader has reported its PTRACE_EVENT_EXIT stop
+// by this time, it appears to the tracer that the dead thread leader
+// "reappears from nowhere". (Note: the thread group leader does not
+// report death via WIFEXITED(status) until there is at least one
+// other live thread. This eliminates the possibility that the
+// tracer will see it dying and then reappearing.) If the thread
+// group leader was still alive, for the tracer this may look as if
+// thread group leader returns from a different system call than it
+// entered, or even "returned from a system call even though it was
+// not in any system call". If the thread group leader was not
+// traced (or was traced by a different tracer), then during
+// execve(2) it will appear as if it has become a tracee of the
+// tracer of the execing tracee.
+//
+// All of the above effects are the artifacts of the thread ID change in
+// the tracee.
+// """
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// execStop is a TaskStop that a task sets on itself when it wants to execve
+// and is waiting for the other tasks in its thread group to exit first.
+//
+// +stateify savable
+type execStop struct{}
+
+// Killable implements TaskStop.Killable.
+func (*execStop) Killable() bool { return true }
+
+// Execve implements the execve(2) syscall by killing all other tasks in its
+// thread group and switching to newTC. Execve always takes ownership of newTC.
+//
+// Preconditions: The caller must be running Task.doSyscallInvoke on the task
+// goroutine.
+func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+
+ if t.tg.exiting || t.tg.execing != nil {
+ // We lost to a racing group-exit, kill, or exec from another thread
+ // and should just exit.
+ newTC.release()
+ return nil, syserror.EINTR
+ }
+
+ // Cancel any racing group stops.
+ t.tg.endGroupStopLocked(false)
+
+ // If the task has any siblings, they have to exit before the exec can
+ // continue.
+ t.tg.execing = t
+ if t.tg.tasks.Front() != t.tg.tasks.Back() {
+ // "[All] other threads except the thread group leader report death as
+ // if they exited via _exit(2) with exit code 0." - ptrace(2)
+ for sibling := t.tg.tasks.Front(); sibling != nil; sibling = sibling.Next() {
+ if t != sibling {
+ sibling.killLocked()
+ }
+ }
+ // The last sibling to exit will wake t.
+ t.beginInternalStopLocked((*execStop)(nil))
+ }
+
+ return &SyscallControl{next: &runSyscallAfterExecStop{newTC}, ignoreReturn: true}, nil
+}
+
+// The runSyscallAfterExecStop state continues execve(2) after all siblings of
+// a thread in the execve syscall have exited.
+//
+// +stateify savable
+type runSyscallAfterExecStop struct {
+ tc *TaskContext
+}
+
+func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
+ t.tg.pidns.owner.mu.Lock()
+ t.tg.execing = nil
+ if t.killed() {
+ t.tg.pidns.owner.mu.Unlock()
+ r.tc.release()
+ return (*runInterrupt)(nil)
+ }
+ // We are the thread group leader now. Save our old thread ID for
+ // PTRACE_EVENT_EXEC. This is racy in that if a tracer attaches after this
+ // point it will get a PID of 0, but this is consistent with Linux.
+ oldTID := ThreadID(0)
+ if tracer := t.Tracer(); tracer != nil {
+ oldTID = tracer.tg.pidns.tids[t]
+ }
+ t.promoteLocked()
+ // "POSIX timers are not preserved (timer_create(2))." - execve(2). Handle
+ // this first since POSIX timers are protected by the signal mutex, which
+ // we're about to change. Note that we have to stop and destroy timers
+ // without holding any mutexes to avoid circular lock ordering.
+ var its []*IntervalTimer
+ t.tg.signalHandlers.mu.Lock()
+ for _, it := range t.tg.timers {
+ its = append(its, it)
+ }
+ t.tg.timers = make(map[linux.TimerID]*IntervalTimer)
+ t.tg.signalHandlers.mu.Unlock()
+ t.tg.pidns.owner.mu.Unlock()
+ for _, it := range its {
+ it.DestroyTimer()
+ }
+ t.tg.pidns.owner.mu.Lock()
+ // "During an execve(2), the dispositions of handled signals are reset to
+ // the default; the dispositions of ignored signals are left unchanged. ...
+ // [The] signal mask is preserved across execve(2). ... [The] pending
+ // signal set is preserved across an execve(2)." - signal(7)
+ //
+ // Details:
+ //
+ // - If the thread group is sharing its signal handlers with another thread
+ // group via CLONE_SIGHAND, execve forces the signal handlers to be copied
+ // (see Linux's fs/exec.c:de_thread). We're not reference-counting signal
+ // handlers, so we always make a copy.
+ //
+ // - "Disposition" only means sigaction::sa_handler/sa_sigaction; flags,
+ // restorer (if present), and mask are always reset. (See Linux's
+ // fs/exec.c:setup_new_exec => kernel/signal.c:flush_signal_handlers.)
+ t.tg.signalHandlers = t.tg.signalHandlers.CopyForExec()
+ t.endStopCond.L = &t.tg.signalHandlers.mu
+ // "Any alternate signal stack is not preserved (sigaltstack(2))." - execve(2)
+ t.signalStack = arch.SignalStack{Flags: arch.SignalStackFlagDisable}
+ // "The termination signal is reset to SIGCHLD (see clone(2))."
+ t.tg.terminationSignal = linux.SIGCHLD
+ // execed indicates that the process can no longer join a process group
+ // in some scenarios (namely, the parent call setpgid(2) on the child).
+ // See the JoinProcessGroup function in sessions.go for more context.
+ t.tg.execed = true
+ // Maximum RSS is preserved across execve(2).
+ t.updateRSSLocked()
+ // Restartable sequence state is discarded.
+ t.rseqPreempted = false
+ t.rseqCPUAddr = 0
+ t.rseqCPU = -1
+ t.tg.rscr.Store(&RSEQCriticalRegion{})
+ t.tg.pidns.owner.mu.Unlock()
+
+ // Remove FDs with the CloseOnExec flag set.
+ t.fds.RemoveIf(func(file *fs.File, flags FDFlags) bool {
+ return flags.CloseOnExec
+ })
+
+ // Switch to the new process.
+ t.MemoryManager().Deactivate()
+ t.mu.Lock()
+ // Update credentials to reflect the execve. This should precede switching
+ // MMs to ensure that dumpability has been reset first, if needed.
+ t.updateCredsForExecLocked()
+ t.tc.release()
+ t.tc = *r.tc
+ t.mu.Unlock()
+ t.unstopVforkParent()
+ // NOTE(b/30316266): All locks must be dropped prior to calling Activate.
+ t.MemoryManager().Activate()
+
+ t.ptraceExec(oldTID)
+ return (*runSyscallExit)(nil)
+}
+
+// promoteLocked makes t the leader of its thread group. If t is already the
+// thread group leader, promoteLocked is a no-op.
+//
+// Preconditions: All other tasks in t's thread group, including the existing
+// leader (if it is not t), have reached TaskExitZombie. The TaskSet mutex must
+// be locked for writing.
+func (t *Task) promoteLocked() {
+ oldLeader := t.tg.leader
+ if t == oldLeader {
+ return
+ }
+ // Swap the leader's TIDs with the execing task's. The latter will be
+ // released when the old leader is reaped below.
+ for ns := t.tg.pidns; ns != nil; ns = ns.parent {
+ oldTID, leaderTID := ns.tids[t], ns.tids[oldLeader]
+ ns.tids[oldLeader] = oldTID
+ ns.tids[t] = leaderTID
+ ns.tasks[oldTID] = oldLeader
+ ns.tasks[leaderTID] = t
+ // Neither the ThreadGroup nor TGID change, so no need to
+ // update ns.tgids.
+ }
+
+ // Inherit the old leader's start time.
+ oldStartTime := oldLeader.StartTime()
+ t.mu.Lock()
+ t.startTime = oldStartTime
+ t.mu.Unlock()
+
+ t.tg.leader = t
+ t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t])
+ t.updateLogPrefixLocked()
+ // Reap the original leader. If it has a tracer, detach it instead of
+ // waiting for it to acknowledge the original leader's death.
+ oldLeader.exitParentNotified = true
+ oldLeader.exitParentAcked = true
+ if tracer := oldLeader.Tracer(); tracer != nil {
+ delete(tracer.ptraceTracees, oldLeader)
+ oldLeader.forgetTracerLocked()
+ // Notify the tracer that it will no longer be receiving these events
+ // from the tracee.
+ tracer.tg.eventQueue.Notify(EventExit | EventTraceeStop | EventGroupContinue)
+ }
+ oldLeader.exitNotifyLocked(false)
+}
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
new file mode 100644
index 000000000..158e665d3
--- /dev/null
+++ b/pkg/sentry/kernel/task_exit.go
@@ -0,0 +1,1159 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// This file implements the task exit cycle:
+//
+// - Tasks are asynchronously requested to exit with Task.Kill.
+//
+// - When able, the task goroutine enters the exit path starting from state
+// runExit.
+//
+// - Other tasks observe completed exits with Task.Wait (which implements the
+// wait*() family of syscalls).
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// An ExitStatus is a value communicated from an exiting task or thread group
+// to the party that reaps it.
+//
+// +stateify savable
+type ExitStatus struct {
+ // Code is the numeric value passed to the call to exit or exit_group that
+ // caused the exit. If the exit was not caused by such a call, Code is 0.
+ Code int
+
+ // Signo is the signal that caused the exit. If the exit was not caused by
+ // a signal, Signo is 0.
+ Signo int
+}
+
+// Signaled returns true if the ExitStatus indicates that the exiting task or
+// thread group was killed by a signal.
+func (es ExitStatus) Signaled() bool {
+ return es.Signo != 0
+}
+
+// Status returns the numeric representation of the ExitStatus returned by e.g.
+// the wait4() system call.
+func (es ExitStatus) Status() uint32 {
+ return ((uint32(es.Code) & 0xff) << 8) | (uint32(es.Signo) & 0xff)
+}
+
+// ShellExitCode returns the numeric exit code that Bash would return for an
+// exit status of es.
+func (es ExitStatus) ShellExitCode() int {
+ if es.Signaled() {
+ return 128 + es.Signo
+ }
+ return es.Code
+}
+
+// TaskExitState represents a step in the task exit path.
+//
+// "Exiting" and "exited" are often ambiguous; prefer to name specific states.
+type TaskExitState int
+
+const (
+ // TaskExitNone indicates that the task has not begun exiting.
+ TaskExitNone TaskExitState = iota
+
+ // TaskExitInitiated indicates that the task goroutine has entered the exit
+ // path, and the task is no longer eligible to participate in group stops
+ // or group signal handling. TaskExitInitiated is analogous to Linux's
+ // PF_EXITING.
+ TaskExitInitiated
+
+ // TaskExitZombie indicates that the task has released its resources, and
+ // the task no longer prevents a sibling thread from completing execve.
+ TaskExitZombie
+
+ // TaskExitDead indicates that the task's thread IDs have been released,
+ // and the task no longer prevents its thread group leader from being
+ // reaped. ("Reaping" refers to the transitioning of a task from
+ // TaskExitZombie to TaskExitDead.)
+ TaskExitDead
+)
+
+// String implements fmt.Stringer.
+func (t TaskExitState) String() string {
+ switch t {
+ case TaskExitNone:
+ return "TaskExitNone"
+ case TaskExitInitiated:
+ return "TaskExitInitiated"
+ case TaskExitZombie:
+ return "TaskExitZombie"
+ case TaskExitDead:
+ return "TaskExitDead"
+ default:
+ return strconv.Itoa(int(t))
+ }
+}
+
+// killLocked marks t as killed by enqueueing a SIGKILL, without causing the
+// thread-group-affecting side effects SIGKILL usually has.
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) killLocked() {
+ // Clear killable stops.
+ if t.stop != nil && t.stop.Killable() {
+ t.endInternalStopLocked()
+ }
+ t.pendingSignals.enqueue(&arch.SignalInfo{
+ Signo: int32(linux.SIGKILL),
+ // Linux just sets SIGKILL in the pending signal bitmask without
+ // enqueueing an actual siginfo, such that
+ // kernel/signal.c:collect_signal() initializes si_code to SI_USER.
+ Code: arch.SignalInfoUser,
+ }, nil)
+ t.interrupt()
+}
+
+// killed returns true if t has a SIGKILL pending. killed is analogous to
+// Linux's fatal_signal_pending().
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) killed() bool {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.killedLocked()
+}
+
+func (t *Task) killedLocked() bool {
+ return t.pendingSignals.pendingSet&linux.SignalSetOf(linux.SIGKILL) != 0
+}
+
+// PrepareExit indicates an exit with status es.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) PrepareExit(es ExitStatus) {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.exitStatus = es
+}
+
+// PrepareGroupExit indicates a group exit with status es to t's thread group.
+//
+// PrepareGroupExit is analogous to Linux's do_group_exit(), except that it
+// does not tail-call do_exit(), except that it *does* set Task.exitStatus.
+// (Linux does not do so until within do_exit(), since it reuses exit_code for
+// ptrace.)
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) PrepareGroupExit(es ExitStatus) {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if t.tg.exiting || t.tg.execing != nil {
+ // Note that if t.tg.exiting is false but t.tg.execing is not nil, i.e.
+ // this "group exit" is being executed by the killed sibling of an
+ // execing task, then Task.Execve never set t.tg.exitStatus, so it's
+ // still the zero value. This is consistent with Linux, both in intent
+ // ("all other threads ... report death as if they exited via _exit(2)
+ // with exit code 0" - ptrace(2), "execve under ptrace") and in
+ // implementation (compare fs/exec.c:de_thread() =>
+ // kernel/signal.c:zap_other_threads() and
+ // kernel/exit.c:do_group_exit() =>
+ // include/linux/sched.h:signal_group_exit()).
+ t.exitStatus = t.tg.exitStatus
+ return
+ }
+ t.tg.exiting = true
+ t.tg.exitStatus = es
+ t.exitStatus = es
+ for sibling := t.tg.tasks.Front(); sibling != nil; sibling = sibling.Next() {
+ if sibling != t {
+ sibling.killLocked()
+ }
+ }
+}
+
+// Kill requests that all tasks in ts exit as if group exiting with status es.
+// Kill does not wait for tasks to exit.
+//
+// Kill has no analogue in Linux; it's provided for save/restore only.
+func (ts *TaskSet) Kill(es ExitStatus) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.Root.exiting = true
+ for t := range ts.Root.tids {
+ t.tg.signalHandlers.mu.Lock()
+ if !t.tg.exiting {
+ t.tg.exiting = true
+ t.tg.exitStatus = es
+ }
+ t.killLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ }
+}
+
+// advanceExitStateLocked checks that t's current exit state is oldExit, then
+// sets it to newExit. If t's current exit state is not oldExit,
+// advanceExitStateLocked panics.
+//
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) advanceExitStateLocked(oldExit, newExit TaskExitState) {
+ if t.exitState != oldExit {
+ panic(fmt.Sprintf("Transitioning from exit state %v to %v: unexpected preceding state %v", oldExit, newExit, t.exitState))
+ }
+ t.Debugf("Transitioning from exit state %v to %v", oldExit, newExit)
+ t.exitState = newExit
+}
+
+// runExit is the entry point into the task exit path.
+//
+// +stateify savable
+type runExit struct{}
+
+func (*runExit) execute(t *Task) taskRunState {
+ t.ptraceExit()
+ return (*runExitMain)(nil)
+}
+
+// +stateify savable
+type runExitMain struct{}
+
+func (*runExitMain) execute(t *Task) taskRunState {
+ lastExiter := t.exitThreadGroup()
+
+ // If the task has a cleartid, and the thread group wasn't killed by a
+ // signal, handle that before releasing the MM.
+ if t.cleartid != 0 {
+ t.tg.signalHandlers.mu.Lock()
+ signaled := t.tg.exiting && t.tg.exitStatus.Signaled()
+ t.tg.signalHandlers.mu.Unlock()
+ if !signaled {
+ if _, err := t.CopyOut(t.cleartid, ThreadID(0)); err == nil {
+ t.Futex().Wake(t, t.cleartid, false, ^uint32(0), 1)
+ }
+ // If the CopyOut fails, there's nothing we can do.
+ }
+ }
+
+ // Deactivate the address space and update max RSS before releasing the
+ // task's MM.
+ t.Deactivate()
+ t.tg.pidns.owner.mu.Lock()
+ t.updateRSSLocked()
+ t.tg.pidns.owner.mu.Unlock()
+ t.mu.Lock()
+ t.tc.release()
+ t.mu.Unlock()
+
+ // Releasing the MM unblocks a blocked CLONE_VFORK parent.
+ t.unstopVforkParent()
+
+ t.fsc.DecRef()
+ t.fds.DecRef()
+
+ // If this is the last task to exit from the thread group, release the
+ // thread group's resources.
+ if lastExiter {
+ t.tg.release()
+ }
+
+ // Detach tracees.
+ t.exitPtrace()
+
+ // Reparent the task's children.
+ t.exitChildren()
+
+ // Don't tail-call runExitNotify, as exitChildren may have initiated a stop
+ // to wait for a PID namespace to die.
+ return (*runExitNotify)(nil)
+}
+
+// exitThreadGroup transitions t to TaskExitInitiated, indicating to t's thread
+// group that it is no longer eligible to participate in group activities. It
+// returns true if t is the last task in its thread group to call
+// exitThreadGroup.
+func (t *Task) exitThreadGroup() bool {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ t.tg.signalHandlers.mu.Lock()
+ // Can't defer unlock: see below.
+
+ t.advanceExitStateLocked(TaskExitNone, TaskExitInitiated)
+ t.tg.activeTasks--
+ last := t.tg.activeTasks == 0
+
+ // Ensure that someone will handle the signals we can't.
+ t.setSignalMaskLocked(^linux.SignalSet(0))
+
+ // Check if this task's exit interacts with an initiated group stop.
+ if !t.groupStopPending {
+ t.tg.signalHandlers.mu.Unlock()
+ return last
+ }
+ t.groupStopPending = false
+ sig := t.tg.groupStopSignal
+ notifyParent := t.participateGroupStopLocked()
+ // signalStop must be called with t's signal mutex unlocked.
+ t.tg.signalHandlers.mu.Unlock()
+ if notifyParent && t.tg.leader.parent != nil {
+ t.tg.leader.parent.signalStop(t, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop)
+ }
+ return last
+}
+
+func (t *Task) exitChildren() {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ newParent := t.findReparentTargetLocked()
+ if newParent == nil {
+ // "If the init process of a PID namespace terminates, the kernel
+ // terminates all of the processes in the namespace via a SIGKILL
+ // signal." - pid_namespaces(7)
+ t.Debugf("Init process terminating, killing namespace")
+ t.tg.pidns.exiting = true
+ for other := range t.tg.pidns.tgids {
+ if other == t.tg {
+ continue
+ }
+ other.signalHandlers.mu.Lock()
+ other.leader.sendSignalLocked(&arch.SignalInfo{
+ Signo: int32(linux.SIGKILL),
+ }, true /* group */)
+ other.signalHandlers.mu.Unlock()
+ }
+ // TODO(b/37722272): The init process waits for all processes in the
+ // namespace to exit before completing its own exit
+ // (kernel/pid_namespace.c:zap_pid_ns_processes()). Stop until all
+ // other tasks in the namespace are dead, except possibly for this
+ // thread group's leader (which can't be reaped until this task exits).
+ }
+ // This is correct even if newParent is nil (it ensures that children don't
+ // wait for a parent to reap them.)
+ for c := range t.children {
+ if sig := c.ParentDeathSignal(); sig != 0 {
+ siginfo := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ siginfo.SetPid(int32(c.tg.pidns.tids[t]))
+ siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
+ c.tg.signalHandlers.mu.Lock()
+ c.sendSignalLocked(siginfo, true /* group */)
+ c.tg.signalHandlers.mu.Unlock()
+ }
+ c.reparentLocked(newParent)
+ if newParent != nil {
+ newParent.children[c] = struct{}{}
+ }
+ }
+}
+
+// findReparentTargetLocked returns the task to which t's children should be
+// reparented. If no such task exists, findNewParentLocked returns nil.
+//
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) findReparentTargetLocked() *Task {
+ // Reparent to any sibling in the same thread group that hasn't begun
+ // exiting.
+ if t2 := t.tg.anyNonExitingTaskLocked(); t2 != nil {
+ return t2
+ }
+ // "A child process that is orphaned within the namespace will be
+ // reparented to [the init process for the namespace] ..." -
+ // pid_namespaces(7)
+ if init := t.tg.pidns.tasks[InitTID]; init != nil {
+ return init.tg.anyNonExitingTaskLocked()
+ }
+ return nil
+}
+
+func (tg *ThreadGroup) anyNonExitingTaskLocked() *Task {
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ if t.exitState == TaskExitNone {
+ return t
+ }
+ }
+ return nil
+}
+
+// reparentLocked changes t's parent. The new parent may be nil.
+//
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) reparentLocked(parent *Task) {
+ oldParent := t.parent
+ t.parent = parent
+ // If a thread group leader's parent changes, reset the thread group's
+ // termination signal to SIGCHLD and re-check exit notification. (Compare
+ // kernel/exit.c:reparent_leader().)
+ if t != t.tg.leader {
+ return
+ }
+ if oldParent == nil && parent == nil {
+ return
+ }
+ if oldParent != nil && parent != nil && oldParent.tg == parent.tg {
+ return
+ }
+ t.tg.terminationSignal = linux.SIGCHLD
+ if t.exitParentNotified && !t.exitParentAcked {
+ t.exitParentNotified = false
+ t.exitNotifyLocked(false)
+ }
+}
+
+// When a task exits, other tasks in the system, notably the task's parent and
+// ptracer, may want to be notified. The exit notification system ensures that
+// interested tasks receive signals and/or are woken from blocking calls to
+// wait*() syscalls; these notifications must be resolved before exiting tasks
+// can be reaped and disappear from the system.
+//
+// Each task may have a parent task and/or a tracer task. If both a parent and
+// a tracer exist, they may be the same task, different tasks in the same
+// thread group, or tasks in different thread groups. (In the last case, Linux
+// refers to the task as being ptrace-reparented due to an implementation
+// detail; we avoid this terminology to avoid confusion.)
+//
+// A thread group is *empty* if all non-leader tasks in the thread group are
+// dead, and the leader is either a zombie or dead. The exit of a thread group
+// leader is never waitable - by either the parent or tracer - until the thread
+// group is empty.
+//
+// There are a few ways for an exit notification to be resolved:
+//
+// - The exit notification may be acknowledged by a call to Task.Wait with
+// WaitOptions.ConsumeEvent set (e.g. due to a wait4() syscall).
+//
+// - If the notified party is the parent, and the parent thread group is not
+// also the tracer thread group, and the notification signal is SIGCHLD, the
+// parent may explicitly ignore the notification (see quote in exitNotify).
+// Note that it's possible for the notified party to ignore the signal in other
+// cases, but the notification is only resolved under the above conditions.
+// (Actually, there is one exception; see the last paragraph of the "leader,
+// has tracer, tracer thread group is parent thread group" case below.)
+//
+// - If the notified party is the parent, and the parent does not exist, the
+// notification is resolved as if ignored. (This is only possible in the
+// sentry. In Linux, the only task / thread group without a parent is global
+// init, and killing global init causes a kernel panic.)
+//
+// - If the notified party is a tracer, the tracer may detach the traced task.
+// (Zombie tasks cannot be ptrace-attached, so the reverse is not possible.)
+//
+// In addition, if the notified party is the parent, the parent may exit and
+// cause the notifying task to be reparented to another thread group. This does
+// not resolve the notification; instead, the notification must be resent to
+// the new parent.
+//
+// The series of notifications generated for a given task's exit depend on
+// whether it is a thread group leader; whether the task is ptraced; and, if
+// so, whether the tracer thread group is the same as the parent thread group.
+//
+// - Non-leader, no tracer: No notification is generated; the task is reaped
+// immediately.
+//
+// - Non-leader, has tracer: SIGCHLD is sent to the tracer. When the tracer
+// notification is resolved (by waiting or detaching), the task is reaped. (For
+// non-leaders, whether the tracer and parent thread groups are the same is
+// irrelevant.)
+//
+// - Leader, no tracer: The task remains a zombie, with no notification sent,
+// until all other tasks in the thread group are dead. (In Linux terms, this
+// condition is indicated by include/linux/sched.h:thread_group_empty(); tasks
+// are removed from their thread_group list in kernel/exit.c:release_task() =>
+// __exit_signal() => __unhash_process().) Then the thread group's termination
+// signal is sent to the parent. When the parent notification is resolved (by
+// waiting or ignoring), the task is reaped.
+//
+// - Leader, has tracer, tracer thread group is not parent thread group:
+// SIGCHLD is sent to the tracer. When the tracer notification is resolved (by
+// waiting or detaching), and all other tasks in the thread group are dead, the
+// thread group's termination signal is sent to the parent. (Note that the
+// tracer cannot resolve the exit notification by waiting until the thread
+// group is empty.) When the parent notification is resolved, the task is
+// reaped.
+//
+// - Leader, has tracer, tracer thread group is parent thread group:
+//
+// If all other tasks in the thread group are dead, the thread group's
+// termination signal is sent to the parent. At this point, the notification
+// can only be resolved by waiting. If the parent detaches from the task as a
+// tracer, the notification is not resolved, but the notification can now be
+// resolved by waiting or ignoring. When the parent notification is resolved,
+// the task is reaped.
+//
+// If at least one task in the thread group is not dead, SIGCHLD is sent to the
+// parent. At this point, the notification cannot be resolved at all; once the
+// thread group becomes empty, it can be resolved only by waiting. If the
+// parent detaches from the task as a tracer before all remaining tasks die,
+// then exit notification proceeds as in the case where the leader never had a
+// tracer. If the parent detaches from the task as a tracer after all remaining
+// tasks die, the notification is not resolved, but the notification can now be
+// resolved by waiting or ignoring. When the parent notification is resolved,
+// the task is reaped.
+//
+// In both of the above cases, when the parent detaches from the task as a
+// tracer while the thread group is empty, whether or not the parent resolves
+// the notification by ignoring it is based on the parent's SIGCHLD signal
+// action, whether or not the thread group's termination signal is SIGCHLD
+// (Linux: kernel/ptrace.c:__ptrace_detach() => ignoring_children()).
+//
+// There is one final wrinkle: A leader can become a non-leader due to a
+// sibling execve. In this case, the execing thread detaches the leader's
+// tracer (if one exists) and reaps the leader immediately. In Linux, this is
+// in fs/exec.c:de_thread(); in the sentry, this is in Task.promoteLocked().
+
+// +stateify savable
+type runExitNotify struct{}
+
+func (*runExitNotify) execute(t *Task) taskRunState {
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+ t.advanceExitStateLocked(TaskExitInitiated, TaskExitZombie)
+ t.tg.liveTasks--
+ // Check if this completes a sibling's execve.
+ if t.tg.execing != nil && t.tg.liveTasks == 1 {
+ // execing blocks the addition of new tasks to the thread group, so
+ // the sole living task must be the execing one.
+ e := t.tg.execing
+ e.tg.signalHandlers.mu.Lock()
+ if _, ok := e.stop.(*execStop); ok {
+ e.endInternalStopLocked()
+ }
+ e.tg.signalHandlers.mu.Unlock()
+ }
+ t.exitNotifyLocked(false)
+ // The task goroutine will now exit.
+ return nil
+}
+
+// exitNotifyLocked is called after changes to t's state that affect exit
+// notification.
+//
+// If fromPtraceDetach is true, the caller is ptraceDetach or exitPtrace;
+// thanks to Linux's haphazard implementation of this functionality, such cases
+// determine whether parent notifications are ignored based on the parent's
+// handling of SIGCHLD, regardless of what the exited task's thread group's
+// termination signal is.
+//
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) exitNotifyLocked(fromPtraceDetach bool) {
+ if t.exitState != TaskExitZombie {
+ return
+ }
+ if !t.exitTracerNotified {
+ t.exitTracerNotified = true
+ tracer := t.Tracer()
+ if tracer == nil {
+ t.exitTracerAcked = true
+ } else if t != t.tg.leader || t.parent == nil || tracer.tg != t.parent.tg {
+ // Don't set exitParentNotified if t is non-leader, even if the
+ // tracer is in the parent thread group, so that if the parent
+ // detaches the following call to exitNotifyLocked passes through
+ // the !exitParentNotified case below and causes t to be reaped
+ // immediately.
+ //
+ // Tracer notification doesn't care about about
+ // SIG_IGN/SA_NOCLDWAIT.
+ tracer.tg.signalHandlers.mu.Lock()
+ tracer.sendSignalLocked(t.exitNotificationSignal(linux.SIGCHLD, tracer), true /* group */)
+ tracer.tg.signalHandlers.mu.Unlock()
+ // Wake EventTraceeStop waiters as well since this task will never
+ // ptrace-stop again.
+ tracer.tg.eventQueue.Notify(EventExit | EventTraceeStop)
+ } else {
+ // t is a leader and the tracer is in the parent thread group.
+ t.exitParentNotified = true
+ sig := linux.SIGCHLD
+ if t.tg.tasksCount == 1 {
+ sig = t.tg.terminationSignal
+ }
+ // This notification doesn't care about SIG_IGN/SA_NOCLDWAIT either
+ // (in Linux, the check in do_notify_parent() is gated by
+ // !tsk->ptrace.)
+ t.parent.tg.signalHandlers.mu.Lock()
+ t.parent.sendSignalLocked(t.exitNotificationSignal(sig, t.parent), true /* group */)
+ t.parent.tg.signalHandlers.mu.Unlock()
+ // See below for rationale for this event mask.
+ t.parent.tg.eventQueue.Notify(EventExit | EventChildGroupStop | EventGroupContinue)
+ }
+ }
+ if t.exitTracerAcked && !t.exitParentNotified {
+ if t != t.tg.leader {
+ t.exitParentNotified = true
+ t.exitParentAcked = true
+ } else if t.tg.tasksCount == 1 {
+ t.exitParentNotified = true
+ if t.parent == nil {
+ t.exitParentAcked = true
+ } else {
+ // "POSIX.1-2001 specifies that if the disposition of SIGCHLD is
+ // set to SIG_IGN or the SA_NOCLDWAIT flag is set for SIGCHLD (see
+ // sigaction(2)), then children that terminate do not become
+ // zombies and a call to wait() or waitpid() will block until all
+ // children have terminated, and then fail with errno set to
+ // ECHILD. (The original POSIX standard left the behavior of
+ // setting SIGCHLD to SIG_IGN unspecified. Note that even though
+ // the default disposition of SIGCHLD is "ignore", explicitly
+ // setting the disposition to SIG_IGN results in different
+ // treatment of zombie process children.) Linux 2.6 conforms to
+ // this specification." - wait(2)
+ //
+ // Some undocumented Linux-specific details:
+ //
+ // - All of the above is ignored if the termination signal isn't
+ // SIGCHLD.
+ //
+ // - SA_NOCLDWAIT causes the leader to be immediately reaped, but
+ // does not suppress the SIGCHLD.
+ signalParent := t.tg.terminationSignal.IsValid()
+ t.parent.tg.signalHandlers.mu.Lock()
+ if t.tg.terminationSignal == linux.SIGCHLD || fromPtraceDetach {
+ if act, ok := t.parent.tg.signalHandlers.actions[linux.SIGCHLD]; ok {
+ if act.Handler == arch.SignalActIgnore {
+ t.exitParentAcked = true
+ signalParent = false
+ } else if act.Flags&arch.SignalFlagNoCldWait != 0 {
+ t.exitParentAcked = true
+ }
+ }
+ }
+ if signalParent {
+ t.parent.tg.leader.sendSignalLocked(t.exitNotificationSignal(t.tg.terminationSignal, t.parent), true /* group */)
+ }
+ t.parent.tg.signalHandlers.mu.Unlock()
+ // If a task in the parent was waiting for a child group stop
+ // or continue, it needs to be notified of the exit, because
+ // there may be no remaining eligible tasks (so that wait
+ // should return ECHILD).
+ t.parent.tg.eventQueue.Notify(EventExit | EventChildGroupStop | EventGroupContinue)
+ }
+ }
+ }
+ if t.exitTracerAcked && t.exitParentAcked {
+ t.advanceExitStateLocked(TaskExitZombie, TaskExitDead)
+ for ns := t.tg.pidns; ns != nil; ns = ns.parent {
+ tid := ns.tids[t]
+ delete(ns.tasks, tid)
+ delete(ns.tids, t)
+ if t == t.tg.leader {
+ delete(ns.tgids, t.tg)
+ }
+ }
+ t.tg.exitedCPUStats.Accumulate(t.CPUStats())
+ t.tg.ioUsage.Accumulate(t.ioUsage)
+ t.tg.signalHandlers.mu.Lock()
+ t.tg.tasks.Remove(t)
+ t.tg.tasksCount--
+ tc := t.tg.tasksCount
+ t.tg.signalHandlers.mu.Unlock()
+ if tc == 1 && t != t.tg.leader {
+ // Our fromPtraceDetach doesn't matter here (in Linux terms, this
+ // is via a call to release_task()).
+ t.tg.leader.exitNotifyLocked(false)
+ } else if tc == 0 {
+ t.tg.processGroup.decRefWithParent(t.tg.parentPG())
+ }
+ if t.parent != nil {
+ delete(t.parent.children, t)
+ t.parent = nil
+ }
+ }
+}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.SignalInfo {
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ }
+ info.SetPid(int32(receiver.tg.pidns.tids[t]))
+ info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ if t.exitStatus.Signaled() {
+ info.Code = arch.CLD_KILLED
+ info.SetStatus(int32(t.exitStatus.Signo))
+ } else {
+ info.Code = arch.CLD_EXITED
+ info.SetStatus(int32(t.exitStatus.Code))
+ }
+ // TODO(b/72102453): Set utime, stime.
+ return info
+}
+
+// ExitStatus returns t's exit status, which is only guaranteed to be
+// meaningful if t.ExitState() != TaskExitNone.
+func (t *Task) ExitStatus() ExitStatus {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.exitStatus
+}
+
+// ExitStatus returns the exit status that would be returned by a consuming
+// wait*() on tg.
+func (tg *ThreadGroup) ExitStatus() ExitStatus {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ if tg.exiting {
+ return tg.exitStatus
+ }
+ return tg.leader.exitStatus
+}
+
+// TerminationSignal returns the thread group's termination signal.
+func (tg *ThreadGroup) TerminationSignal() linux.Signal {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.terminationSignal
+}
+
+// Task events that can be waited for.
+const (
+ // EventExit represents an exit notification generated for a child thread
+ // group leader or a tracee under the conditions specified in the comment
+ // above runExitNotify.
+ EventExit waiter.EventMask = 1 << iota
+
+ // EventChildGroupStop occurs when a child thread group completes a group
+ // stop (i.e. all tasks in the child thread group have entered a stopped
+ // state as a result of a group stop).
+ EventChildGroupStop
+
+ // EventTraceeStop occurs when a task that is ptraced by a task in the
+ // notified thread group enters a ptrace stop (see ptrace(2)).
+ EventTraceeStop
+
+ // EventGroupContinue occurs when a child thread group, or a thread group
+ // whose leader is ptraced by a task in the notified thread group, that had
+ // initiated or completed a group stop leaves the group stop, due to the
+ // child thread group or any task in the child thread group being sent
+ // SIGCONT.
+ EventGroupContinue
+)
+
+// WaitOptions controls the behavior of Task.Wait.
+type WaitOptions struct {
+ // If SpecificTID is non-zero, only events from the task with thread ID
+ // SpecificTID are eligible to be waited for. SpecificTID is resolved in
+ // the PID namespace of the waiter (the method receiver of Task.Wait). If
+ // no such task exists, or that task would not otherwise be eligible to be
+ // waited for by the waiting task, then there are no waitable tasks and
+ // Wait will return ECHILD.
+ SpecificTID ThreadID
+
+ // If SpecificPGID is non-zero, only events from ThreadGroups with a
+ // matching ProcessGroupID are eligible to be waited for. (Same
+ // constraints as SpecificTID apply.)
+ SpecificPGID ProcessGroupID
+
+ // Terminology note: Per waitpid(2), "a clone child is one which delivers
+ // no signal, or a signal other than SIGCHLD to its parent upon
+ // termination." In Linux, termination signal is technically a per-task
+ // property rather than a per-thread-group property. However, clone()
+ // forces no termination signal for tasks created with CLONE_THREAD, and
+ // execve() resets the termination signal to SIGCHLD, so all
+ // non-group-leader threads have no termination signal and are therefore
+ // "clone tasks".
+
+ // If NonCloneTasks is true, events from non-clone tasks are eligible to be
+ // waited for.
+ NonCloneTasks bool
+
+ // If CloneTasks is true, events from clone tasks are eligible to be waited
+ // for.
+ CloneTasks bool
+
+ // If SiblingChildren is true, events from children tasks of any task
+ // in the thread group of the waiter are eligible to be waited for.
+ SiblingChildren bool
+
+ // Events is a bitwise combination of the events defined above that specify
+ // what events are of interest to the call to Wait.
+ Events waiter.EventMask
+
+ // If ConsumeEvent is true, the Wait should consume the event such that it
+ // cannot be returned by a future Wait. Note that if a task exit is
+ // consumed in this way, in most cases the task will be reaped.
+ ConsumeEvent bool
+
+ // If BlockInterruptErr is not nil, Wait will block until either an event
+ // is available or there are no tasks that could produce a waitable event;
+ // if that blocking is interrupted, Wait returns BlockInterruptErr. If
+ // BlockInterruptErr is nil, Wait will not block.
+ BlockInterruptErr error
+}
+
+// Preconditions: The TaskSet mutex must be locked (for reading or writing).
+func (o *WaitOptions) matchesTask(t *Task, pidns *PIDNamespace, tracee bool) bool {
+ if o.SpecificTID != 0 && o.SpecificTID != pidns.tids[t] {
+ return false
+ }
+ if o.SpecificPGID != 0 && o.SpecificPGID != pidns.pgids[t.tg.processGroup] {
+ return false
+ }
+ // Tracees are always eligible.
+ if tracee {
+ return true
+ }
+ if t == t.tg.leader && t.tg.terminationSignal == linux.SIGCHLD {
+ return o.NonCloneTasks
+ }
+ return o.CloneTasks
+}
+
+// ErrNoWaitableEvent is returned by non-blocking Task.Waits (e.g.
+// waitpid(WNOHANG)) that find no waitable events, but determine that waitable
+// events may exist in the future. (In contrast, if a non-blocking or blocking
+// Wait determines that there are no tasks that can produce a waitable event,
+// Task.Wait returns ECHILD.)
+var ErrNoWaitableEvent = errors.New("non-blocking Wait found eligible threads but no waitable events")
+
+// WaitResult contains information about a waited-for event.
+type WaitResult struct {
+ // Task is the task that reported the event.
+ Task *Task
+
+ // TID is the thread ID of Task in the PID namespace of the task that
+ // called Wait (that is, the method receiver of the call to Task.Wait). TID
+ // is provided because consuming exit waits cause the thread ID to be
+ // deallocated.
+ TID ThreadID
+
+ // UID is the real UID of Task in the user namespace of the task that
+ // called Wait.
+ UID auth.UID
+
+ // Event is exactly one of the events defined above.
+ Event waiter.EventMask
+
+ // Status is the numeric status associated with the event.
+ Status uint32
+}
+
+// Wait waits for an event from a thread group that is a child of t's thread
+// group, or a task in such a thread group, or a task that is ptraced by t,
+// subject to the options specified in opts.
+func (t *Task) Wait(opts *WaitOptions) (*WaitResult, error) {
+ if opts.BlockInterruptErr == nil {
+ return t.waitOnce(opts)
+ }
+ w, ch := waiter.NewChannelEntry(nil)
+ t.tg.eventQueue.EventRegister(&w, opts.Events)
+ defer t.tg.eventQueue.EventUnregister(&w)
+ for {
+ wr, err := t.waitOnce(opts)
+ if err != ErrNoWaitableEvent {
+ // This includes err == nil.
+ return wr, err
+ }
+ if err := t.Block(ch); err != nil {
+ return wr, syserror.ConvertIntr(err, opts.BlockInterruptErr)
+ }
+ }
+}
+
+func (t *Task) waitOnce(opts *WaitOptions) (*WaitResult, error) {
+ anyWaitableTasks := false
+
+ t.tg.pidns.owner.mu.Lock()
+ defer t.tg.pidns.owner.mu.Unlock()
+
+ if opts.SiblingChildren {
+ // We can wait on the children and tracees of any task in the
+ // same thread group.
+ for parent := t.tg.tasks.Front(); parent != nil; parent = parent.Next() {
+ wr, any := t.waitParentLocked(opts, parent)
+ if wr != nil {
+ return wr, nil
+ }
+ anyWaitableTasks = anyWaitableTasks || any
+ }
+ } else {
+ // We can only wait on this task.
+ var wr *WaitResult
+ wr, anyWaitableTasks = t.waitParentLocked(opts, t)
+ if wr != nil {
+ return wr, nil
+ }
+ }
+
+ if anyWaitableTasks {
+ return nil, ErrNoWaitableEvent
+ }
+ return nil, syserror.ECHILD
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitParentLocked(opts *WaitOptions, parent *Task) (*WaitResult, bool) {
+ anyWaitableTasks := false
+
+ for child := range parent.children {
+ if !opts.matchesTask(child, parent.tg.pidns, false) {
+ continue
+ }
+ // Non-leaders don't notify parents on exit and aren't eligible to
+ // be waited on.
+ if opts.Events&EventExit != 0 && child == child.tg.leader && !child.exitParentAcked {
+ anyWaitableTasks = true
+ if wr := t.waitCollectZombieLocked(child, opts, false); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ // Check for group stops and continues. Tasks that have passed
+ // TaskExitInitiated can no longer participate in group stops.
+ if opts.Events&(EventChildGroupStop|EventGroupContinue) == 0 {
+ continue
+ }
+ if child.exitState >= TaskExitInitiated {
+ continue
+ }
+ // If the waiter is in the same thread group as the task's
+ // tracer, do not report its group stops; they will be reported
+ // as ptrace stops instead. This also skips checking for group
+ // continues, but they'll be checked for when scanning tracees
+ // below. (Per kernel/exit.c:wait_consider_task(): "If a
+ // ptracer wants to distinguish the two events for its own
+ // children, it should create a separate process which takes
+ // the role of real parent.")
+ if tracer := child.Tracer(); tracer != nil && tracer.tg == parent.tg {
+ continue
+ }
+ anyWaitableTasks = true
+ if opts.Events&EventChildGroupStop != 0 {
+ if wr := t.waitCollectChildGroupStopLocked(child, opts); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ if opts.Events&EventGroupContinue != 0 {
+ if wr := t.waitCollectGroupContinueLocked(child, opts); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ }
+ for tracee := range parent.ptraceTracees {
+ if !opts.matchesTask(tracee, parent.tg.pidns, true) {
+ continue
+ }
+ // Non-leaders do notify tracers on exit.
+ if opts.Events&EventExit != 0 && !tracee.exitTracerAcked {
+ anyWaitableTasks = true
+ if wr := t.waitCollectZombieLocked(tracee, opts, true); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ if opts.Events&(EventTraceeStop|EventGroupContinue) == 0 {
+ continue
+ }
+ if tracee.exitState >= TaskExitInitiated {
+ continue
+ }
+ anyWaitableTasks = true
+ if opts.Events&EventTraceeStop != 0 {
+ if wr := t.waitCollectTraceeStopLocked(tracee, opts); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ if opts.Events&EventGroupContinue != 0 {
+ if wr := t.waitCollectGroupContinueLocked(tracee, opts); wr != nil {
+ return wr, anyWaitableTasks
+ }
+ }
+ }
+
+ return nil, anyWaitableTasks
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitCollectZombieLocked(target *Task, opts *WaitOptions, asPtracer bool) *WaitResult {
+ if asPtracer && !target.exitTracerNotified {
+ return nil
+ }
+ if !asPtracer && !target.exitParentNotified {
+ return nil
+ }
+ // Zombied thread group leaders are never waitable until their thread group
+ // is otherwise empty. Usually this is caught by the
+ // target.exitParentNotified check above, but if t is both (in the thread
+ // group of) target's tracer and parent, asPtracer may be true.
+ if target == target.tg.leader && target.tg.tasksCount != 1 {
+ return nil
+ }
+ pid := t.tg.pidns.tids[target]
+ uid := target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()
+ status := target.exitStatus.Status()
+ if !opts.ConsumeEvent {
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventExit,
+ Status: status,
+ }
+ }
+ // Surprisingly, the exit status reported by a non-consuming wait can
+ // differ from that reported by a consuming wait; the latter will return
+ // the group exit code if one is available.
+ if target.tg.exiting {
+ status = target.tg.exitStatus.Status()
+ }
+ // t may be (in the thread group of) target's parent, tracer, or both. We
+ // don't need to check for !exitTracerAcked because tracees are detached
+ // here, and we don't need to check for !exitParentAcked because zombies
+ // will be reaped here.
+ if tracer := target.Tracer(); tracer != nil && tracer.tg == t.tg && target.exitTracerNotified {
+ target.exitTracerAcked = true
+ target.ptraceTracer.Store((*Task)(nil))
+ delete(t.ptraceTracees, target)
+ }
+ if target.parent != nil && target.parent.tg == t.tg && target.exitParentNotified {
+ target.exitParentAcked = true
+ if target == target.tg.leader {
+ // target.tg.exitedCPUStats doesn't include target.CPUStats() yet,
+ // and won't until after target.exitNotifyLocked() (maybe). Include
+ // target.CPUStats() explicitly. This is consistent with Linux,
+ // which accounts an exited task's cputime to its thread group in
+ // kernel/exit.c:release_task() => __exit_signal(), and uses
+ // thread_group_cputime_adjusted() in wait_task_zombie().
+ t.tg.childCPUStats.Accumulate(target.CPUStats())
+ t.tg.childCPUStats.Accumulate(target.tg.exitedCPUStats)
+ t.tg.childCPUStats.Accumulate(target.tg.childCPUStats)
+ // Update t's child max resident set size. The size will be the maximum
+ // of this thread's size and all its childrens' sizes.
+ if t.tg.childMaxRSS < target.tg.maxRSS {
+ t.tg.childMaxRSS = target.tg.maxRSS
+ }
+ if t.tg.childMaxRSS < target.tg.childMaxRSS {
+ t.tg.childMaxRSS = target.tg.childMaxRSS
+ }
+ }
+ }
+ target.exitNotifyLocked(false)
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventExit,
+ Status: status,
+ }
+}
+
+// updateRSSLocked updates t.tg.maxRSS.
+//
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) updateRSSLocked() {
+ if mmMaxRSS := t.MemoryManager().MaxResidentSetSize(); t.tg.maxRSS < mmMaxRSS {
+ t.tg.maxRSS = mmMaxRSS
+ }
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitCollectChildGroupStopLocked(target *Task, opts *WaitOptions) *WaitResult {
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if !target.tg.groupStopWaitable {
+ return nil
+ }
+ pid := t.tg.pidns.tids[target]
+ uid := target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()
+ sig := target.tg.groupStopSignal
+ if opts.ConsumeEvent {
+ target.tg.groupStopWaitable = false
+ }
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventChildGroupStop,
+ // There is no name for these status constants.
+ Status: (uint32(sig)&0xff)<<8 | 0x7f,
+ }
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitCollectGroupContinueLocked(target *Task, opts *WaitOptions) *WaitResult {
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if !target.tg.groupContWaitable {
+ return nil
+ }
+ pid := t.tg.pidns.tids[target]
+ uid := target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()
+ if opts.ConsumeEvent {
+ target.tg.groupContWaitable = false
+ }
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventGroupContinue,
+ Status: 0xffff,
+ }
+}
+
+// Preconditions: The TaskSet mutex must be locked for writing.
+func (t *Task) waitCollectTraceeStopLocked(target *Task, opts *WaitOptions) *WaitResult {
+ target.tg.signalHandlers.mu.Lock()
+ defer target.tg.signalHandlers.mu.Unlock()
+ if target.stop == nil {
+ return nil
+ }
+ if _, ok := target.stop.(*ptraceStop); !ok {
+ return nil
+ }
+ if target.ptraceCode == 0 {
+ return nil
+ }
+ pid := t.tg.pidns.tids[target]
+ uid := target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()
+ code := target.ptraceCode
+ if opts.ConsumeEvent {
+ target.ptraceCode = 0
+ }
+ return &WaitResult{
+ Task: target,
+ TID: pid,
+ UID: uid,
+ Event: EventTraceeStop,
+ Status: uint32(code)<<8 | 0x7f,
+ }
+}
+
+// ExitState returns t's current progress through the exit path.
+func (t *Task) ExitState() TaskExitState {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ return t.exitState
+}
+
+// ParentDeathSignal returns t's parent death signal.
+func (t *Task) ParentDeathSignal() linux.Signal {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.parentDeathSignal
+}
+
+// SetParentDeathSignal sets t's parent death signal.
+func (t *Task) SetParentDeathSignal(sig linux.Signal) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.parentDeathSignal = sig
+}
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
new file mode 100644
index 000000000..f98097c2c
--- /dev/null
+++ b/pkg/sentry/kernel/task_futex.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Futex returns t's futex manager.
+//
+// Preconditions: The caller must be running on the task goroutine, or t.mu
+// must be locked.
+func (t *Task) Futex() *futex.Manager {
+ return t.tc.fu
+}
+
+// SwapUint32 implements futex.Target.SwapUint32.
+func (t *Task) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
+ return t.MemoryManager().SwapUint32(t, addr, new, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CompareAndSwapUint32 implemets futex.Target.CompareAndSwapUint32.
+func (t *Task) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
+ return t.MemoryManager().CompareAndSwapUint32(t, addr, old, new, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// LoadUint32 implemets futex.Target.LoadUint32.
+func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) {
+ return t.MemoryManager().LoadUint32(t, addr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// GetSharedKey implements futex.Target.GetSharedKey.
+func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) {
+ return t.MemoryManager().GetSharedFutexKey(t, addr)
+}
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
new file mode 100644
index 000000000..17f08729a
--- /dev/null
+++ b/pkg/sentry/kernel/task_identity.go
@@ -0,0 +1,568 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Credentials returns t's credentials.
+//
+// This value must be considered immutable.
+func (t *Task) Credentials() *auth.Credentials {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.creds
+}
+
+// UserNamespace returns the user namespace associated with the task.
+func (t *Task) UserNamespace() *auth.UserNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.creds.UserNamespace
+}
+
+// HasCapabilityIn checks if the task has capability cp in user namespace ns.
+func (t *Task) HasCapabilityIn(cp linux.Capability, ns *auth.UserNamespace) bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.creds.HasCapabilityIn(cp, ns)
+}
+
+// HasCapability checks if the task has capability cp in its user namespace.
+func (t *Task) HasCapability(cp linux.Capability) bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.creds.HasCapability(cp)
+}
+
+// SetUID implements the semantics of setuid(2).
+func (t *Task) SetUID(uid auth.UID) error {
+ // setuid considers -1 to be invalid.
+ if !uid.Ok() {
+ return syserror.EINVAL
+ }
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ kuid := t.creds.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ // "setuid() sets the effective user ID of the calling process. If the
+ // effective UID of the caller is root (more precisely: if the caller has
+ // the CAP_SETUID capability), the real UID and saved set-user-ID are also
+ // set." - setuid(2)
+ if t.creds.HasCapability(linux.CAP_SETUID) {
+ t.setKUIDsUncheckedLocked(kuid, kuid, kuid)
+ return nil
+ }
+ // "EPERM: The user is not privileged (Linux: does not have the CAP_SETUID
+ // capability) and uid does not match the real UID or saved set-user-ID of
+ // the calling process."
+ if kuid != t.creds.RealKUID && kuid != t.creds.SavedKUID {
+ return syserror.EPERM
+ }
+ t.setKUIDsUncheckedLocked(t.creds.RealKUID, kuid, t.creds.SavedKUID)
+ return nil
+}
+
+// SetREUID implements the semantics of setreuid(2).
+func (t *Task) SetREUID(r, e auth.UID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // "Supplying a value of -1 for either the real or effective user ID forces
+ // the system to leave that ID unchanged." - setreuid(2)
+ newR := t.creds.RealKUID
+ if r.Ok() {
+ newR = t.creds.UserNamespace.MapToKUID(r)
+ if !newR.Ok() {
+ return syserror.EINVAL
+ }
+ }
+ newE := t.creds.EffectiveKUID
+ if e.Ok() {
+ newE = t.creds.UserNamespace.MapToKUID(e)
+ if !newE.Ok() {
+ return syserror.EINVAL
+ }
+ }
+ if !t.creds.HasCapability(linux.CAP_SETUID) {
+ // "Unprivileged processes may only set the effective user ID to the
+ // real user ID, the effective user ID, or the saved set-user-ID."
+ if newE != t.creds.RealKUID && newE != t.creds.EffectiveKUID && newE != t.creds.SavedKUID {
+ return syserror.EPERM
+ }
+ // "Unprivileged users may only set the real user ID to the real user
+ // ID or the effective user ID."
+ if newR != t.creds.RealKUID && newR != t.creds.EffectiveKUID {
+ return syserror.EPERM
+ }
+ }
+ // "If the real user ID is set (i.e., ruid is not -1) or the effective user
+ // ID is set to a value not equal to the previous real user ID, the saved
+ // set-user-ID will be set to the new effective user ID."
+ newS := t.creds.SavedKUID
+ if r.Ok() || (e.Ok() && newE != t.creds.EffectiveKUID) {
+ newS = newE
+ }
+ t.setKUIDsUncheckedLocked(newR, newE, newS)
+ return nil
+}
+
+// SetRESUID implements the semantics of the setresuid(2) syscall.
+func (t *Task) SetRESUID(r, e, s auth.UID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // "Unprivileged user processes may change the real UID, effective UID, and
+ // saved set-user-ID, each to one of: the current real UID, the current
+ // effective UID or the current saved set-user-ID. Privileged processes (on
+ // Linux, those having the CAP_SETUID capability) may set the real UID,
+ // effective UID, and saved set-user-ID to arbitrary values. If one of the
+ // arguments equals -1, the corresponding value is not changed." -
+ // setresuid(2)
+ var err error
+ newR := t.creds.RealKUID
+ if r.Ok() {
+ newR, err = t.creds.UseUID(r)
+ if err != nil {
+ return err
+ }
+ }
+ newE := t.creds.EffectiveKUID
+ if e.Ok() {
+ newE, err = t.creds.UseUID(e)
+ if err != nil {
+ return err
+ }
+ }
+ newS := t.creds.SavedKUID
+ if s.Ok() {
+ newS, err = t.creds.UseUID(s)
+ if err != nil {
+ return err
+ }
+ }
+ t.setKUIDsUncheckedLocked(newR, newE, newS)
+ return nil
+}
+
+// Preconditions: t.mu must be locked.
+func (t *Task) setKUIDsUncheckedLocked(newR, newE, newS auth.KUID) {
+ root := t.creds.UserNamespace.MapToKUID(auth.RootUID)
+ oldR, oldE, oldS := t.creds.RealKUID, t.creds.EffectiveKUID, t.creds.SavedKUID
+ t.creds = t.creds.Fork() // See doc for creds.
+ t.creds.RealKUID, t.creds.EffectiveKUID, t.creds.SavedKUID = newR, newE, newS
+
+ // "1. If one or more of the real, effective or saved set user IDs was
+ // previously 0, and as a result of the UID changes all of these IDs have a
+ // nonzero value, then all capabilities are cleared from the permitted and
+ // effective capability sets." - capabilities(7)
+ if (oldR == root || oldE == root || oldS == root) && (newR != root && newE != root && newS != root) {
+ // prctl(2): "PR_SET_KEEPCAP: Set the state of the calling thread's
+ // "keep capabilities" flag, which determines whether the thread's permitted
+ // capability set is cleared when a change is made to the
+ // thread's user IDs such that the thread's real UID, effective
+ // UID, and saved set-user-ID all become nonzero when at least
+ // one of them previously had the value 0. By default, the
+ // permitted capability set is cleared when such a change is
+ // made; setting the "keep capabilities" flag prevents it from
+ // being cleared." (A thread's effective capability set is always
+ // cleared when such a credential change is made,
+ // regardless of the setting of the "keep capabilities" flag.)
+ if !t.creds.KeepCaps {
+ t.creds.PermittedCaps = 0
+ t.creds.EffectiveCaps = 0
+ }
+ }
+ // """
+ // 2. If the effective user ID is changed from 0 to nonzero, then all
+ // capabilities are cleared from the effective set.
+ //
+ // 3. If the effective user ID is changed from nonzero to 0, then the
+ // permitted set is copied to the effective set.
+ // """
+ if oldE == root && newE != root {
+ t.creds.EffectiveCaps = 0
+ } else if oldE != root && newE == root {
+ t.creds.EffectiveCaps = t.creds.PermittedCaps
+ }
+ // "4. If the filesystem user ID is changed from 0 to nonzero (see
+ // setfsuid(2)), then the following capabilities are cleared from the
+ // effective set: ..."
+ // (filesystem UIDs aren't implemented, nor are any of the capabilities in
+ // question)
+
+ // Not documented, but compare Linux's kernel/cred.c:commit_creds().
+ if oldE != newE {
+ t.parentDeathSignal = 0
+ }
+}
+
+// SetGID implements the semantics of setgid(2).
+func (t *Task) SetGID(gid auth.GID) error {
+ if !gid.Ok() {
+ return syserror.EINVAL
+ }
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ kgid := t.creds.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ if t.creds.HasCapability(linux.CAP_SETGID) {
+ t.setKGIDsUncheckedLocked(kgid, kgid, kgid)
+ return nil
+ }
+ if kgid != t.creds.RealKGID && kgid != t.creds.SavedKGID {
+ return syserror.EPERM
+ }
+ t.setKGIDsUncheckedLocked(t.creds.RealKGID, kgid, t.creds.SavedKGID)
+ return nil
+}
+
+// SetREGID implements the semantics of setregid(2).
+func (t *Task) SetREGID(r, e auth.GID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ newR := t.creds.RealKGID
+ if r.Ok() {
+ newR = t.creds.UserNamespace.MapToKGID(r)
+ if !newR.Ok() {
+ return syserror.EINVAL
+ }
+ }
+ newE := t.creds.EffectiveKGID
+ if e.Ok() {
+ newE = t.creds.UserNamespace.MapToKGID(e)
+ if !newE.Ok() {
+ return syserror.EINVAL
+ }
+ }
+ if !t.creds.HasCapability(linux.CAP_SETGID) {
+ if newE != t.creds.RealKGID && newE != t.creds.EffectiveKGID && newE != t.creds.SavedKGID {
+ return syserror.EPERM
+ }
+ if newR != t.creds.RealKGID && newR != t.creds.EffectiveKGID {
+ return syserror.EPERM
+ }
+ }
+ newS := t.creds.SavedKGID
+ if r.Ok() || (e.Ok() && newE != t.creds.EffectiveKGID) {
+ newS = newE
+ }
+ t.setKGIDsUncheckedLocked(newR, newE, newS)
+ return nil
+}
+
+// SetRESGID implements the semantics of the setresgid(2) syscall.
+func (t *Task) SetRESGID(r, e, s auth.GID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ var err error
+ newR := t.creds.RealKGID
+ if r.Ok() {
+ newR, err = t.creds.UseGID(r)
+ if err != nil {
+ return err
+ }
+ }
+ newE := t.creds.EffectiveKGID
+ if e.Ok() {
+ newE, err = t.creds.UseGID(e)
+ if err != nil {
+ return err
+ }
+ }
+ newS := t.creds.SavedKGID
+ if s.Ok() {
+ newS, err = t.creds.UseGID(s)
+ if err != nil {
+ return err
+ }
+ }
+ t.setKGIDsUncheckedLocked(newR, newE, newS)
+ return nil
+}
+
+func (t *Task) setKGIDsUncheckedLocked(newR, newE, newS auth.KGID) {
+ oldE := t.creds.EffectiveKGID
+ t.creds = t.creds.Fork() // See doc for creds.
+ t.creds.RealKGID, t.creds.EffectiveKGID, t.creds.SavedKGID = newR, newE, newS
+
+ // Not documented, but compare Linux's kernel/cred.c:commit_creds().
+ if oldE != newE {
+ t.parentDeathSignal = 0
+ }
+}
+
+// SetExtraGIDs attempts to change t's supplemental groups. All IDs are
+// interpreted as being in t's user namespace.
+func (t *Task) SetExtraGIDs(gids []auth.GID) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if !t.creds.HasCapability(linux.CAP_SETGID) {
+ return syserror.EPERM
+ }
+ kgids := make([]auth.KGID, len(gids))
+ for i, gid := range gids {
+ kgid := t.creds.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ kgids[i] = kgid
+ }
+ t.creds = t.creds.Fork() // See doc for creds.
+ t.creds.ExtraKGIDs = kgids
+ return nil
+}
+
+// SetCapabilitySets attempts to change t's permitted, inheritable, and
+// effective capability sets.
+func (t *Task) SetCapabilitySets(permitted, inheritable, effective auth.CapabilitySet) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // "Permitted: This is a limiting superset for the effective capabilities
+ // that the thread may assume." - capabilities(7)
+ if effective & ^permitted != 0 {
+ return syserror.EPERM
+ }
+ // "It is also a limiting superset for the capabilities that may be added
+ // to the inheritable set by a thread that does not have the CAP_SETPCAP
+ // capability in its effective set."
+ if !t.creds.HasCapability(linux.CAP_SETPCAP) && (inheritable & ^(t.creds.InheritableCaps|t.creds.PermittedCaps) != 0) {
+ return syserror.EPERM
+ }
+ // "If a thread drops a capability from its permitted set, it can never
+ // reacquire that capability (unless it execve(2)s ..."
+ if permitted & ^t.creds.PermittedCaps != 0 {
+ return syserror.EPERM
+ }
+ // "... if a capability is not in the bounding set, then a thread can't add
+ // this capability to its inheritable set, even if it was in its permitted
+ // capabilities ..."
+ if inheritable & ^(t.creds.InheritableCaps|t.creds.BoundingCaps) != 0 {
+ return syserror.EPERM
+ }
+ t.creds = t.creds.Fork() // See doc for creds.
+ t.creds.PermittedCaps = permitted
+ t.creds.InheritableCaps = inheritable
+ t.creds.EffectiveCaps = effective
+ return nil
+}
+
+// DropBoundingCapability attempts to drop capability cp from t's capability
+// bounding set.
+func (t *Task) DropBoundingCapability(cp linux.Capability) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if !t.creds.HasCapability(linux.CAP_SETPCAP) {
+ return syserror.EPERM
+ }
+ t.creds = t.creds.Fork() // See doc for creds.
+ t.creds.BoundingCaps &^= auth.CapabilitySetOf(cp)
+ return nil
+}
+
+// SetUserNamespace attempts to move c into ns.
+func (t *Task) SetUserNamespace(ns *auth.UserNamespace) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // "A process reassociating itself with a user namespace must have the
+ // CAP_SYS_ADMIN capability in the target user namespace." - setns(2)
+ //
+ // If t just created ns, then t.creds is guaranteed to have CAP_SYS_ADMIN
+ // in ns (by rule 3 in auth.Credentials.HasCapability).
+ if !t.creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, ns) {
+ return syserror.EPERM
+ }
+
+ t.creds = t.creds.Fork() // See doc for creds.
+ t.creds.UserNamespace = ns
+ // "The child process created by clone(2) with the CLONE_NEWUSER flag
+ // starts out with a complete set of capabilities in the new user
+ // namespace. Likewise, a process that creates a new user namespace using
+ // unshare(2) or joins an existing user namespace using setns(2) gains a
+ // full set of capabilities in that namespace."
+ t.creds.PermittedCaps = auth.AllCapabilities
+ t.creds.InheritableCaps = 0
+ t.creds.EffectiveCaps = auth.AllCapabilities
+ t.creds.BoundingCaps = auth.AllCapabilities
+ // "A call to clone(2), unshare(2), or setns(2) using the CLONE_NEWUSER
+ // flag sets the "securebits" flags (see capabilities(7)) to their default
+ // values (all flags disabled) in the child (for clone(2)) or caller (for
+ // unshare(2), or setns(2)." - user_namespaces(7)
+ t.creds.KeepCaps = false
+
+ return nil
+}
+
+// SetKeepCaps will set the keep capabilities flag PR_SET_KEEPCAPS.
+func (t *Task) SetKeepCaps(k bool) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.creds = t.creds.Fork() // See doc for creds.
+ t.creds.KeepCaps = k
+}
+
+// updateCredsForExec updates t.creds to reflect an execve().
+//
+// NOTE(b/30815691): We currently do not implement privileged executables
+// (set-user/group-ID bits and file capabilities). This allows us to make a lot
+// of simplifying assumptions:
+//
+// - We assume the no_new_privs bit (set by prctl(SET_NO_NEW_PRIVS)), which
+// disables the features we don't support anyway, is always set. This
+// drastically simplifies this function.
+//
+// - We don't implement AT_SECURE, because no_new_privs always being set means
+// that the conditions that require AT_SECURE never arise. (Compare Linux's
+// security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().)
+//
+// - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since
+// seccomp-bpf is also allowed if the task has no_new_privs set.
+//
+// - Task.ptraceAttach does not serialize with execve as it does in Linux,
+// since no_new_privs being set has the same effect as the presence of an
+// unprivileged tracer.
+//
+// Preconditions: t.mu must be locked.
+func (t *Task) updateCredsForExecLocked() {
+ // """
+ // During an execve(2), the kernel calculates the new capabilities of
+ // the process using the following algorithm:
+ //
+ // P'(permitted) = (P(inheritable) & F(inheritable)) |
+ // (F(permitted) & cap_bset)
+ //
+ // P'(effective) = F(effective) ? P'(permitted) : 0
+ //
+ // P'(inheritable) = P(inheritable) [i.e., unchanged]
+ //
+ // where:
+ //
+ // P denotes the value of a thread capability set before the
+ // execve(2)
+ //
+ // P' denotes the value of a thread capability set after the
+ // execve(2)
+ //
+ // F denotes a file capability set
+ //
+ // cap_bset is the value of the capability bounding set
+ //
+ // ...
+ //
+ // In order to provide an all-powerful root using capability sets, during
+ // an execve(2):
+ //
+ // 1. If a set-user-ID-root program is being executed, or the real user ID
+ // of the process is 0 (root) then the file inheritable and permitted sets
+ // are defined to be all ones (i.e. all capabilities enabled).
+ //
+ // 2. If a set-user-ID-root program is being executed, then the file
+ // effective bit is defined to be one (enabled).
+ //
+ // The upshot of the above rules, combined with the capabilities
+ // transformations described above, is that when a process execve(2)s a
+ // set-user-ID-root program, or when a process with an effective UID of 0
+ // execve(2)s a program, it gains all capabilities in its permitted and
+ // effective capability sets, except those masked out by the capability
+ // bounding set.
+ // """ - capabilities(7)
+ // (ambient capability sets omitted)
+ //
+ // As the last paragraph implies, the case of "a set-user-ID root program
+ // is being executed" also includes the case where (namespace) root is
+ // executing a non-set-user-ID program; the actual check is just based on
+ // the effective user ID.
+ var newPermitted auth.CapabilitySet // since F(inheritable) == F(permitted) == 0
+ fileEffective := false
+ root := t.creds.UserNamespace.MapToKUID(auth.RootUID)
+ if t.creds.EffectiveKUID == root || t.creds.RealKUID == root {
+ newPermitted = t.creds.InheritableCaps | t.creds.BoundingCaps
+ if t.creds.EffectiveKUID == root {
+ fileEffective = true
+ }
+ }
+
+ t.creds = t.creds.Fork() // See doc for creds.
+
+ // Now we enter poorly-documented, somewhat confusing territory. (The
+ // accompanying comment in Linux's security/commoncap.c:cap_bprm_set_creds
+ // is not very helpful.) My reading of it is:
+ //
+ // If at least one of the following is true:
+ //
+ // A1. The execing task is ptraced, and the tracer did not have
+ // CAP_SYS_PTRACE in the execing task's user namespace at the time of
+ // PTRACE_ATTACH.
+ //
+ // A2. The execing task shares its FS context with at least one task in
+ // another thread group.
+ //
+ // A3. The execing task has no_new_privs set.
+ //
+ // AND at least one of the following is true:
+ //
+ // B1. The new effective user ID (which may come from set-user-ID, or be the
+ // execing task's existing effective user ID) is not equal to the task's
+ // real UID.
+ //
+ // B2. The new effective group ID (which may come from set-group-ID, or be
+ // the execing task's existing effective group ID) is not equal to the
+ // task's real GID.
+ //
+ // B3. The new permitted capability set contains capabilities not in the
+ // task's permitted capability set.
+ //
+ // Then:
+ //
+ // C1. Limit the new permitted capability set to the task's permitted
+ // capability set.
+ //
+ // C2. If either the task does not have CAP_SETUID in its user namespace, or
+ // the task has no_new_privs set, force the new effective UID and GID to
+ // the task's real UID and GID.
+ //
+ // But since no_new_privs is always set (A3 is always true), this becomes
+ // much simpler. If B1 and B2 are false, C2 is a no-op. If B3 is false, C1
+ // is a no-op. So we can just do C1 and C2 unconditionally.
+ if t.creds.EffectiveKUID != t.creds.RealKUID || t.creds.EffectiveKGID != t.creds.RealKGID {
+ t.creds.EffectiveKUID = t.creds.RealKUID
+ t.creds.EffectiveKGID = t.creds.RealKGID
+ t.parentDeathSignal = 0
+ }
+ // (Saved set-user-ID is always set to the new effective user ID, and saved
+ // set-group-ID is always set to the new effective group ID, regardless of
+ // the above.)
+ t.creds.SavedKUID = t.creds.RealKUID
+ t.creds.SavedKGID = t.creds.RealKGID
+ t.creds.PermittedCaps &= newPermitted
+ if fileEffective {
+ t.creds.EffectiveCaps = t.creds.PermittedCaps
+ } else {
+ t.creds.EffectiveCaps = 0
+ }
+
+ // prctl(2): The "keep capabilities" value will be reset to 0 on subsequent
+ // calls to execve(2).
+ t.creds.KeepCaps = false
+
+ // "The bounding set is inherited at fork(2) from the thread's parent, and
+ // is preserved across an execve(2)". So we're done.
+}
diff --git a/pkg/sentry/kernel/task_list.go b/pkg/sentry/kernel/task_list.go
new file mode 100755
index 000000000..57d3f098d
--- /dev/null
+++ b/pkg/sentry/kernel/task_list.go
@@ -0,0 +1,173 @@
+package kernel
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type taskElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (taskElementMapper) linkerFor(elem *Task) *Task { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type taskList struct {
+ head *Task
+ tail *Task
+}
+
+// Reset resets list l to the empty state.
+func (l *taskList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *taskList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *taskList) Front() *Task {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *taskList) Back() *Task {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *taskList) PushFront(e *Task) {
+ taskElementMapper{}.linkerFor(e).SetNext(l.head)
+ taskElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ taskElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *taskList) PushBack(e *Task) {
+ taskElementMapper{}.linkerFor(e).SetNext(nil)
+ taskElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ taskElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *taskList) PushBackList(m *taskList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ taskElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ taskElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *taskList) InsertAfter(b, e *Task) {
+ a := taskElementMapper{}.linkerFor(b).Next()
+ taskElementMapper{}.linkerFor(e).SetNext(a)
+ taskElementMapper{}.linkerFor(e).SetPrev(b)
+ taskElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ taskElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *taskList) InsertBefore(a, e *Task) {
+ b := taskElementMapper{}.linkerFor(a).Prev()
+ taskElementMapper{}.linkerFor(e).SetNext(a)
+ taskElementMapper{}.linkerFor(e).SetPrev(b)
+ taskElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ taskElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *taskList) Remove(e *Task) {
+ prev := taskElementMapper{}.linkerFor(e).Prev()
+ next := taskElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ taskElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ taskElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type taskEntry struct {
+ next *Task
+ prev *Task
+}
+
+// Next returns the entry that follows e in the list.
+func (e *taskEntry) Next() *Task {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *taskEntry) Prev() *Task {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *taskEntry) SetNext(elem *Task) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *taskEntry) SetPrev(elem *Task) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
new file mode 100644
index 000000000..e0e57e8bd
--- /dev/null
+++ b/pkg/sentry/kernel/task_log.go
@@ -0,0 +1,137 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "sort"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+const (
+ // maxStackDebugBytes is the maximum number of user stack bytes that may be
+ // printed by debugDumpStack.
+ maxStackDebugBytes = 1024
+)
+
+// Infof logs an formatted info message by calling log.Infof.
+func (t *Task) Infof(fmt string, v ...interface{}) {
+ if log.IsLogging(log.Info) {
+ log.Infof(t.logPrefix.Load().(string)+fmt, v...)
+ }
+}
+
+// Warningf logs a warning string by calling log.Warningf.
+func (t *Task) Warningf(fmt string, v ...interface{}) {
+ if log.IsLogging(log.Warning) {
+ log.Warningf(t.logPrefix.Load().(string)+fmt, v...)
+ }
+}
+
+// Debugf creates a debug string that includes the task ID.
+func (t *Task) Debugf(fmt string, v ...interface{}) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf(t.logPrefix.Load().(string)+fmt, v...)
+ }
+}
+
+// IsLogging returns true iff this level is being logged.
+func (t *Task) IsLogging(level log.Level) bool {
+ return log.IsLogging(level)
+}
+
+// DebugDumpState logs task state at log level debug.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) DebugDumpState() {
+ t.debugDumpRegisters()
+ t.debugDumpStack()
+ if mm := t.MemoryManager(); mm != nil {
+ t.Debugf("Mappings:\n%s", mm)
+ }
+ t.Debugf("FDMap:\n%s", t.fds)
+}
+
+// debugDumpRegisters logs register state at log level debug.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) debugDumpRegisters() {
+ if !t.IsLogging(log.Debug) {
+ return
+ }
+ regmap, err := t.Arch().RegisterMap()
+ if err != nil {
+ t.Debugf("Registers: %v", err)
+ } else {
+ t.Debugf("Registers:")
+ var regs []string
+ for reg := range regmap {
+ regs = append(regs, reg)
+ }
+ sort.Strings(regs)
+ for _, reg := range regs {
+ t.Debugf("%-8s = %016x", reg, regmap[reg])
+ }
+ }
+}
+
+// debugDumpStack logs user stack contents at log level debug.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) debugDumpStack() {
+ if !t.IsLogging(log.Debug) {
+ return
+ }
+ m := t.MemoryManager()
+ if m == nil {
+ t.Debugf("Memory manager for task is gone, skipping application stack dump.")
+ return
+ }
+ t.Debugf("Stack:")
+ start := usermem.Addr(t.Arch().Stack())
+ // Round addr down to a 16-byte boundary.
+ start &= ^usermem.Addr(15)
+ // Print 16 bytes per line, one byte at a time.
+ for offset := uint64(0); offset < maxStackDebugBytes; offset += 16 {
+ addr, ok := start.AddLength(offset)
+ if !ok {
+ break
+ }
+ var data [16]byte
+ n, err := m.CopyIn(t, addr, data[:], usermem.IOOpts{
+ IgnorePermissions: true,
+ })
+ // Print as much of the line as we can, even if an error was
+ // encountered.
+ if n > 0 {
+ t.Debugf("%x: % x", addr, data[:n])
+ }
+ if err != nil {
+ t.Debugf("Error reading stack at address %x: %v", addr+usermem.Addr(n), err)
+ break
+ }
+ }
+}
+
+// updateLogPrefix updates the task's cached log prefix to reflect its
+// current thread ID.
+//
+// Preconditions: The task's owning TaskSet.mu must be locked.
+func (t *Task) updateLogPrefixLocked() {
+ // Use the task's TID in the root PID namespace for logging.
+ t.logPrefix.Store(fmt.Sprintf("[% 4d] ", t.tg.pidns.owner.Root.tids[t]))
+}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
new file mode 100644
index 000000000..04c684c1a
--- /dev/null
+++ b/pkg/sentry/kernel/task_net.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+)
+
+// IsNetworkNamespaced returns true if t is in a non-root network namespace.
+func (t *Task) IsNetworkNamespaced() bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns
+}
+
+// NetworkContext returns the network stack used by the task. NetworkContext
+// may return nil if no network stack is available.
+func (t *Task) NetworkContext() inet.Stack {
+ if t.IsNetworkNamespaced() {
+ return nil
+ }
+ return t.k.networkStack
+}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
new file mode 100644
index 000000000..a79101a18
--- /dev/null
+++ b/pkg/sentry/kernel/task_run.go
@@ -0,0 +1,340 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "runtime"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/hostcpu"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// A taskRunState is a reified state in the task state machine. See README.md
+// for details. The canonical list of all run states, as well as transitions
+// between them, is given in run_states.dot.
+//
+// The set of possible states is enumerable and completely defined by the
+// kernel package, so taskRunState would ideally be represented by a
+// discriminated union. However, Go does not support sum types.
+//
+// Hence, as with TaskStop, data-free taskRunStates should be represented as
+// typecast nils to avoid unnecessary allocation.
+type taskRunState interface {
+ // execute executes the code associated with this state over the given task
+ // and returns the following state. If execute returns nil, the task
+ // goroutine should exit.
+ //
+ // It is valid to tail-call a following state's execute to avoid the
+ // overhead of converting the following state to an interface object and
+ // checking for stops, provided that the tail-call cannot recurse.
+ execute(*Task) taskRunState
+}
+
+// run runs the task goroutine.
+//
+// threadID a dummy value set to the task's TID in the root PID namespace to
+// make it visible in stack dumps. A goroutine for a given task can be identified
+// searching for Task.run()'s argument value.
+func (t *Task) run(threadID uintptr) {
+ // Construct t.blockingTimer here. We do this here because we can't
+ // reconstruct t.blockingTimer during restore in Task.afterLoad(), because
+ // kernel.timekeeper.SetClocks() hasn't been called yet.
+ blockingTimerNotifier, blockingTimerChan := ktime.NewChannelNotifier()
+ t.blockingTimer = ktime.NewTimer(t.k.MonotonicClock(), blockingTimerNotifier)
+ defer t.blockingTimer.Destroy()
+ t.blockingTimerChan = blockingTimerChan
+
+ // Activate our address space.
+ t.Activate()
+ // The corresponding t.Deactivate occurs in the exit path
+ // (runExitMain.execute) so that when
+ // Platform.CooperativelySharesAddressSpace() == true, we give up the
+ // AddressSpace before the task goroutine finishes executing.
+
+ // If this is a newly-started task, it should check for participation in
+ // group stops. If this is a task resuming after restore, it was
+ // interrupted by saving. In either case, the task is initially
+ // interrupted.
+ t.interruptSelf()
+
+ for {
+ // Explanation for this ordering:
+ //
+ // - A freshly-started task that is stopped should not do anything
+ // before it enters the stop.
+ //
+ // - If taskRunState.execute returns nil, the task goroutine should
+ // exit without checking for a stop.
+ //
+ // - Task.Start won't start Task.run if t.runState is nil, so this
+ // ordering is safe.
+ t.doStop()
+ t.runState = t.runState.execute(t)
+ if t.runState == nil {
+ t.accountTaskGoroutineEnter(TaskGoroutineNonexistent)
+ t.goroutineStopped.Done()
+ t.tg.liveGoroutines.Done()
+ t.tg.pidns.owner.liveGoroutines.Done()
+ t.tg.pidns.owner.runningGoroutines.Done()
+
+ // Keep argument alive because stack trace for dead variables may not be correct.
+ runtime.KeepAlive(threadID)
+ return
+ }
+ }
+}
+
+// doStop is called by Task.run to block until the task is not stopped.
+func (t *Task) doStop() {
+ if atomic.LoadInt32(&t.stopCount) == 0 {
+ return
+ }
+ t.Deactivate()
+ // NOTE(b/30316266): t.Activate() must be called without any locks held, so
+ // this defer must precede the defer for unlocking the signal mutex.
+ defer t.Activate()
+ t.accountTaskGoroutineEnter(TaskGoroutineStopped)
+ defer t.accountTaskGoroutineLeave(TaskGoroutineStopped)
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.tg.pidns.owner.runningGoroutines.Add(-1)
+ defer t.tg.pidns.owner.runningGoroutines.Add(1)
+ t.goroutineStopped.Add(-1)
+ defer t.goroutineStopped.Add(1)
+ for t.stopCount > 0 {
+ t.endStopCond.Wait()
+ }
+}
+
+// The runApp state checks for interrupts before executing untrusted
+// application code.
+//
+// +stateify savable
+type runApp struct{}
+
+func (*runApp) execute(t *Task) taskRunState {
+ if t.interrupted() {
+ // Checkpointing instructs tasks to stop by sending an interrupt, so we
+ // must check for stops before entering runInterrupt (instead of
+ // tail-calling it).
+ return (*runInterrupt)(nil)
+ }
+
+ // We're about to switch to the application again. If there's still a
+ // unhandled SyscallRestartErrno that wasn't translated to an EINTR,
+ // restart the syscall that was interrupted. If there's a saved signal
+ // mask, restore it. (Note that restoring the saved signal mask may unblock
+ // a pending signal, causing another interruption, but that signal should
+ // not interact with the interrupted syscall.)
+ if t.haveSyscallReturn {
+ if sre, ok := SyscallRestartErrnoFromReturn(t.Arch().Return()); ok {
+ if sre == ERESTART_RESTARTBLOCK {
+ t.Debugf("Restarting syscall %d with restart block after errno %d: not interrupted by handled signal", t.Arch().SyscallNo(), sre)
+ t.Arch().RestartSyscallWithRestartBlock()
+ } else {
+ t.Debugf("Restarting syscall %d after errno %d: not interrupted by handled signal", t.Arch().SyscallNo(), sre)
+ t.Arch().RestartSyscall()
+ }
+ }
+ t.haveSyscallReturn = false
+ }
+ if t.haveSavedSignalMask {
+ t.SetSignalMask(t.savedSignalMask)
+ t.haveSavedSignalMask = false
+ if t.interrupted() {
+ return (*runInterrupt)(nil)
+ }
+ }
+
+ // Apply restartable sequences.
+ if t.rseqPreempted {
+ t.rseqPreempted = false
+ if t.rseqCPUAddr != 0 {
+ cpu := int32(hostcpu.GetCPU())
+ if t.rseqCPU != cpu {
+ t.rseqCPU = cpu
+ if err := t.rseqCopyOutCPU(); err != nil {
+ t.Warningf("Failed to copy CPU to %#x for RSEQ: %v", t.rseqCPUAddr, err)
+ t.forceSignal(linux.SIGSEGV, false)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // Re-enter the task run loop for signal delivery.
+ return (*runApp)(nil)
+ }
+ }
+ }
+ t.rseqInterrupt()
+ }
+
+ // Check if we need to enable single-stepping. Tracers expect that the
+ // kernel preserves the value of the single-step flag set by PTRACE_SETREGS
+ // whether or not PTRACE_SINGLESTEP/PTRACE_SYSEMU_SINGLESTEP is used (this
+ // includes our ptrace platform, by the way), so we should only clear the
+ // single-step flag if we're responsible for setting it. (clearSinglestep
+ // is therefore analogous to Linux's TIF_FORCED_TF.)
+ //
+ // Strictly speaking, we should also not clear the single-step flag if we
+ // single-step through an instruction that sets the single-step flag
+ // (arch/x86/kernel/step.c:is_setting_trap_flag()). But nobody sets their
+ // own TF. (Famous last words, I know.)
+ clearSinglestep := false
+ if t.hasTracer() {
+ t.tg.pidns.owner.mu.RLock()
+ if t.ptraceSinglestep {
+ clearSinglestep = !t.Arch().SingleStep()
+ t.Arch().SetSingleStep()
+ }
+ t.tg.pidns.owner.mu.RUnlock()
+ }
+
+ t.accountTaskGoroutineEnter(TaskGoroutineRunningApp)
+ info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU)
+ t.accountTaskGoroutineLeave(TaskGoroutineRunningApp)
+
+ if clearSinglestep {
+ t.Arch().ClearSingleStep()
+ }
+
+ switch err {
+ case nil:
+ // Handle application system call.
+ return t.doSyscall()
+
+ case platform.ErrContextInterrupt:
+ // Interrupted by platform.Context.Interrupt(). Re-enter the run
+ // loop to figure out why.
+ return (*runApp)(nil)
+
+ case platform.ErrContextSignalCPUID:
+ // Is this a CPUID instruction?
+ expected := arch.CPUIDInstruction[:]
+ found := make([]byte, len(expected))
+ _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
+ if err == nil && bytes.Equal(expected, found) {
+ // Skip the cpuid instruction.
+ t.Arch().CPUIDEmulate(t)
+ t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
+
+ // Resume execution.
+ return (*runApp)(nil)
+ }
+
+ // The instruction at the given RIP was not a CPUID, and we
+ // fallthrough to the default signal deliver behavior below.
+ fallthrough
+
+ case platform.ErrContextSignal:
+ // Looks like a signal has been delivered to us. If it's a synchronous
+ // signal (SEGV, SIGBUS, etc.), it should be sent to the application
+ // thread that received it.
+ sig := linux.Signal(info.Signo)
+
+ // Was it a fault that we should handle internally? If so, this wasn't
+ // an application-generated signal and we should continue execution
+ // normally.
+ if at.Any() {
+ addr := usermem.Addr(info.Addr())
+ err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack()))
+ if err == nil {
+ // The fault was handled appropriately.
+ // We can resume running the application.
+ return (*runApp)(nil)
+ }
+
+ // Is this a vsyscall that we need emulate?
+ if at.Execute {
+ if sysno, ok := t.tc.st.LookupEmulate(addr); ok {
+ return t.doVsyscall(addr, sysno)
+ }
+ }
+
+ // Faults are common, log only at debug level.
+ t.Debugf("Unhandled user fault: addr=%x ip=%x access=%v err=%v", addr, t.Arch().IP(), at, err)
+ t.DebugDumpState()
+
+ // Continue to signal handling.
+ //
+ // Convert a BusError error to a SIGBUS from a SIGSEGV. All
+ // other info bits stay the same (address, etc.).
+ if _, ok := err.(*memmap.BusError); ok {
+ sig = linux.SIGBUS
+ info.Signo = int32(linux.SIGBUS)
+ }
+ }
+
+ switch sig {
+ case linux.SIGILL, linux.SIGSEGV, linux.SIGBUS, linux.SIGFPE, linux.SIGTRAP:
+ // Synchronous signal. Send it to ourselves. Assume the signal is
+ // legitimate and force it (work around the signal being ignored or
+ // blocked) like Linux does. Conveniently, this is even the correct
+ // behavior for SIGTRAP from single-stepping.
+ t.forceSignal(linux.Signal(sig), false /* unconditional */)
+ t.SendSignal(info)
+
+ case platform.SignalInterrupt:
+ // Assume that a call to platform.Context.Interrupt() misfired.
+
+ case linux.SIGPROF:
+ // It's a profiling interrupt: there's not much
+ // we can do. We've already paid a decent cost
+ // by intercepting the signal, at this point we
+ // simply ignore it.
+
+ default:
+ // Asynchronous signal. Let the system deal with it.
+ t.k.sendExternalSignal(info, "application")
+ }
+
+ return (*runApp)(nil)
+
+ case platform.ErrContextCPUPreempted:
+ // Ensure that RSEQ critical sections are interrupted and per-thread
+ // CPU values are updated before the next platform.Context.Switch().
+ t.rseqPreempted = true
+ return (*runApp)(nil)
+
+ default:
+ // What happened? Can't continue.
+ t.Warningf("Unexpected SwitchToApp error: %v", err)
+ t.PrepareExit(ExitStatus{Code: t.ExtractErrno(err, -1)})
+ return (*runExit)(nil)
+ }
+}
+
+// waitGoroutineStoppedOrExited blocks until t's task goroutine stops or exits.
+func (t *Task) waitGoroutineStoppedOrExited() {
+ t.goroutineStopped.Wait()
+}
+
+// WaitExited blocks until all task goroutines in tg have exited.
+//
+// WaitExited does not correspond to anything in Linux; it's provided so that
+// external callers of Kernel.CreateProcess can wait for the created thread
+// group to terminate.
+func (tg *ThreadGroup) WaitExited() {
+ tg.liveGoroutines.Wait()
+}
+
+// Yield yields the processor for the calling task.
+func (t *Task) Yield() {
+ atomic.AddUint64(&t.yieldCount, 1)
+ runtime.Gosched()
+}
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
new file mode 100644
index 000000000..5455f6ea9
--- /dev/null
+++ b/pkg/sentry/kernel/task_sched.go
@@ -0,0 +1,637 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// CPU scheduling, real and fake.
+
+import (
+ "fmt"
+ "math/rand"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/hostcpu"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/sched"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// TaskGoroutineState is a coarse representation of the current execution
+// status of a kernel.Task goroutine.
+type TaskGoroutineState int
+
+const (
+ // TaskGoroutineNonexistent indicates that the task goroutine has either
+ // not yet been created by Task.Start() or has returned from Task.run().
+ // This must be the zero value for TaskGoroutineState.
+ TaskGoroutineNonexistent TaskGoroutineState = iota
+
+ // TaskGoroutineRunningSys indicates that the task goroutine is executing
+ // sentry code.
+ TaskGoroutineRunningSys
+
+ // TaskGoroutineRunningApp indicates that the task goroutine is executing
+ // application code.
+ TaskGoroutineRunningApp
+
+ // TaskGoroutineBlockedInterruptible indicates that the task goroutine is
+ // blocked in Task.block(), and hence may be woken by Task.interrupt()
+ // (e.g. due to signal delivery).
+ TaskGoroutineBlockedInterruptible
+
+ // TaskGoroutineBlockedUninterruptible indicates that the task goroutine is
+ // stopped outside of Task.block() and Task.doStop(), and hence cannot be
+ // woken by Task.interrupt().
+ TaskGoroutineBlockedUninterruptible
+
+ // TaskGoroutineStopped indicates that the task goroutine is blocked in
+ // Task.doStop(). TaskGoroutineStopped is similar to
+ // TaskGoroutineBlockedUninterruptible, but is a separate state to make it
+ // possible to determine when Task.stop is meaningful.
+ TaskGoroutineStopped
+)
+
+// TaskGoroutineSchedInfo contains task goroutine scheduling state which must
+// be read and updated atomically.
+//
+// +stateify savable
+type TaskGoroutineSchedInfo struct {
+ // Timestamp was the value of Kernel.cpuClock when this
+ // TaskGoroutineSchedInfo was last updated.
+ Timestamp uint64
+
+ // State is the current state of the task goroutine.
+ State TaskGoroutineState
+
+ // UserTicks is the amount of time the task goroutine has spent executing
+ // its associated Task's application code, in units of linux.ClockTick.
+ UserTicks uint64
+
+ // SysTicks is the amount of time the task goroutine has spent executing in
+ // the sentry, in units of linux.ClockTick.
+ SysTicks uint64
+}
+
+// userTicksAt returns the extrapolated value of ts.UserTicks after
+// Kernel.CPUClockNow() indicates a time of now.
+//
+// Preconditions: now <= Kernel.CPUClockNow(). (Since Kernel.cpuClock is
+// monotonic, this is satisfied if now is the result of a previous call to
+// Kernel.CPUClockNow().) This requirement exists because otherwise a racing
+// change to t.gosched can cause userTicksAt to adjust stats by too much,
+// making the observed stats non-monotonic.
+func (ts *TaskGoroutineSchedInfo) userTicksAt(now uint64) uint64 {
+ if ts.Timestamp < now && ts.State == TaskGoroutineRunningApp {
+ // Update stats to reflect execution since the last update.
+ return ts.UserTicks + (now - ts.Timestamp)
+ }
+ return ts.UserTicks
+}
+
+// sysTicksAt returns the extrapolated value of ts.SysTicks after
+// Kernel.CPUClockNow() indicates a time of now.
+//
+// Preconditions: As for userTicksAt.
+func (ts *TaskGoroutineSchedInfo) sysTicksAt(now uint64) uint64 {
+ if ts.Timestamp < now && ts.State == TaskGoroutineRunningSys {
+ return ts.SysTicks + (now - ts.Timestamp)
+ }
+ return ts.SysTicks
+}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) {
+ now := t.k.CPUClockNow()
+ if t.gosched.State != TaskGoroutineRunningSys {
+ panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, TaskGoroutineRunningSys, state))
+ }
+ t.goschedSeq.BeginWrite()
+ // This function is very hot; avoid defer.
+ t.gosched.SysTicks += now - t.gosched.Timestamp
+ t.gosched.Timestamp = now
+ t.gosched.State = state
+ t.goschedSeq.EndWrite()
+}
+
+// Preconditions: The caller must be running on the task goroutine, and leaving
+// a state indicated by a previous call to
+// t.accountTaskGoroutineEnter(state).
+func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) {
+ now := t.k.CPUClockNow()
+ if t.gosched.State != state {
+ panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys))
+ }
+ t.goschedSeq.BeginWrite()
+ // This function is very hot; avoid defer.
+ if state == TaskGoroutineRunningApp {
+ t.gosched.UserTicks += now - t.gosched.Timestamp
+ }
+ t.gosched.Timestamp = now
+ t.gosched.State = TaskGoroutineRunningSys
+ t.goschedSeq.EndWrite()
+}
+
+// TaskGoroutineSchedInfo returns a copy of t's task goroutine scheduling info.
+// Most clients should use t.CPUStats() instead.
+func (t *Task) TaskGoroutineSchedInfo() TaskGoroutineSchedInfo {
+ return SeqAtomicLoadTaskGoroutineSchedInfo(&t.goschedSeq, &t.gosched)
+}
+
+// CPUStats returns the CPU usage statistics of t.
+func (t *Task) CPUStats() usage.CPUStats {
+ return t.cpuStatsAt(t.k.CPUClockNow())
+}
+
+// Preconditions: As for TaskGoroutineSchedInfo.userTicksAt.
+func (t *Task) cpuStatsAt(now uint64) usage.CPUStats {
+ tsched := t.TaskGoroutineSchedInfo()
+ return usage.CPUStats{
+ UserTime: time.Duration(tsched.userTicksAt(now) * uint64(linux.ClockTick)),
+ SysTime: time.Duration(tsched.sysTicksAt(now) * uint64(linux.ClockTick)),
+ VoluntarySwitches: atomic.LoadUint64(&t.yieldCount),
+ }
+}
+
+// CPUStats returns the combined CPU usage statistics of all past and present
+// threads in tg.
+func (tg *ThreadGroup) CPUStats() usage.CPUStats {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ // Hack to get a pointer to the Kernel.
+ if tg.leader == nil {
+ // Per comment on tg.leader, this is only possible if nothing in the
+ // ThreadGroup has ever executed anyway.
+ return usage.CPUStats{}
+ }
+ return tg.cpuStatsAtLocked(tg.leader.k.CPUClockNow())
+}
+
+// Preconditions: As for TaskGoroutineSchedInfo.userTicksAt. The TaskSet mutex
+// must be locked.
+func (tg *ThreadGroup) cpuStatsAtLocked(now uint64) usage.CPUStats {
+ stats := tg.exitedCPUStats
+ // Account for live tasks.
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ stats.Accumulate(t.cpuStatsAt(now))
+ }
+ return stats
+}
+
+// JoinedChildCPUStats implements the semantics of RUSAGE_CHILDREN: "Return
+// resource usage statistics for all children of [tg] that have terminated and
+// been waited for. These statistics will include the resources used by
+// grandchildren, and further removed descendants, if all of the intervening
+// descendants waited on their terminated children."
+func (tg *ThreadGroup) JoinedChildCPUStats() usage.CPUStats {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.childCPUStats
+}
+
+// taskClock is a ktime.Clock that measures the time that a task has spent
+// executing. taskClock is primarily used to implement CLOCK_THREAD_CPUTIME_ID.
+//
+// +stateify savable
+type taskClock struct {
+ t *Task
+
+ // If includeSys is true, the taskClock includes both time spent executing
+ // application code as well as time spent in the sentry. Otherwise, the
+ // taskClock includes only time spent executing application code.
+ includeSys bool
+
+ // Implements waiter.Waitable. TimeUntil wouldn't change its estimation
+ // based on either of the clock events, so there's no event to be
+ // notified for.
+ ktime.NoClockEvents `state:"nosave"`
+
+ // Implements ktime.Clock.WallTimeUntil.
+ //
+ // As an upper bound, a task's clock cannot advance faster than CPU
+ // time. It would have to execute at a rate of more than 1 task-second
+ // per 1 CPU-second, which isn't possible.
+ ktime.WallRateClock `state:"nosave"`
+}
+
+// UserCPUClock returns a clock measuring the CPU time the task has spent
+// executing application code.
+func (t *Task) UserCPUClock() ktime.Clock {
+ return &taskClock{t: t, includeSys: false}
+}
+
+// CPUClock returns a clock measuring the CPU time the task has spent executing
+// application and "kernel" code.
+func (t *Task) CPUClock() ktime.Clock {
+ return &taskClock{t: t, includeSys: true}
+}
+
+// Now implements ktime.Clock.Now.
+func (tc *taskClock) Now() ktime.Time {
+ stats := tc.t.CPUStats()
+ if tc.includeSys {
+ return ktime.FromNanoseconds((stats.UserTime + stats.SysTime).Nanoseconds())
+ }
+ return ktime.FromNanoseconds(stats.UserTime.Nanoseconds())
+}
+
+// tgClock is a ktime.Clock that measures the time a thread group has spent
+// executing. tgClock is primarily used to implement CLOCK_PROCESS_CPUTIME_ID.
+//
+// +stateify savable
+type tgClock struct {
+ tg *ThreadGroup
+
+ // If includeSys is true, the tgClock includes both time spent executing
+ // application code as well as time spent in the sentry. Otherwise, the
+ // tgClock includes only time spent executing application code.
+ includeSys bool
+
+ // Implements waiter.Waitable.
+ ktime.ClockEventsQueue `state:"nosave"`
+}
+
+// Now implements ktime.Clock.Now.
+func (tgc *tgClock) Now() ktime.Time {
+ stats := tgc.tg.CPUStats()
+ if tgc.includeSys {
+ return ktime.FromNanoseconds((stats.UserTime + stats.SysTime).Nanoseconds())
+ }
+ return ktime.FromNanoseconds(stats.UserTime.Nanoseconds())
+}
+
+// WallTimeUntil implements ktime.Clock.WallTimeUntil.
+func (tgc *tgClock) WallTimeUntil(t, now ktime.Time) time.Duration {
+ // Thread group CPU time should not exceed wall time * live tasks, since
+ // task goroutines exit after the transition to TaskExitZombie in
+ // runExitNotify.
+ tgc.tg.pidns.owner.mu.RLock()
+ n := tgc.tg.liveTasks
+ tgc.tg.pidns.owner.mu.RUnlock()
+ if n == 0 {
+ if t.Before(now) {
+ return 0
+ }
+ // The timer tick raced with thread group exit, after which no more
+ // tasks can enter the thread group. So tgc.Now() will never advance
+ // again. Return a large delay; the timer should be stopped long before
+ // it comes again anyway.
+ return time.Hour
+ }
+ // This is a lower bound on the amount of time that can elapse before an
+ // associated timer expires, so returning this value tends to result in a
+ // sequence of closely-spaced ticks just before timer expiry. To avoid
+ // this, round up to the nearest ClockTick; CPU usage measurements are
+ // limited to this resolution anyway.
+ remaining := time.Duration(t.Sub(now).Nanoseconds()/int64(n)) * time.Nanosecond
+ return ((remaining + (linux.ClockTick - time.Nanosecond)) / linux.ClockTick) * linux.ClockTick
+}
+
+// UserCPUClock returns a ktime.Clock that measures the time that a thread
+// group has spent executing.
+func (tg *ThreadGroup) UserCPUClock() ktime.Clock {
+ return &tgClock{tg: tg, includeSys: false}
+}
+
+// CPUClock returns a ktime.Clock that measures the time that a thread group
+// has spent executing, including sentry time.
+func (tg *ThreadGroup) CPUClock() ktime.Clock {
+ return &tgClock{tg: tg, includeSys: true}
+}
+
+type kernelCPUClockTicker struct {
+ k *Kernel
+
+ // These are essentially kernelCPUClockTicker.Notify local variables that
+ // are cached between calls to reduce allocations.
+ rng *rand.Rand
+ tgs []*ThreadGroup
+}
+
+func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker {
+ return &kernelCPUClockTicker{
+ k: k,
+ rng: rand.New(rand.NewSource(rand.Int63())),
+ }
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
+ // Only increment cpuClock by 1 regardless of the number of expirations.
+ // This approximately compensates for cases where thread throttling or bad
+ // Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and
+ // presumably task goroutines as well, from executing for a long period of
+ // time. It's also necessary to prevent CPU clocks from seeing large
+ // discontinuous jumps.
+ now := atomic.AddUint64(&ticker.k.cpuClock, 1)
+
+ // Check thread group CPU timers.
+ tgs := ticker.k.tasks.Root.ThreadGroupsAppend(ticker.tgs)
+ for _, tg := range tgs {
+ if atomic.LoadUint32(&tg.cpuTimersEnabled) == 0 {
+ continue
+ }
+
+ ticker.k.tasks.mu.RLock()
+ if tg.leader == nil {
+ // No tasks have ever run in this thread group.
+ ticker.k.tasks.mu.RUnlock()
+ continue
+ }
+ // Accumulate thread group CPU stats, and randomly select running tasks
+ // using reservoir sampling to receive CPU timer signals.
+ var virtReceiver *Task
+ nrVirtCandidates := 0
+ var profReceiver *Task
+ nrProfCandidates := 0
+ tgUserTime := tg.exitedCPUStats.UserTime
+ tgSysTime := tg.exitedCPUStats.SysTime
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ tsched := t.TaskGoroutineSchedInfo()
+ tgUserTime += time.Duration(tsched.userTicksAt(now) * uint64(linux.ClockTick))
+ tgSysTime += time.Duration(tsched.sysTicksAt(now) * uint64(linux.ClockTick))
+ switch tsched.State {
+ case TaskGoroutineRunningApp:
+ // Considered by ITIMER_VIRT, ITIMER_PROF, and RLIMIT_CPU
+ // timers.
+ nrVirtCandidates++
+ if int(randInt31n(ticker.rng, int32(nrVirtCandidates))) == 0 {
+ virtReceiver = t
+ }
+ fallthrough
+ case TaskGoroutineRunningSys:
+ // Considered by ITIMER_PROF and RLIMIT_CPU timers.
+ nrProfCandidates++
+ if int(randInt31n(ticker.rng, int32(nrProfCandidates))) == 0 {
+ profReceiver = t
+ }
+ }
+ }
+ tgVirtNow := ktime.FromNanoseconds(tgUserTime.Nanoseconds())
+ tgProfNow := ktime.FromNanoseconds((tgUserTime + tgSysTime).Nanoseconds())
+
+ // All of the following are standard (not real-time) signals, which are
+ // automatically deduplicated, so we ignore the number of expirations.
+ tg.signalHandlers.mu.Lock()
+ // It should only be possible for these timers to advance if we found
+ // at least one running task.
+ if virtReceiver != nil {
+ // ITIMER_VIRTUAL
+ newItimerVirtSetting, exp := tg.itimerVirtSetting.At(tgVirtNow)
+ tg.itimerVirtSetting = newItimerVirtSetting
+ if exp != 0 {
+ virtReceiver.sendSignalLocked(SignalInfoPriv(linux.SIGVTALRM), true)
+ }
+ }
+ if profReceiver != nil {
+ // ITIMER_PROF
+ newItimerProfSetting, exp := tg.itimerProfSetting.At(tgProfNow)
+ tg.itimerProfSetting = newItimerProfSetting
+ if exp != 0 {
+ profReceiver.sendSignalLocked(SignalInfoPriv(linux.SIGPROF), true)
+ }
+ // RLIMIT_CPU soft limit
+ newRlimitCPUSoftSetting, exp := tg.rlimitCPUSoftSetting.At(tgProfNow)
+ tg.rlimitCPUSoftSetting = newRlimitCPUSoftSetting
+ if exp != 0 {
+ profReceiver.sendSignalLocked(SignalInfoPriv(linux.SIGXCPU), true)
+ }
+ // RLIMIT_CPU hard limit
+ rlimitCPUMax := tg.limits.Get(limits.CPU).Max
+ if rlimitCPUMax != limits.Infinity && !tgProfNow.Before(ktime.FromSeconds(int64(rlimitCPUMax))) {
+ profReceiver.sendSignalLocked(SignalInfoPriv(linux.SIGKILL), true)
+ }
+ }
+ tg.signalHandlers.mu.Unlock()
+
+ ticker.k.tasks.mu.RUnlock()
+ }
+
+ // Retain tgs between calls to Notify to reduce allocations.
+ for i := range tgs {
+ tgs[i] = nil
+ }
+ ticker.tgs = tgs[:0]
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (ticker *kernelCPUClockTicker) Destroy() {
+}
+
+// randInt31n returns a random integer in [0, n).
+//
+// randInt31n is equivalent to math/rand.Rand.int31n(), which is unexported.
+// See that function for details.
+func randInt31n(rng *rand.Rand, n int32) int32 {
+ v := rng.Uint32()
+ prod := uint64(v) * uint64(n)
+ low := uint32(prod)
+ if low < uint32(n) {
+ thresh := uint32(-n) % uint32(n)
+ for low < thresh {
+ v = rng.Uint32()
+ prod = uint64(v) * uint64(n)
+ low = uint32(prod)
+ }
+ }
+ return int32(prod >> 32)
+}
+
+// NotifyRlimitCPUUpdated is called by setrlimit.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) NotifyRlimitCPUUpdated() {
+ t.k.cpuClockTicker.Atomically(func() {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ rlimitCPU := t.tg.limits.Get(limits.CPU)
+ t.tg.rlimitCPUSoftSetting = ktime.Setting{
+ Enabled: rlimitCPU.Cur != limits.Infinity,
+ Next: ktime.FromNanoseconds((time.Duration(rlimitCPU.Cur) * time.Second).Nanoseconds()),
+ Period: time.Second,
+ }
+ if rlimitCPU.Max != limits.Infinity {
+ // Check if tg is already over the hard limit.
+ tgcpu := t.tg.cpuStatsAtLocked(t.k.CPUClockNow())
+ tgProfNow := ktime.FromNanoseconds((tgcpu.UserTime + tgcpu.SysTime).Nanoseconds())
+ if !tgProfNow.Before(ktime.FromSeconds(int64(rlimitCPU.Max))) {
+ t.sendSignalLocked(SignalInfoPriv(linux.SIGKILL), true)
+ }
+ }
+ t.tg.updateCPUTimersEnabledLocked()
+ })
+}
+
+// Preconditions: The signal mutex must be locked.
+func (tg *ThreadGroup) updateCPUTimersEnabledLocked() {
+ rlimitCPU := tg.limits.Get(limits.CPU)
+ if tg.itimerVirtSetting.Enabled || tg.itimerProfSetting.Enabled || tg.rlimitCPUSoftSetting.Enabled || rlimitCPU.Max != limits.Infinity {
+ atomic.StoreUint32(&tg.cpuTimersEnabled, 1)
+ } else {
+ atomic.StoreUint32(&tg.cpuTimersEnabled, 0)
+ }
+}
+
+// StateStatus returns a string representation of the task's current state,
+// appropriate for /proc/[pid]/status.
+func (t *Task) StateStatus() string {
+ switch s := t.TaskGoroutineSchedInfo().State; s {
+ case TaskGoroutineNonexistent:
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ switch t.exitState {
+ case TaskExitZombie:
+ return "Z (zombie)"
+ case TaskExitDead:
+ return "X (dead)"
+ default:
+ // The task goroutine can't exit before passing through
+ // runExitNotify, so this indicates that the task has been created,
+ // but the task goroutine hasn't yet started. The Linux equivalent
+ // is struct task_struct::state == TASK_NEW
+ // (kernel/fork.c:copy_process() =>
+ // kernel/sched/core.c:sched_fork()), but the TASK_NEW bit is
+ // masked out by TASK_REPORT for /proc/[pid]/status, leaving only
+ // TASK_RUNNING.
+ return "R (running)"
+ }
+ case TaskGoroutineRunningSys, TaskGoroutineRunningApp:
+ return "R (running)"
+ case TaskGoroutineBlockedInterruptible:
+ return "S (sleeping)"
+ case TaskGoroutineStopped:
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ switch t.stop.(type) {
+ case *groupStop:
+ return "T (stopped)"
+ case *ptraceStop:
+ return "t (tracing stop)"
+ }
+ fallthrough
+ case TaskGoroutineBlockedUninterruptible:
+ // This is the name Linux uses for TASK_UNINTERRUPTIBLE and
+ // TASK_KILLABLE (= TASK_UNINTERRUPTIBLE | TASK_WAKEKILL):
+ // fs/proc/array.c:task_state_array.
+ return "D (disk sleep)"
+ default:
+ panic(fmt.Sprintf("Invalid TaskGoroutineState: %v", s))
+ }
+}
+
+// CPUMask returns a copy of t's allowed CPU mask.
+func (t *Task) CPUMask() sched.CPUSet {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.allowedCPUMask.Copy()
+}
+
+// SetCPUMask sets t's allowed CPU mask based on mask. It takes ownership of
+// mask.
+//
+// Preconditions: mask.Size() ==
+// sched.CPUSetSize(t.Kernel().ApplicationCores()).
+func (t *Task) SetCPUMask(mask sched.CPUSet) error {
+ if want := sched.CPUSetSize(t.k.applicationCores); mask.Size() != want {
+ panic(fmt.Sprintf("Invalid CPUSet %v (expected %d bytes)", mask, want))
+ }
+
+ // Remove CPUs in mask above Kernel.applicationCores.
+ mask.ClearAbove(t.k.applicationCores)
+
+ // Ensure that at least 1 CPU is still allowed.
+ if mask.NumCPUs() == 0 {
+ return syserror.EINVAL
+ }
+
+ if t.k.useHostCores {
+ // No-op; pretend the mask was immediately changed back.
+ return nil
+ }
+
+ t.tg.pidns.owner.mu.RLock()
+ rootTID := t.tg.pidns.owner.Root.tids[t]
+ t.tg.pidns.owner.mu.RUnlock()
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.allowedCPUMask = mask
+ atomic.StoreInt32(&t.cpu, assignCPU(mask, rootTID))
+ return nil
+}
+
+// CPU returns the cpu id for a given task.
+func (t *Task) CPU() int32 {
+ if t.k.useHostCores {
+ return int32(hostcpu.GetCPU())
+ }
+
+ return atomic.LoadInt32(&t.cpu)
+}
+
+// assignCPU returns the virtualized CPU number for the task with global TID
+// tid and allowedCPUMask allowed.
+func assignCPU(allowed sched.CPUSet, tid ThreadID) (cpu int32) {
+ // To pretend that threads are evenly distributed to allowed CPUs, choose n
+ // to be less than the number of CPUs in allowed ...
+ n := int(tid) % int(allowed.NumCPUs())
+ // ... then pick the nth CPU in allowed.
+ allowed.ForEachCPU(func(c uint) {
+ if n--; n == 0 {
+ cpu = int32(c)
+ }
+ })
+ return cpu
+}
+
+// Niceness returns t's niceness.
+func (t *Task) Niceness() int {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.niceness
+}
+
+// Priority returns t's priority.
+func (t *Task) Priority() int {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.niceness + 20
+}
+
+// SetNiceness sets t's niceness to n.
+func (t *Task) SetNiceness(n int) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.niceness = n
+}
+
+// NumaPolicy returns t's current numa policy.
+func (t *Task) NumaPolicy() (policy int32, nodeMask uint32) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.numaPolicy, t.numaNodeMask
+}
+
+// SetNumaPolicy sets t's numa policy.
+func (t *Task) SetNumaPolicy(policy int32, nodeMask uint32) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.numaPolicy = policy
+ t.numaNodeMask = nodeMask
+}
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
new file mode 100644
index 000000000..654cf7525
--- /dev/null
+++ b/pkg/sentry/kernel/task_signals.go
@@ -0,0 +1,1110 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// This file defines the behavior of task signal handling.
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/eventchannel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ ucspb "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// SignalAction is an internal signal action.
+type SignalAction int
+
+// Available signal actions.
+// Note that although we refer the complete set internally,
+// the application is only capable of using the Default and
+// Ignore actions from the system call interface.
+const (
+ SignalActionTerm SignalAction = iota
+ SignalActionCore
+ SignalActionStop
+ SignalActionIgnore
+ SignalActionHandler
+)
+
+// Default signal handler actions. Note that for most signals,
+// (except SIGKILL and SIGSTOP) these can be overridden by the app.
+var defaultActions = map[linux.Signal]SignalAction{
+ // POSIX.1-1990 standard.
+ linux.SIGHUP: SignalActionTerm,
+ linux.SIGINT: SignalActionTerm,
+ linux.SIGQUIT: SignalActionCore,
+ linux.SIGILL: SignalActionCore,
+ linux.SIGABRT: SignalActionCore,
+ linux.SIGFPE: SignalActionCore,
+ linux.SIGKILL: SignalActionTerm, // but see ThreadGroup.applySignalSideEffects
+ linux.SIGSEGV: SignalActionCore,
+ linux.SIGPIPE: SignalActionTerm,
+ linux.SIGALRM: SignalActionTerm,
+ linux.SIGTERM: SignalActionTerm,
+ linux.SIGUSR1: SignalActionTerm,
+ linux.SIGUSR2: SignalActionTerm,
+ linux.SIGCHLD: SignalActionIgnore,
+ linux.SIGCONT: SignalActionIgnore, // but see ThreadGroup.applySignalSideEffects
+ linux.SIGSTOP: SignalActionStop,
+ linux.SIGTSTP: SignalActionStop,
+ linux.SIGTTIN: SignalActionStop,
+ linux.SIGTTOU: SignalActionStop,
+ // POSIX.1-2001 standard.
+ linux.SIGBUS: SignalActionCore,
+ linux.SIGPROF: SignalActionTerm,
+ linux.SIGSYS: SignalActionCore,
+ linux.SIGTRAP: SignalActionCore,
+ linux.SIGURG: SignalActionIgnore,
+ linux.SIGVTALRM: SignalActionTerm,
+ linux.SIGXCPU: SignalActionCore,
+ linux.SIGXFSZ: SignalActionCore,
+ // The rest on linux.
+ linux.SIGSTKFLT: SignalActionTerm,
+ linux.SIGIO: SignalActionTerm,
+ linux.SIGPWR: SignalActionTerm,
+ linux.SIGWINCH: SignalActionIgnore,
+}
+
+// computeAction figures out what to do given a signal number
+// and an arch.SignalAct. SIGSTOP always results in a SignalActionStop,
+// and SIGKILL always results in a SignalActionTerm.
+// Signal 0 is always ignored as many programs use it for various internal functions
+// and don't expect it to do anything.
+//
+// In the event the signal is not one of these, act.Handler determines what
+// happens next.
+// If act.Handler is:
+// 0, the default action is taken;
+// 1, the signal is ignored;
+// anything else, the function returns SignalActionHandler.
+func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction {
+ switch sig {
+ case linux.SIGSTOP:
+ return SignalActionStop
+ case linux.SIGKILL:
+ return SignalActionTerm
+ case linux.Signal(0):
+ return SignalActionIgnore
+ }
+
+ switch act.Handler {
+ case arch.SignalActDefault:
+ return defaultActions[sig]
+ case arch.SignalActIgnore:
+ return SignalActionIgnore
+ default:
+ return SignalActionHandler
+ }
+}
+
+// UnblockableSignals contains the set of signals which cannot be blocked.
+var UnblockableSignals = linux.MakeSignalSet(linux.SIGKILL, linux.SIGSTOP)
+
+// StopSignals is the set of signals whose default action is SignalActionStop.
+var StopSignals = linux.MakeSignalSet(linux.SIGSTOP, linux.SIGTSTP, linux.SIGTTIN, linux.SIGTTOU)
+
+// dequeueSignalLocked returns a pending signal that is *not* included in mask.
+// If there are no pending unmasked signals, dequeueSignalLocked returns nil.
+//
+// Preconditions: t.tg.signalHandlers.mu must be locked.
+func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *arch.SignalInfo {
+ if info := t.pendingSignals.dequeue(mask); info != nil {
+ return info
+ }
+ return t.tg.pendingSignals.dequeue(mask)
+}
+
+// discardSpecificLocked removes all instances of the given signal from all
+// signal queues in tg.
+//
+// Preconditions: The signal mutex must be locked.
+func (tg *ThreadGroup) discardSpecificLocked(sig linux.Signal) {
+ tg.pendingSignals.discardSpecific(sig)
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ t.pendingSignals.discardSpecific(sig)
+ }
+}
+
+// PendingSignals returns the set of pending signals.
+func (t *Task) PendingSignals() linux.SignalSet {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.pendingSignals.pendingSet | t.tg.pendingSignals.pendingSet
+}
+
+// deliverSignal delivers the given signal and returns the following run state.
+func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunState {
+ sigact := computeAction(linux.Signal(info.Signo), act)
+
+ if t.haveSyscallReturn {
+ if sre, ok := SyscallRestartErrnoFromReturn(t.Arch().Return()); ok {
+ // Signals that are ignored, cause a thread group stop, or
+ // terminate the thread group do not interact with interrupted
+ // syscalls; in Linux terms, they are never returned to the signal
+ // handling path from get_signal => get_signal_to_deliver. The
+ // behavior of an interrupted syscall is determined by the first
+ // signal that is actually handled (by userspace).
+ if sigact == SignalActionHandler {
+ switch {
+ case sre == ERESTARTNOHAND:
+ fallthrough
+ case sre == ERESTART_RESTARTBLOCK:
+ fallthrough
+ case (sre == ERESTARTSYS && !act.IsRestart()):
+ t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
+ t.Arch().SetReturn(uintptr(-t.ExtractErrno(syserror.EINTR, -1)))
+ default:
+ t.Debugf("Restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
+ t.Arch().RestartSyscall()
+ }
+ }
+ }
+ }
+
+ switch sigact {
+ case SignalActionTerm, SignalActionCore:
+ // "Default action is to terminate the process." - signal(7)
+ t.Debugf("Signal %d: terminating thread group", info.Signo)
+
+ // Emit an event channel messages related to this uncaught signal.
+ ucs := &ucspb.UncaughtSignal{
+ Tid: int32(t.Kernel().TaskSet().Root.IDOfTask(t)),
+ Pid: int32(t.Kernel().TaskSet().Root.IDOfThreadGroup(t.ThreadGroup())),
+ Registers: t.Arch().StateData().Proto(),
+ SignalNumber: info.Signo,
+ }
+
+ // Attach an fault address if appropriate.
+ switch linux.Signal(info.Signo) {
+ case linux.SIGSEGV, linux.SIGFPE, linux.SIGILL, linux.SIGTRAP, linux.SIGBUS:
+ ucs.FaultAddr = info.Addr()
+ }
+
+ eventchannel.Emit(ucs)
+
+ t.PrepareGroupExit(ExitStatus{Signo: int(info.Signo)})
+ return (*runExit)(nil)
+
+ case SignalActionStop:
+ // "Default action is to stop the process."
+ t.initiateGroupStop(info)
+
+ case SignalActionIgnore:
+ // "Default action is to ignore the signal."
+ t.Debugf("Signal %d: ignored", info.Signo)
+
+ case SignalActionHandler:
+ // Try to deliver the signal to the user-configured handler.
+ t.Debugf("Signal %d: delivering to handler", info.Signo)
+ if err := t.deliverSignalToHandler(info, act); err != nil {
+ // This is not a warning, it can occur during normal operation.
+ t.Debugf("Failed to deliver signal %+v to user handler: %v", info, err)
+
+ // Send a forced SIGSEGV. If the signal that couldn't be delivered
+ // was a SIGSEGV, force the handler to SIG_DFL.
+ t.forceSignal(linux.SIGSEGV, linux.Signal(info.Signo) == linux.SIGSEGV /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ }
+
+ default:
+ panic(fmt.Sprintf("Unknown signal action %+v, %d?", info, computeAction(linux.Signal(info.Signo), act)))
+ }
+ return (*runInterrupt)(nil)
+}
+
+// deliverSignalToHandler changes the task's userspace state to enter the given
+// user-configured handler for the given signal.
+func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) error {
+ // Signal delivery to an application handler interrupts restartable
+ // sequences.
+ t.rseqInterrupt()
+
+ // Are executing on the main stack,
+ // or the provided alternate stack?
+ sp := usermem.Addr(t.Arch().Stack())
+
+ // N.B. This is a *copy* of the alternate stack that the user's signal
+ // handler expects to see in its ucontext (even if it's not in use).
+ alt := t.signalStack
+ if act.IsOnStack() && alt.IsEnabled() {
+ alt.SetOnStack()
+ if !alt.Contains(sp) {
+ sp = usermem.Addr(alt.Top())
+ }
+ }
+
+ // Set up the signal handler. If we have a saved signal mask, the signal
+ // handler should run with the current mask, but sigreturn should restore
+ // the saved one.
+ st := &arch.Stack{t.Arch(), t.MemoryManager(), sp}
+ mask := t.signalMask
+ if t.haveSavedSignalMask {
+ mask = t.savedSignalMask
+ }
+ if err := t.Arch().SignalSetup(st, &act, info, &alt, mask); err != nil {
+ return err
+ }
+ t.haveSavedSignalMask = false
+
+ // Add our signal mask.
+ newMask := t.signalMask | act.Mask
+ if !act.IsNoDefer() {
+ newMask |= linux.SignalSetOf(linux.Signal(info.Signo))
+ }
+ t.SetSignalMask(newMask)
+
+ return nil
+}
+
+var ctrlResume = &SyscallControl{ignoreReturn: true}
+
+// SignalReturn implements sigreturn(2) (if rt is false) or rt_sigreturn(2) (if
+// rt is true).
+func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) {
+ st := t.Stack()
+ sigset, alt, err := t.Arch().SignalRestore(st, rt)
+ if err != nil {
+ return nil, err
+ }
+
+ // Attempt to record the given signal stack. Note that we silently
+ // ignore failures here, as does Linux. Only an EFAULT may be
+ // generated, but SignalRestore has already deserialized the entire
+ // frame successfully.
+ t.SetSignalStack(alt)
+
+ // Restore our signal mask. SIGKILL and SIGSTOP should not be blocked.
+ t.SetSignalMask(sigset &^ UnblockableSignals)
+
+ return ctrlResume, nil
+}
+
+// Sigtimedwait implements the semantics of sigtimedwait(2).
+//
+// Preconditions: The caller must be running on the task goroutine. t.exitState
+// < TaskExitZombie.
+func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.SignalInfo, error) {
+ // set is the set of signals we're interested in; invert it to get the set
+ // of signals to block.
+ mask := ^(set &^ UnblockableSignals)
+
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if info := t.dequeueSignalLocked(mask); info != nil {
+ return info, nil
+ }
+
+ if timeout == 0 {
+ return nil, syserror.EAGAIN
+ }
+
+ // Unblock signals we're waiting for. Remember the original signal mask so
+ // that Task.sendSignalTimerLocked doesn't discard ignored signals that
+ // we're temporarily unblocking.
+ t.realSignalMask = t.signalMask
+ t.setSignalMaskLocked(t.signalMask & mask)
+
+ // Wait for a timeout or new signal.
+ t.tg.signalHandlers.mu.Unlock()
+ _, err := t.BlockWithTimeout(nil, true, timeout)
+ t.tg.signalHandlers.mu.Lock()
+
+ // Restore the original signal mask.
+ t.setSignalMaskLocked(t.realSignalMask)
+ t.realSignalMask = 0
+
+ if info := t.dequeueSignalLocked(mask); info != nil {
+ return info, nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return nil, syserror.EAGAIN
+ }
+ return nil, err
+}
+
+// SendSignal sends the given signal to t.
+//
+// The following errors may be returned:
+//
+// syserror.ESRCH - The task has exited.
+// syserror.EINVAL - The signal is not valid.
+// syserror.EAGAIN - THe signal is realtime, and cannot be queued.
+//
+func (t *Task) SendSignal(info *arch.SignalInfo) error {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.sendSignalLocked(info, false /* group */)
+}
+
+// SendGroupSignal sends the given signal to t's thread group.
+func (t *Task) SendGroupSignal(info *arch.SignalInfo) error {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ return t.sendSignalLocked(info, true /* group */)
+}
+
+// SendSignal sends the given signal to tg, using tg's leader to determine if
+// the signal is blocked.
+func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ return tg.leader.sendSignalLocked(info, true /* group */)
+}
+
+func (t *Task) sendSignalLocked(info *arch.SignalInfo, group bool) error {
+ return t.sendSignalTimerLocked(info, group, nil)
+}
+
+func (t *Task) sendSignalTimerLocked(info *arch.SignalInfo, group bool, timer *IntervalTimer) error {
+ if t.exitState == TaskExitDead {
+ return syserror.ESRCH
+ }
+ sig := linux.Signal(info.Signo)
+ if sig == 0 {
+ return nil
+ }
+ if !sig.IsValid() {
+ return syserror.EINVAL
+ }
+
+ // Signal side effects apply even if the signal is ultimately discarded.
+ t.tg.applySignalSideEffectsLocked(sig)
+
+ // TODO: "Only signals for which the "init" process has established a
+ // signal handler can be sent to the "init" process by other members of the
+ // PID namespace. This restriction applies even to privileged processes,
+ // and prevents other members of the PID namespace from accidentally
+ // killing the "init" process." - pid_namespaces(7). We don't currently do
+ // this for child namespaces, though we should; we also don't do this for
+ // the root namespace (the same restriction applies to global init on
+ // Linux), where whether or not we should is much murkier. In practice,
+ // most sandboxed applications are not prepared to function as an init
+ // process.
+
+ // Unmasked, ignored signals are discarded without being queued, unless
+ // they will be visible to a tracer. Even for group signals, it's the
+ // originally-targeted task's signal mask and tracer that matter; compare
+ // Linux's kernel/signal.c:__send_signal() => prepare_signal() =>
+ // sig_ignored().
+ ignored := computeAction(sig, t.tg.signalHandlers.actions[sig]) == SignalActionIgnore
+ if sigset := linux.SignalSetOf(sig); sigset&t.signalMask == 0 && sigset&t.realSignalMask == 0 && ignored && !t.hasTracer() {
+ t.Debugf("Discarding ignored signal %d", sig)
+ if timer != nil {
+ timer.signalRejectedLocked()
+ }
+ return nil
+ }
+
+ q := &t.pendingSignals
+ if group {
+ q = &t.tg.pendingSignals
+ }
+ if !q.enqueue(info, timer) {
+ if sig.IsRealtime() {
+ return syserror.EAGAIN
+ }
+ t.Debugf("Discarding duplicate signal %d", sig)
+ if timer != nil {
+ timer.signalRejectedLocked()
+ }
+ return nil
+ }
+
+ // Find a receiver to notify. Note that the task we choose to notify, if
+ // any, may not be the task that actually dequeues and handles the signal;
+ // e.g. a racing signal mask change may cause the notified task to become
+ // ineligible, or a racing sibling task may dequeue the signal first.
+ if t.canReceiveSignalLocked(sig) {
+ t.Debugf("Notified of signal %d", sig)
+ t.interrupt()
+ return nil
+ }
+ if group {
+ if nt := t.tg.findSignalReceiverLocked(sig); nt != nil {
+ nt.Debugf("Notified of group signal %d", sig)
+ nt.interrupt()
+ return nil
+ }
+ }
+ t.Debugf("No task notified of signal %d", sig)
+ return nil
+}
+
+func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) {
+ switch {
+ case linux.SignalSetOf(sig)&StopSignals != 0:
+ // Stop signals cause all prior SIGCONT to be discarded. (This is
+ // despite the fact this has little effect since SIGCONT's most
+ // important effect is applied when the signal is sent in the branch
+ // below, not when the signal is delivered.)
+ tg.discardSpecificLocked(linux.SIGCONT)
+ case sig == linux.SIGCONT:
+ // "The SIGCONT signal has a side effect of waking up (all threads of)
+ // a group-stopped process. This side effect happens before
+ // signal-delivery-stop. The tracer can't suppress this side effect (it
+ // can only suppress signal injection, which only causes the SIGCONT
+ // handler to not be executed in the tracee, if such a handler is
+ // installed." - ptrace(2)
+ tg.endGroupStopLocked(true)
+ case sig == linux.SIGKILL:
+ // "SIGKILL does not generate signal-delivery-stop and therefore the
+ // tracer can't suppress it. SIGKILL kills even within system calls
+ // (syscall-exit-stop is not generated prior to death by SIGKILL)." -
+ // ptrace(2)
+ //
+ // Note that this differs from ThreadGroup.requestExit in that it
+ // ignores tg.execing.
+ if !tg.exiting {
+ tg.exiting = true
+ tg.exitStatus = ExitStatus{Signo: int(linux.SIGKILL)}
+ }
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ t.killLocked()
+ }
+ }
+}
+
+// canReceiveSignalLocked returns true if t should be interrupted to receive
+// the given signal. canReceiveSignalLocked is analogous to Linux's
+// kernel/signal.c:wants_signal(), but see below for divergences.
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
+ // - Do not choose tasks that are blocking the signal.
+ if linux.SignalSetOf(sig)&t.signalMask != 0 {
+ return false
+ }
+ // - No need to check Task.exitState, as the exit path sets every bit in the
+ // signal mask when it transitions from TaskExitNone to TaskExitInitiated.
+ // - No special case for SIGKILL: SIGKILL already interrupted all tasks in the
+ // task group via applySignalSideEffects => killLocked.
+ // - Do not choose stopped tasks, which cannot handle signals.
+ if t.stop != nil {
+ return false
+ }
+ // - TODO(b/38173783): No special case for when t is also the sending task,
+ // because the identity of the sender is unknown.
+ // - Do not choose tasks that have already been interrupted, as they may be
+ // busy handling another signal.
+ if len(t.interruptChan) != 0 {
+ return false
+ }
+ return true
+}
+
+// findSignalReceiverLocked returns a task in tg that should be interrupted to
+// receive the given signal. If no such task exists, findSignalReceiverLocked
+// returns nil.
+//
+// Linux actually records curr_target to balance the group signal targets.
+//
+// Preconditions: The signal mutex must be locked.
+func (tg *ThreadGroup) findSignalReceiverLocked(sig linux.Signal) *Task {
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ if t.canReceiveSignalLocked(sig) {
+ return t
+ }
+ }
+ return nil
+}
+
+// forceSignal ensures that the task is not ignoring or blocking the given
+// signal. If unconditional is true, forceSignal takes action even if the
+// signal isn't being ignored or blocked.
+func (t *Task) forceSignal(sig linux.Signal, unconditional bool) {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.forceSignalLocked(sig, unconditional)
+}
+
+func (t *Task) forceSignalLocked(sig linux.Signal, unconditional bool) {
+ blocked := linux.SignalSetOf(sig)&t.signalMask != 0
+ act := t.tg.signalHandlers.actions[sig]
+ ignored := act.Handler == arch.SignalActIgnore
+ if blocked || ignored || unconditional {
+ act.Handler = arch.SignalActDefault
+ t.tg.signalHandlers.actions[sig] = act
+ if blocked {
+ t.setSignalMaskLocked(t.signalMask &^ linux.SignalSetOf(sig))
+ }
+ }
+}
+
+// SignalMask returns a copy of t's signal mask.
+func (t *Task) SignalMask() linux.SignalSet {
+ return linux.SignalSet(atomic.LoadUint64((*uint64)(&t.signalMask)))
+}
+
+// SetSignalMask sets t's signal mask.
+//
+// Preconditions: SetSignalMask can only be called by the task goroutine.
+// t.exitState < TaskExitZombie.
+func (t *Task) SetSignalMask(mask linux.SignalSet) {
+ // By precondition, t prevents t.tg from completing an execve and mutating
+ // t.tg.signalHandlers, so we can skip the TaskSet mutex.
+ t.tg.signalHandlers.mu.Lock()
+ t.setSignalMaskLocked(mask)
+ t.tg.signalHandlers.mu.Unlock()
+}
+
+// Preconditions: The signal mutex must be locked.
+func (t *Task) setSignalMaskLocked(mask linux.SignalSet) {
+ oldMask := t.signalMask
+ atomic.StoreUint64((*uint64)(&t.signalMask), uint64(mask))
+
+ // If the new mask blocks any signals that were not blocked by the old
+ // mask, and at least one such signal is pending in tg.pendingSignals, and
+ // t has been woken, it could be the case that t was woken to handle that
+ // signal, but will no longer do so as a result of its new signal mask, so
+ // we have to pick a replacement.
+ blocked := mask &^ oldMask
+ blockedGroupPending := blocked & t.tg.pendingSignals.pendingSet
+ if blockedGroupPending != 0 && t.interrupted() {
+ linux.ForEachSignal(blockedGroupPending, func(sig linux.Signal) {
+ if nt := t.tg.findSignalReceiverLocked(sig); nt != nil {
+ nt.interrupt()
+ return
+ }
+ })
+ // We have to re-issue the interrupt consumed by t.interrupted() since
+ // it might have been for a different reason.
+ t.interruptSelf()
+ }
+
+ // Conversely, if the new mask unblocks any signals that were blocked by
+ // the old mask, and at least one such signal is pending, we may now need
+ // to handle that signal.
+ unblocked := oldMask &^ mask
+ unblockedPending := unblocked & (t.pendingSignals.pendingSet | t.tg.pendingSignals.pendingSet)
+ if unblockedPending != 0 {
+ t.interruptSelf()
+ }
+}
+
+// SetSavedSignalMask sets the saved signal mask (see Task.savedSignalMask's
+// comment).
+//
+// Preconditions: SetSavedSignalMask can only be called by the task goroutine.
+func (t *Task) SetSavedSignalMask(mask linux.SignalSet) {
+ t.savedSignalMask = mask
+ t.haveSavedSignalMask = true
+}
+
+// SignalStack returns the task-private signal stack.
+func (t *Task) SignalStack() arch.SignalStack {
+ alt := t.signalStack
+ if t.onSignalStack(alt) {
+ alt.Flags |= arch.SignalStackFlagOnStack
+ }
+ return alt
+}
+
+// onSignalStack returns true if the task is executing on the given signal stack.
+func (t *Task) onSignalStack(alt arch.SignalStack) bool {
+ sp := usermem.Addr(t.Arch().Stack())
+ return alt.Contains(sp)
+}
+
+// SetSignalStack sets the task-private signal stack.
+//
+// This value may not be changed if the task is currently executing on the
+// signal stack, i.e. if t.onSignalStack returns true. In this case, this
+// function will return false. Otherwise, true is returned.
+func (t *Task) SetSignalStack(alt arch.SignalStack) bool {
+ // Check that we're not executing on the stack.
+ if t.onSignalStack(t.signalStack) {
+ return false
+ }
+
+ if alt.Flags&arch.SignalStackFlagDisable != 0 {
+ // Don't record anything beyond the flags.
+ t.signalStack = arch.SignalStack{
+ Flags: arch.SignalStackFlagDisable,
+ }
+ } else {
+ // Mask out irrelevant parts: only disable matters.
+ alt.Flags &= arch.SignalStackFlagDisable
+ t.signalStack = alt
+ }
+ return true
+}
+
+// SetSignalAct atomically sets the thread group's signal action for signal sig
+// to *actptr (if actptr is not nil) and returns the old signal action.
+func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (arch.SignalAct, error) {
+ if !sig.IsValid() {
+ return arch.SignalAct{}, syserror.EINVAL
+ }
+
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ sh := tg.signalHandlers
+ sh.mu.Lock()
+ defer sh.mu.Unlock()
+ oldact := sh.actions[sig]
+ if actptr != nil {
+ if sig == linux.SIGKILL || sig == linux.SIGSTOP {
+ return oldact, syserror.EINVAL
+ }
+
+ act := *actptr
+ act.Mask &^= UnblockableSignals
+ sh.actions[sig] = act
+ // From POSIX, by way of Linux:
+ //
+ // "Setting a signal action to SIG_IGN for a signal that is pending
+ // shall cause the pending signal to be discarded, whether or not it is
+ // blocked."
+ //
+ // "Setting a signal action to SIG_DFL for a signal that is pending and
+ // whose default action is to ignore the signal (for example, SIGCHLD),
+ // shall cause the pending signal to be discarded, whether or not it is
+ // blocked."
+ if computeAction(sig, act) == SignalActionIgnore {
+ tg.discardSpecificLocked(sig)
+ }
+ }
+ return oldact, nil
+}
+
+// CopyOutSignalAct converts the given SignalAct into an architecture-specific
+// type and then copies it out to task memory.
+func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error {
+ n := t.Arch().NewSignalAct()
+ n.SerializeFrom(s)
+ _, err := t.CopyOut(addr, n)
+ return err
+}
+
+// CopyInSignalAct copies an architecture-specific sigaction type from task
+// memory and then converts it into a SignalAct.
+func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) {
+ n := t.Arch().NewSignalAct()
+ var s arch.SignalAct
+ if _, err := t.CopyIn(addr, n); err != nil {
+ return s, err
+ }
+ n.DeserializeTo(&s)
+ return s, nil
+}
+
+// CopyOutSignalStack converts the given SignalStack into an
+// architecture-specific type and then copies it out to task memory.
+func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error {
+ n := t.Arch().NewSignalStack()
+ n.SerializeFrom(s)
+ _, err := t.CopyOut(addr, n)
+ return err
+}
+
+// CopyInSignalStack copies an architecture-specific stack_t from task memory
+// and then converts it into a SignalStack.
+func (t *Task) CopyInSignalStack(addr usermem.Addr) (arch.SignalStack, error) {
+ n := t.Arch().NewSignalStack()
+ var s arch.SignalStack
+ if _, err := t.CopyIn(addr, n); err != nil {
+ return s, err
+ }
+ n.DeserializeTo(&s)
+ return s, nil
+}
+
+// groupStop is a TaskStop placed on tasks that have received a stop signal
+// (SIGSTOP, SIGTSTP, SIGTTIN, SIGTTOU). (The term "group-stop" originates from
+// the ptrace man page.)
+//
+// +stateify savable
+type groupStop struct{}
+
+// Killable implements TaskStop.Killable.
+func (*groupStop) Killable() bool { return true }
+
+// initiateGroupStop attempts to initiate a group stop based on a
+// previously-dequeued stop signal.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) initiateGroupStop(info *arch.SignalInfo) {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ if t.groupStopPending {
+ t.Debugf("Signal %d: not stopping thread group: lost to racing stop signal", info.Signo)
+ return
+ }
+ if !t.tg.groupStopDequeued {
+ t.Debugf("Signal %d: not stopping thread group: lost to racing SIGCONT", info.Signo)
+ return
+ }
+ if t.tg.exiting {
+ t.Debugf("Signal %d: not stopping thread group: lost to racing group exit", info.Signo)
+ return
+ }
+ if t.tg.execing != nil {
+ t.Debugf("Signal %d: not stopping thread group: lost to racing execve", info.Signo)
+ return
+ }
+ if !t.tg.groupStopComplete {
+ t.tg.groupStopSignal = linux.Signal(info.Signo)
+ }
+ t.tg.groupStopPendingCount = 0
+ for t2 := t.tg.tasks.Front(); t2 != nil; t2 = t2.Next() {
+ if t2.killedLocked() || t2.exitState >= TaskExitInitiated {
+ t2.groupStopPending = false
+ continue
+ }
+ t2.groupStopPending = true
+ t2.groupStopAcknowledged = false
+ if t2.ptraceSeized {
+ t2.trapNotifyPending = true
+ if s, ok := t2.stop.(*ptraceStop); ok && s.listen {
+ t2.endInternalStopLocked()
+ }
+ }
+ t2.interrupt()
+ t.tg.groupStopPendingCount++
+ }
+ t.Debugf("Signal %d: stopping %d threads in thread group", info.Signo, t.tg.groupStopPendingCount)
+}
+
+// endGroupStopLocked ensures that all prior stop signals received by tg are
+// not stopping tg and will not stop tg in the future. If broadcast is true,
+// parent and tracer notification will be scheduled if appropriate.
+//
+// Preconditions: The signal mutex must be locked.
+func (tg *ThreadGroup) endGroupStopLocked(broadcast bool) {
+ // Discard all previously-queued stop signals.
+ linux.ForEachSignal(StopSignals, tg.discardSpecificLocked)
+
+ if tg.groupStopPendingCount == 0 && !tg.groupStopComplete {
+ return
+ }
+
+ completeStr := "incomplete"
+ if tg.groupStopComplete {
+ completeStr = "complete"
+ }
+ tg.leader.Debugf("Ending %s group stop with %d threads pending", completeStr, tg.groupStopPendingCount)
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ t.groupStopPending = false
+ if t.ptraceSeized {
+ t.trapNotifyPending = true
+ if s, ok := t.stop.(*ptraceStop); ok && s.listen {
+ t.endInternalStopLocked()
+ }
+ } else {
+ if _, ok := t.stop.(*groupStop); ok {
+ t.endInternalStopLocked()
+ }
+ }
+ }
+ if broadcast {
+ // Instead of notifying the parent here, set groupContNotify so that
+ // one of the continuing tasks does so. (Linux does something similar.)
+ // The reason we do this is to keep locking sane. In order to send a
+ // signal to the parent, we need to lock its signal mutex, but we're
+ // already holding tg's signal mutex, and the TaskSet mutex must be
+ // locked for writing for us to hold two signal mutexes. Since we don't
+ // want to require this for endGroupStopLocked (which is called from
+ // signal-sending paths), nor do we want to lose atomicity by releasing
+ // the mutexes we're already holding, just let the continuing thread
+ // group deal with it.
+ tg.groupContNotify = true
+ tg.groupContInterrupted = !tg.groupStopComplete
+ tg.groupContWaitable = true
+ }
+ // Unsetting groupStopDequeued will cause racing calls to initiateGroupStop
+ // to recognize that the group stop has been cancelled.
+ tg.groupStopDequeued = false
+ tg.groupStopSignal = 0
+ tg.groupStopPendingCount = 0
+ tg.groupStopComplete = false
+ tg.groupStopWaitable = false
+}
+
+// participateGroupStopLocked is called to handle thread group side effects
+// after t unsets t.groupStopPending. The caller must handle task side effects
+// (e.g. placing the task goroutine into the group stop). It returns true if
+// the caller must notify t.tg.leader's parent of a completed group stop (which
+// participateGroupStopLocked cannot do due to holding the wrong locks).
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) participateGroupStopLocked() bool {
+ if t.groupStopAcknowledged {
+ return false
+ }
+ t.groupStopAcknowledged = true
+ t.tg.groupStopPendingCount--
+ if t.tg.groupStopPendingCount != 0 {
+ return false
+ }
+ if t.tg.groupStopComplete {
+ return false
+ }
+ t.Debugf("Completing group stop")
+ t.tg.groupStopComplete = true
+ t.tg.groupStopWaitable = true
+ t.tg.groupContNotify = false
+ t.tg.groupContWaitable = false
+ return true
+}
+
+// signalStop sends a signal to t's thread group of a new group stop, group
+// continue, or ptrace stop, if appropriate. code and status are set in the
+// signal sent to tg, if any.
+//
+// Preconditions: The TaskSet mutex must be locked (for reading or writing).
+func (t *Task) signalStop(target *Task, code int32, status int32) {
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ act, ok := t.tg.signalHandlers.actions[linux.SIGCHLD]
+ if !ok || (act.Handler != arch.SignalActIgnore && act.Flags&arch.SignalFlagNoCldStop == 0) {
+ sigchld := &arch.SignalInfo{
+ Signo: int32(linux.SIGCHLD),
+ Code: code,
+ }
+ sigchld.SetPid(int32(t.tg.pidns.tids[target]))
+ sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ sigchld.SetStatus(status)
+ // TODO(b/72102453): Set utime, stime.
+ t.sendSignalLocked(sigchld, true /* group */)
+ }
+}
+
+// The runInterrupt state handles conditions indicated by interrupts.
+//
+// +stateify savable
+type runInterrupt struct{}
+
+func (*runInterrupt) execute(t *Task) taskRunState {
+ // Interrupts are de-duplicated (if t is interrupted twice before
+ // t.interrupted() is called, t.interrupted() will only return true once),
+ // so early exits from this function must re-enter the runInterrupt state
+ // to check for more interrupt-signaled conditions.
+
+ t.tg.signalHandlers.mu.Lock()
+
+ // Did we just leave a group stop?
+ if t.tg.groupContNotify {
+ t.tg.groupContNotify = false
+ sig := t.tg.groupStopSignal
+ intr := t.tg.groupContInterrupted
+ t.tg.signalHandlers.mu.Unlock()
+ t.tg.pidns.owner.mu.RLock()
+ // For consistency with Linux, if the parent and (thread group
+ // leader's) tracer are in the same thread group, deduplicate
+ // notifications.
+ notifyParent := t.tg.leader.parent != nil
+ if tracer := t.tg.leader.Tracer(); tracer != nil {
+ if notifyParent && tracer.tg == t.tg.leader.parent.tg {
+ notifyParent = false
+ }
+ // Sending CLD_STOPPED to the tracer doesn't really make any sense;
+ // the thread group leader may have already entered the stop and
+ // notified its tracer accordingly. But it's consistent with
+ // Linux...
+ if intr {
+ tracer.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ if !notifyParent {
+ tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop | EventChildGroupStop)
+ } else {
+ tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop)
+ }
+ } else {
+ tracer.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig))
+ tracer.tg.eventQueue.Notify(EventGroupContinue)
+ }
+ }
+ if notifyParent {
+ // If groupContInterrupted, do as Linux does and pretend the group
+ // stop completed just before it ended. The theoretical behavior in
+ // this case would be to send a SIGCHLD indicating the completed
+ // stop, followed by a SIGCHLD indicating the continue. However,
+ // SIGCHLD is a standard signal, so the latter would always be
+ // dropped. Hence sending only the former is equivalent.
+ if intr {
+ t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue | EventChildGroupStop)
+ } else {
+ t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig))
+ t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue)
+ }
+ }
+ t.tg.pidns.owner.mu.RUnlock()
+ return (*runInterrupt)(nil)
+ }
+
+ // Do we need to enter a group stop or related ptrace stop? This path is
+ // analogous to Linux's kernel/signal.c:get_signal() => do_signal_stop()
+ // (with ptrace enabled) and do_jobctl_trap().
+ if t.groupStopPending || t.trapStopPending || t.trapNotifyPending {
+ sig := t.tg.groupStopSignal
+ notifyParent := false
+ if t.groupStopPending {
+ t.groupStopPending = false
+ // We care about t.tg.groupStopSignal (for tracer notification)
+ // even if this doesn't complete a group stop, so keep the
+ // value of sig we've already read.
+ notifyParent = t.participateGroupStopLocked()
+ }
+ t.trapStopPending = false
+ t.trapNotifyPending = false
+ // Drop the signal mutex so we can take the TaskSet mutex.
+ t.tg.signalHandlers.mu.Unlock()
+
+ t.tg.pidns.owner.mu.RLock()
+ if t.tg.leader.parent == nil {
+ notifyParent = false
+ }
+ if tracer := t.Tracer(); tracer != nil {
+ if t.ptraceSeized {
+ if sig == 0 {
+ sig = linux.SIGTRAP
+ }
+ // "If tracee was attached using PTRACE_SEIZE, group-stop is
+ // indicated by PTRACE_EVENT_STOP: status>>16 ==
+ // PTRACE_EVENT_STOP. This allows detection of group-stops
+ // without requiring an extra PTRACE_GETSIGINFO call." -
+ // "Group-stop", ptrace(2)
+ t.ptraceCode = int32(sig) | linux.PTRACE_EVENT_STOP<<8
+ t.ptraceSiginfo = &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: t.ptraceCode,
+ }
+ t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ } else {
+ t.ptraceCode = int32(sig)
+ t.ptraceSiginfo = nil
+ }
+ if t.beginPtraceStopLocked() {
+ tracer.signalStop(t, arch.CLD_STOPPED, int32(sig))
+ // For consistency with Linux, if the parent and tracer are in the
+ // same thread group, deduplicate notification signals.
+ if notifyParent && tracer.tg == t.tg.leader.parent.tg {
+ notifyParent = false
+ tracer.tg.eventQueue.Notify(EventChildGroupStop | EventTraceeStop)
+ } else {
+ tracer.tg.eventQueue.Notify(EventTraceeStop)
+ }
+ }
+ } else {
+ t.tg.signalHandlers.mu.Lock()
+ if !t.killedLocked() {
+ t.beginInternalStopLocked((*groupStop)(nil))
+ }
+ t.tg.signalHandlers.mu.Unlock()
+ }
+ if notifyParent {
+ t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop)
+ }
+ t.tg.pidns.owner.mu.RUnlock()
+
+ return (*runInterrupt)(nil)
+ }
+
+ // Are there signals pending?
+ if info := t.dequeueSignalLocked(t.signalMask); info != nil {
+ if linux.SignalSetOf(linux.Signal(info.Signo))&StopSignals != 0 {
+ // Indicate that we've dequeued a stop signal before unlocking the
+ // signal mutex; initiateGroupStop will check for races with
+ // endGroupStopLocked after relocking it.
+ t.tg.groupStopDequeued = true
+ }
+ if t.ptraceSignalLocked(info) {
+ // Dequeueing the signal action must wait until after the
+ // signal-delivery-stop ends since the tracer can change or
+ // suppress the signal.
+ t.tg.signalHandlers.mu.Unlock()
+ return (*runInterruptAfterSignalDeliveryStop)(nil)
+ }
+ act := t.tg.signalHandlers.dequeueAction(linux.Signal(info.Signo))
+ t.tg.signalHandlers.mu.Unlock()
+ return t.deliverSignal(info, act)
+ }
+
+ t.tg.signalHandlers.mu.Unlock()
+ return (*runApp)(nil)
+}
+
+// +stateify savable
+type runInterruptAfterSignalDeliveryStop struct{}
+
+func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
+ t.tg.pidns.owner.mu.Lock()
+ // Can't defer unlock: deliverSignal must be called without holding TaskSet
+ // mutex.
+ sig := linux.Signal(t.ptraceCode)
+ defer func() {
+ t.ptraceSiginfo = nil
+ }()
+ if !sig.IsValid() {
+ t.tg.pidns.owner.mu.Unlock()
+ return (*runInterrupt)(nil)
+ }
+ info := t.ptraceSiginfo
+ if sig != linux.Signal(info.Signo) {
+ info.Signo = int32(sig)
+ info.Errno = 0
+ info.Code = arch.SignalInfoUser
+ // pid isn't a valid field for all signal numbers, but Linux
+ // doesn't care (kernel/signal.c:ptrace_signal()).
+ //
+ // Linux uses t->parent for the tid and uid here, which is the tracer
+ // if it hasn't detached or the real parent otherwise.
+ parent := t.parent
+ if tracer := t.Tracer(); tracer != nil {
+ parent = tracer
+ }
+ if parent == nil {
+ // Tracer has detached and t was created by Kernel.CreateProcess().
+ // Pretend the parent is in an ancestor PID + user namespace.
+ info.SetPid(0)
+ info.SetUid(int32(auth.OverflowUID))
+ } else {
+ info.SetPid(int32(t.tg.pidns.tids[parent]))
+ info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ }
+ }
+ t.tg.signalHandlers.mu.Lock()
+ t.tg.pidns.owner.mu.Unlock()
+ // If the signal is masked, re-queue it.
+ if linux.SignalSetOf(sig)&t.signalMask != 0 {
+ t.sendSignalLocked(info, false /* group */)
+ t.tg.signalHandlers.mu.Unlock()
+ return (*runInterrupt)(nil)
+ }
+ act := t.tg.signalHandlers.dequeueAction(linux.Signal(info.Signo))
+ t.tg.signalHandlers.mu.Unlock()
+ return t.deliverSignal(info, act)
+}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
new file mode 100644
index 000000000..b42531e57
--- /dev/null
+++ b/pkg/sentry/kernel/task_start.go
@@ -0,0 +1,287 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// TaskConfig defines the configuration of a new Task (see below).
+type TaskConfig struct {
+ // Kernel is the owning Kernel.
+ Kernel *Kernel
+
+ // Parent is the new task's parent. Parent may be nil.
+ Parent *Task
+
+ // If InheritParent is not nil, use InheritParent's parent as the new
+ // task's parent.
+ InheritParent *Task
+
+ // ThreadGroup is the ThreadGroup the new task belongs to.
+ ThreadGroup *ThreadGroup
+
+ // SignalMask is the new task's initial signal mask.
+ SignalMask linux.SignalSet
+
+ // TaskContext is the TaskContext of the new task. Ownership of the
+ // TaskContext is transferred to TaskSet.NewTask, whether or not it
+ // succeeds.
+ TaskContext *TaskContext
+
+ // FSContext is the FSContext of the new task. A reference must be held on
+ // FSContext, which is transferred to TaskSet.NewTask whether or not it
+ // succeeds.
+ FSContext *FSContext
+
+ // FDMap is the FDMap of the new task. A reference must be held on FDMap,
+ // which is transferred to TaskSet.NewTask whether or not it succeeds.
+ FDMap *FDMap
+
+ // Credentials is the Credentials of the new task.
+ Credentials *auth.Credentials
+
+ // Niceness is the niceness of the new task.
+ Niceness int
+
+ // If NetworkNamespaced is true, the new task should observe a non-root
+ // network namespace.
+ NetworkNamespaced bool
+
+ // AllowedCPUMask contains the cpus that this task can run on.
+ AllowedCPUMask sched.CPUSet
+
+ // UTSNamespace is the UTSNamespace of the new task.
+ UTSNamespace *UTSNamespace
+
+ // IPCNamespace is the IPCNamespace of the new task.
+ IPCNamespace *IPCNamespace
+
+ // AbstractSocketNamespace is the AbstractSocketNamespace of the new task.
+ AbstractSocketNamespace *AbstractSocketNamespace
+
+ // ContainerID is the container the new task belongs to.
+ ContainerID string
+}
+
+// NewTask creates a new task defined by cfg.
+//
+// NewTask does not start the returned task; the caller must call Task.Start.
+func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) {
+ t, err := ts.newTask(cfg)
+ if err != nil {
+ cfg.TaskContext.release()
+ cfg.FSContext.DecRef()
+ cfg.FDMap.DecRef()
+ return nil, err
+ }
+ return t, nil
+}
+
+// newTask is a helper for TaskSet.NewTask that only takes ownership of parts
+// of cfg if it succeeds.
+func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
+ tg := cfg.ThreadGroup
+ tc := cfg.TaskContext
+ t := &Task{
+ taskNode: taskNode{
+ tg: tg,
+ parent: cfg.Parent,
+ children: make(map[*Task]struct{}),
+ },
+ runState: (*runApp)(nil),
+ interruptChan: make(chan struct{}, 1),
+ signalMask: cfg.SignalMask,
+ signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
+ tc: *tc,
+ fsc: cfg.FSContext,
+ fds: cfg.FDMap,
+ p: cfg.Kernel.Platform.NewContext(),
+ k: cfg.Kernel,
+ ptraceTracees: make(map[*Task]struct{}),
+ allowedCPUMask: cfg.AllowedCPUMask.Copy(),
+ ioUsage: &usage.IO{},
+ creds: cfg.Credentials,
+ niceness: cfg.Niceness,
+ netns: cfg.NetworkNamespaced,
+ utsns: cfg.UTSNamespace,
+ ipcns: cfg.IPCNamespace,
+ abstractSockets: cfg.AbstractSocketNamespace,
+ rseqCPU: -1,
+ futexWaiter: futex.NewWaiter(),
+ containerID: cfg.ContainerID,
+ }
+ t.endStopCond.L = &t.tg.signalHandlers.mu
+ t.ptraceTracer.Store((*Task)(nil))
+ // We don't construct t.blockingTimer until Task.run(); see that function
+ // for justification.
+
+ // Make the new task (and possibly thread group) visible to the rest of
+ // the system atomically.
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ if tg.exiting || tg.execing != nil {
+ // If the caller is in the same thread group, then what we return
+ // doesn't matter too much since the caller will exit before it returns
+ // to userspace. If the caller isn't in the same thread group, then
+ // we're in uncharted territory and can return whatever we want.
+ return nil, syserror.EINTR
+ }
+ if err := ts.assignTIDsLocked(t); err != nil {
+ return nil, err
+ }
+ // Below this point, newTask is expected not to fail (there is no rollback
+ // of assignTIDsLocked or any of the following).
+
+ // Logging on t's behalf will panic if t.logPrefix hasn't been initialized.
+ // This is the earliest point at which we can do so (since t now has thread
+ // IDs).
+ t.updateLogPrefixLocked()
+
+ if cfg.InheritParent != nil {
+ t.parent = cfg.InheritParent.parent
+ }
+ if t.parent != nil {
+ t.parent.children[t] = struct{}{}
+ }
+
+ if tg.leader == nil {
+ // New thread group.
+ tg.leader = t
+ if parentPG := tg.parentPG(); parentPG == nil {
+ tg.createSession()
+ } else {
+ // Inherit the process group.
+ parentPG.incRefWithParent(parentPG)
+ tg.processGroup = parentPG
+ }
+ }
+ tg.tasks.PushBack(t)
+ tg.tasksCount++
+ tg.liveTasks++
+ tg.activeTasks++
+
+ // Propagate external TaskSet stops to the new task.
+ t.stopCount = ts.stopCount
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ t.cpu = assignCPU(t.allowedCPUMask, ts.Root.tids[t])
+
+ t.startTime = t.k.RealtimeClock().Now()
+
+ return t, nil
+}
+
+// assignTIDsLocked ensures that new task t is visible in all PID namespaces in
+// which it should be visible.
+//
+// Preconditions: ts.mu must be locked for writing.
+func (ts *TaskSet) assignTIDsLocked(t *Task) error {
+ type allocatedTID struct {
+ ns *PIDNamespace
+ tid ThreadID
+ }
+ var allocatedTIDs []allocatedTID
+ for ns := t.tg.pidns; ns != nil; ns = ns.parent {
+ tid, err := ns.allocateTID()
+ if err != nil {
+ // Failure. Remove the tids we already allocated in descendant
+ // namespaces.
+ for _, a := range allocatedTIDs {
+ delete(a.ns.tasks, a.tid)
+ delete(a.ns.tids, t)
+ if t.tg.leader == nil {
+ delete(a.ns.tgids, t.tg)
+ }
+ }
+ return err
+ }
+ ns.tasks[tid] = t
+ ns.tids[t] = tid
+ if t.tg.leader == nil {
+ // New thread group.
+ ns.tgids[t.tg] = tid
+ }
+ allocatedTIDs = append(allocatedTIDs, allocatedTID{ns, tid})
+ }
+ return nil
+}
+
+// allocateTID returns an unused ThreadID from ns.
+//
+// Preconditions: ns.owner.mu must be locked for writing.
+func (ns *PIDNamespace) allocateTID() (ThreadID, error) {
+ if ns.exiting {
+ // "In this case, a subsequent fork(2) into this PID namespace will
+ // fail with the error ENOMEM; it is not possible to create a new
+ // processes [sic] in a PID namespace whose init process has
+ // terminated." - pid_namespaces(7)
+ return 0, syserror.ENOMEM
+ }
+ tid := ns.last
+ for {
+ // Next.
+ tid++
+ if tid > TasksLimit {
+ tid = InitTID + 1
+ }
+
+ // Is it available?
+ _, ok := ns.tasks[tid]
+ if !ok {
+ ns.last = tid
+ return tid, nil
+ }
+
+ // Did we do a full cycle?
+ if tid == ns.last {
+ // No tid available.
+ return 0, syserror.EAGAIN
+ }
+ }
+}
+
+// Start starts the task goroutine. Start must be called exactly once for each
+// task returned by NewTask.
+//
+// 'tid' must be the task's TID in the root PID namespace and it's used for
+// debugging purposes only (set as parameter to Task.run to make it visible
+// in stack dumps).
+func (t *Task) Start(tid ThreadID) {
+ // If the task was restored, it may be "starting" after having already exited.
+ if t.runState == nil {
+ return
+ }
+ t.goroutineStopped.Add(1)
+ t.tg.liveGoroutines.Add(1)
+ t.tg.pidns.owner.liveGoroutines.Add(1)
+ t.tg.pidns.owner.runningGoroutines.Add(1)
+
+ // Task is now running in system mode.
+ t.accountTaskGoroutineLeave(TaskGoroutineNonexistent)
+
+ // Use the task's TID in the root PID namespace to make it visible in stack dumps.
+ go t.run(uintptr(tid)) // S/R-SAFE: synchronizes with saving through stops
+}
diff --git a/pkg/sentry/kernel/task_stop.go b/pkg/sentry/kernel/task_stop.go
new file mode 100644
index 000000000..e735a5dd0
--- /dev/null
+++ b/pkg/sentry/kernel/task_stop.go
@@ -0,0 +1,226 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// This file implements task stops, which represent the equivalent of Linux's
+// uninterruptible sleep states in a way that is compatible with save/restore.
+// Task stops comprise both internal stops (which form part of the task's
+// "normal" control flow) and external stops (which do not); see README.md for
+// details.
+//
+// There are multiple interfaces for interacting with stops because there are
+// multiple cases to consider:
+//
+// - A task goroutine can begin a stop on its associated task (e.g. a
+// vfork() syscall stopping the calling task until the child task releases its
+// MM). In this case, calling Task.interrupt is both unnecessary (the task
+// goroutine obviously cannot be blocked in Task.block or executing application
+// code) and undesirable (as it may spuriously interrupt a in-progress
+// syscall).
+//
+// Beginning internal stops in this case is implemented by
+// Task.beginInternalStop / Task.beginInternalStopLocked. As of this writing,
+// there are no instances of this case that begin external stops, except for
+// autosave; however, autosave terminates the sentry without ending the
+// external stop, so the spurious interrupt is moot.
+//
+// - An arbitrary goroutine can begin a stop on an unrelated task (e.g. all
+// tasks being stopped in preparation for state checkpointing). If the task
+// goroutine may be in Task.block or executing application code, it must be
+// interrupted by Task.interrupt for it to actually enter the stop; since,
+// strictly speaking, we have no way of determining this, we call
+// Task.interrupt unconditionally.
+//
+// Beginning external stops in this case is implemented by
+// Task.BeginExternalStop. As of this writing, there are no instances of this
+// case that begin internal stops.
+//
+// - An arbitrary goroutine can end a stop on an unrelated task (e.g. an
+// exiting task resuming a sibling task that has been blocked in an execve()
+// syscall waiting for other tasks to exit). In this case, Task.endStopCond
+// must be notified to kick the task goroutine out of Task.doStop.
+//
+// Ending internal stops in this case is implemented by
+// Task.endInternalStopLocked. Ending external stops in this case is
+// implemented by Task.EndExternalStop.
+//
+// - Hypothetically, a task goroutine can end an internal stop on its
+// associated task. As of this writing, there are no instances of this case.
+// However, any instances of this case could still use the above functions,
+// since notifying Task.endStopCond would be unnecessary but harmless.
+
+import (
+ "fmt"
+ "sync/atomic"
+)
+
+// A TaskStop is a condition visible to the task control flow graph that
+// prevents a task goroutine from running or exiting, i.e. an internal stop.
+//
+// NOTE(b/30793614): Most TaskStops don't contain any data; they're
+// distinguished by their type. The obvious way to implement such a TaskStop
+// is:
+//
+// type groupStop struct{}
+// func (groupStop) Killable() bool { return true }
+// ...
+// t.beginInternalStop(groupStop{})
+//
+// However, this doesn't work because the state package can't serialize values,
+// only pointers. Furthermore, the correctness of save/restore depends on the
+// ability to pass a TaskStop to endInternalStop that will compare equal to the
+// TaskStop that was passed to beginInternalStop, even if a save/restore cycle
+// occurred between the two. As a result, the current idiom is to always use a
+// typecast nil for data-free TaskStops:
+//
+// type groupStop struct{}
+// func (*groupStop) Killable() bool { return true }
+// ...
+// t.beginInternalStop((*groupStop)(nil))
+//
+// This is pretty gross, but the alternatives seem grosser.
+type TaskStop interface {
+ // Killable returns true if Task.Kill should end the stop prematurely.
+ // Killable is analogous to Linux's TASK_WAKEKILL.
+ Killable() bool
+}
+
+// beginInternalStop indicates the start of an internal stop that applies to t.
+//
+// Preconditions: The task must not already be in an internal stop (i.e. t.stop
+// == nil). The caller must be running on the task goroutine.
+func (t *Task) beginInternalStop(s TaskStop) {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.beginInternalStopLocked(s)
+}
+
+// Preconditions: The signal mutex must be locked. All preconditions for
+// Task.beginInternalStop also apply.
+func (t *Task) beginInternalStopLocked(s TaskStop) {
+ if t.stop != nil {
+ panic(fmt.Sprintf("Attempting to enter internal stop %#v when already in internal stop %#v", s, t.stop))
+ }
+ t.Debugf("Entering internal stop %#v", s)
+ t.stop = s
+ t.beginStopLocked()
+}
+
+// endInternalStopLocked indicates the end of an internal stop that applies to
+// t. endInternalStopLocked does not wait for the task to resume.
+//
+// The caller is responsible for ensuring that the internal stop they expect
+// actually applies to t; this requires holding the signal mutex which protects
+// t.stop, which is why there is no endInternalStop that locks the signal mutex
+// for you.
+//
+// Preconditions: The signal mutex must be locked. The task must be in an
+// internal stop (i.e. t.stop != nil).
+func (t *Task) endInternalStopLocked() {
+ if t.stop == nil {
+ panic("Attempting to leave non-existent internal stop")
+ }
+ t.Debugf("Leaving internal stop %#v", t.stop)
+ t.stop = nil
+ t.endStopLocked()
+}
+
+// BeginExternalStop indicates the start of an external stop that applies to t.
+// BeginExternalStop does not wait for t's task goroutine to stop.
+func (t *Task) BeginExternalStop() {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.beginStopLocked()
+ t.interrupt()
+}
+
+// EndExternalStop indicates the end of an external stop started by a previous
+// call to Task.BeginExternalStop. EndExternalStop does not wait for t's task
+// goroutine to resume.
+func (t *Task) EndExternalStop() {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ t.tg.signalHandlers.mu.Lock()
+ defer t.tg.signalHandlers.mu.Unlock()
+ t.endStopLocked()
+}
+
+// beginStopLocked increments t.stopCount to indicate that a new internal or
+// external stop applies to t.
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) beginStopLocked() {
+ if newval := atomic.AddInt32(&t.stopCount, 1); newval <= 0 {
+ // Most likely overflow.
+ panic(fmt.Sprintf("Invalid stopCount: %d", newval))
+ }
+}
+
+// endStopLocked decerements t.stopCount to indicate that an existing internal
+// or external stop no longer applies to t.
+//
+// Preconditions: The signal mutex must be locked.
+func (t *Task) endStopLocked() {
+ if newval := atomic.AddInt32(&t.stopCount, -1); newval < 0 {
+ panic(fmt.Sprintf("Invalid stopCount: %d", newval))
+ } else if newval == 0 {
+ t.endStopCond.Signal()
+ }
+}
+
+// BeginExternalStop indicates the start of an external stop that applies to
+// all current and future tasks in ts. BeginExternalStop does not wait for
+// task goroutines to stop.
+func (ts *TaskSet) BeginExternalStop() {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.stopCount++
+ if ts.stopCount <= 0 {
+ panic(fmt.Sprintf("Invalid stopCount: %d", ts.stopCount))
+ }
+ if ts.Root == nil {
+ return
+ }
+ for t := range ts.Root.tids {
+ t.tg.signalHandlers.mu.Lock()
+ t.beginStopLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ t.interrupt()
+ }
+}
+
+// EndExternalStop indicates the end of an external stop started by a previous
+// call to TaskSet.BeginExternalStop. EndExternalStop does not wait for task
+// goroutines to resume.
+func (ts *TaskSet) EndExternalStop() {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.stopCount--
+ if ts.stopCount < 0 {
+ panic(fmt.Sprintf("Invalid stopCount: %d", ts.stopCount))
+ }
+ if ts.Root == nil {
+ return
+ }
+ for t := range ts.Root.tids {
+ t.tg.signalHandlers.mu.Lock()
+ t.endStopLocked()
+ t.tg.signalHandlers.mu.Unlock()
+ }
+}
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
new file mode 100644
index 000000000..a9283d0df
--- /dev/null
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -0,0 +1,447 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "os"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/bits"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel
+// include/linux/errno.h. These errnos are never returned to userspace
+// directly, but are used to communicate the expected behavior of an
+// interrupted syscall from the syscall to signal handling.
+type SyscallRestartErrno int
+
+// These numeric values are significant because ptrace syscall exit tracing can
+// observe them.
+//
+// For all of the following errnos, if the syscall is not interrupted by a
+// signal delivered to a user handler, the syscall is restarted.
+const (
+ // ERESTARTSYS is returned by an interrupted syscall to indicate that it
+ // should be converted to EINTR if interrupted by a signal delivered to a
+ // user handler without SA_RESTART set, and restarted otherwise.
+ ERESTARTSYS = SyscallRestartErrno(512)
+
+ // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it
+ // should always be restarted.
+ ERESTARTNOINTR = SyscallRestartErrno(513)
+
+ // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it
+ // should be converted to EINTR if interrupted by a signal delivered to a
+ // user handler, and restarted otherwise.
+ ERESTARTNOHAND = SyscallRestartErrno(514)
+
+ // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate
+ // that it should be restarted using a custom function. The interrupted
+ // syscall must register a custom restart function by calling
+ // Task.SetRestartSyscallFn.
+ ERESTART_RESTARTBLOCK = SyscallRestartErrno(516)
+)
+
+var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application")
+
+// Error implements error.Error.
+func (e SyscallRestartErrno) Error() string {
+ // Descriptions are borrowed from strace.
+ switch e {
+ case ERESTARTSYS:
+ return "to be restarted if SA_RESTART is set"
+ case ERESTARTNOINTR:
+ return "to be restarted"
+ case ERESTARTNOHAND:
+ return "to be restarted if no handler"
+ case ERESTART_RESTARTBLOCK:
+ return "interrupted by signal"
+ default:
+ return "(unknown interrupt error)"
+ }
+}
+
+// SyscallRestartErrnoFromReturn returns the SyscallRestartErrno represented by
+// rv, the value in a syscall return register.
+func SyscallRestartErrnoFromReturn(rv uintptr) (SyscallRestartErrno, bool) {
+ switch int(rv) {
+ case -int(ERESTARTSYS):
+ return ERESTARTSYS, true
+ case -int(ERESTARTNOINTR):
+ return ERESTARTNOINTR, true
+ case -int(ERESTARTNOHAND):
+ return ERESTARTNOHAND, true
+ case -int(ERESTART_RESTARTBLOCK):
+ return ERESTART_RESTARTBLOCK, true
+ default:
+ return 0, false
+ }
+}
+
+// SyscallRestartBlock represents the restart block for a syscall restartable
+// with a custom function. It encapsulates the state required to restart a
+// syscall across a S/R.
+type SyscallRestartBlock interface {
+ Restart(t *Task) (uintptr, error)
+}
+
+// SyscallControl is returned by syscalls to control the behavior of
+// Task.doSyscallInvoke.
+type SyscallControl struct {
+ // next is the state that the task goroutine should switch to. If next is
+ // nil, the task goroutine should continue to syscall exit as usual.
+ next taskRunState
+
+ // If ignoreReturn is true, Task.doSyscallInvoke should not store any value
+ // in the task's syscall return value register.
+ ignoreReturn bool
+}
+
+var (
+ // CtrlDoExit is returned by the implementations of the exit and exit_group
+ // syscalls to enter the task exit path directly, skipping syscall exit
+ // tracing.
+ CtrlDoExit = &SyscallControl{next: (*runExit)(nil), ignoreReturn: true}
+
+ // ctrlStopAndReinvokeSyscall is returned by syscalls using the external
+ // feature before syscall execution. This causes Task.doSyscallInvoke
+ // to return runSyscallReinvoke, allowing Task.run to check for stops
+ // before immediately re-invoking the syscall (skipping the re-checking
+ // of seccomp filters and ptrace which would confuse userspace
+ // tracing).
+ ctrlStopAndReinvokeSyscall = &SyscallControl{next: (*runSyscallReinvoke)(nil), ignoreReturn: true}
+
+ // ctrlStopBeforeSyscallExit is returned by syscalls that initiate a stop at
+ // their end. This causes Task.doSyscallInvoke to return runSyscallExit, rather
+ // than tail-calling it, allowing stops to be checked before syscall exit.
+ ctrlStopBeforeSyscallExit = &SyscallControl{next: (*runSyscallExit)(nil)}
+)
+
+func (t *Task) invokeExternal() {
+ t.BeginExternalStop()
+ go func() { // S/R-SAFE: External control flow.
+ defer t.EndExternalStop()
+ t.SyscallTable().External(t.Kernel())
+ }()
+}
+
+func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval uintptr, ctrl *SyscallControl, err error) {
+ s := t.SyscallTable()
+
+ fe := s.FeatureEnable.Word(sysno)
+
+ var straceContext interface{}
+ if bits.IsAnyOn32(fe, StraceEnableBits) {
+ straceContext = s.Stracer.SyscallEnter(t, sysno, args, fe)
+ }
+
+ if bits.IsOn32(fe, ExternalBeforeEnable) && (s.ExternalFilterBefore == nil || s.ExternalFilterBefore(t, sysno, args)) {
+ t.invokeExternal()
+ // Ensure we check for stops, then invoke the syscall again.
+ ctrl = ctrlStopAndReinvokeSyscall
+ } else {
+ fn := s.Lookup(sysno)
+ if fn != nil {
+ // Call our syscall implementation.
+ rval, ctrl, err = fn(t, args)
+ } else {
+ // Use the missing function if not found.
+ rval, err = t.SyscallTable().Missing(t, sysno, args)
+ }
+ }
+
+ if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
+ t.invokeExternal()
+ // Don't reinvoke the syscall.
+ }
+
+ if bits.IsAnyOn32(fe, StraceEnableBits) {
+ s.Stracer.SyscallExit(straceContext, t, sysno, rval, err)
+ }
+
+ return
+}
+
+// doSyscall is the entry point for an invocation of a system call specified by
+// the current state of t's registers.
+//
+// The syscall path is very hot; avoid defer.
+func (t *Task) doSyscall() taskRunState {
+ sysno := t.Arch().SyscallNo()
+ args := t.Arch().SyscallArgs()
+
+ // Tracers expect to see this between when the task traps into the kernel
+ // to perform a syscall and when the syscall is actually invoked.
+ // This useless-looking temporary is needed because Go.
+ tmp := uintptr(syscall.ENOSYS)
+ t.Arch().SetReturn(-tmp)
+
+ // Check seccomp filters. The nil check is for performance (as seccomp use
+ // is rare), not needed for correctness.
+ if t.syscallFilters.Load() != nil {
+ switch r := t.checkSeccompSyscall(int32(sysno), args, usermem.Addr(t.Arch().IP())); r {
+ case linux.SECCOMP_RET_ERRNO, linux.SECCOMP_RET_TRAP:
+ t.Debugf("Syscall %d: denied by seccomp", sysno)
+ return (*runSyscallExit)(nil)
+ case linux.SECCOMP_RET_ALLOW:
+ // ok
+ case linux.SECCOMP_RET_KILL_THREAD:
+ t.Debugf("Syscall %d: killed by seccomp", sysno)
+ t.PrepareExit(ExitStatus{Signo: int(linux.SIGSYS)})
+ return (*runExit)(nil)
+ case linux.SECCOMP_RET_TRACE:
+ t.Debugf("Syscall %d: stopping for PTRACE_EVENT_SECCOMP", sysno)
+ return (*runSyscallAfterPtraceEventSeccomp)(nil)
+ default:
+ panic(fmt.Sprintf("Unknown seccomp result %d", r))
+ }
+ }
+
+ return t.doSyscallEnter(sysno, args)
+}
+
+type runSyscallAfterPtraceEventSeccomp struct{}
+
+func (*runSyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
+ if t.killed() {
+ // "[S]yscall-exit-stop is not generated prior to death by SIGKILL." -
+ // ptrace(2)
+ return (*runInterrupt)(nil)
+ }
+ sysno := t.Arch().SyscallNo()
+ // "The tracer can skip the system call by changing the syscall number to
+ // -1." - Documentation/prctl/seccomp_filter.txt
+ if sysno == ^uintptr(0) {
+ return (*runSyscallExit)(nil).execute(t)
+ }
+ args := t.Arch().SyscallArgs()
+ return t.doSyscallEnter(sysno, args)
+}
+
+func (t *Task) doSyscallEnter(sysno uintptr, args arch.SyscallArguments) taskRunState {
+ if next, ok := t.ptraceSyscallEnter(); ok {
+ return next
+ }
+ return t.doSyscallInvoke(sysno, args)
+}
+
+// +stateify savable
+type runSyscallAfterSyscallEnterStop struct{}
+
+func (*runSyscallAfterSyscallEnterStop) execute(t *Task) taskRunState {
+ if sig := linux.Signal(t.ptraceCode); sig.IsValid() {
+ t.tg.signalHandlers.mu.Lock()
+ t.sendSignalLocked(SignalInfoPriv(sig), false /* group */)
+ t.tg.signalHandlers.mu.Unlock()
+ }
+ if t.killed() {
+ return (*runInterrupt)(nil)
+ }
+ sysno := t.Arch().SyscallNo()
+ if sysno == ^uintptr(0) {
+ return (*runSyscallExit)(nil)
+ }
+ args := t.Arch().SyscallArgs()
+ return t.doSyscallInvoke(sysno, args)
+}
+
+// +stateify savable
+type runSyscallAfterSysemuStop struct{}
+
+func (*runSyscallAfterSysemuStop) execute(t *Task) taskRunState {
+ if sig := linux.Signal(t.ptraceCode); sig.IsValid() {
+ t.tg.signalHandlers.mu.Lock()
+ t.sendSignalLocked(SignalInfoPriv(sig), false /* group */)
+ t.tg.signalHandlers.mu.Unlock()
+ }
+ if t.killed() {
+ return (*runInterrupt)(nil)
+ }
+ return (*runSyscallExit)(nil).execute(t)
+}
+
+func (t *Task) doSyscallInvoke(sysno uintptr, args arch.SyscallArguments) taskRunState {
+ rval, ctrl, err := t.executeSyscall(sysno, args)
+
+ if ctrl != nil {
+ if !ctrl.ignoreReturn {
+ t.Arch().SetReturn(rval)
+ }
+ if ctrl.next != nil {
+ return ctrl.next
+ }
+ } else if err != nil {
+ t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno))))
+ t.haveSyscallReturn = true
+ } else {
+ t.Arch().SetReturn(rval)
+ }
+
+ return (*runSyscallExit)(nil).execute(t)
+}
+
+// +stateify savable
+type runSyscallReinvoke struct{}
+
+func (*runSyscallReinvoke) execute(t *Task) taskRunState {
+ if t.killed() {
+ // It's possible that since the last execution, the task has
+ // been forcible killed. Invoking the system call here could
+ // result in an infinite loop if it is again preempted by an
+ // external stop and reinvoked.
+ return (*runInterrupt)(nil)
+ }
+
+ sysno := t.Arch().SyscallNo()
+ args := t.Arch().SyscallArgs()
+ return t.doSyscallInvoke(sysno, args)
+}
+
+// +stateify savable
+type runSyscallExit struct{}
+
+func (*runSyscallExit) execute(t *Task) taskRunState {
+ t.ptraceSyscallExit()
+ return (*runApp)(nil)
+}
+
+// doVsyscall is the entry point for a vsyscall invocation of syscall sysno, as
+// indicated by an execution fault at address addr. doVsyscall returns the
+// task's next run state.
+func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState {
+ vsyscallCount.Increment()
+
+ // Grab the caller up front, to make sure there's a sensible stack.
+ caller := t.Arch().Native(uintptr(0))
+ if _, err := t.CopyIn(usermem.Addr(t.Arch().Stack()), caller); err != nil {
+ t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return (*runApp)(nil)
+ }
+
+ // For _vsyscalls_, there is no need to translate System V calling convention
+ // to syscall ABI because they both use RDI, RSI, and RDX for the first three
+ // arguments and none of the vsyscalls uses more than two arguments.
+ args := t.Arch().SyscallArgs()
+ if t.syscallFilters.Load() != nil {
+ switch r := t.checkSeccompSyscall(int32(sysno), args, addr); r {
+ case linux.SECCOMP_RET_ERRNO, linux.SECCOMP_RET_TRAP:
+ t.Debugf("vsyscall %d, caller %x: denied by seccomp", sysno, t.Arch().Value(caller))
+ return (*runApp)(nil)
+ case linux.SECCOMP_RET_ALLOW:
+ // ok
+ case linux.SECCOMP_RET_TRACE:
+ t.Debugf("vsyscall %d, caller %x: stopping for PTRACE_EVENT_SECCOMP", sysno, t.Arch().Value(caller))
+ return &runVsyscallAfterPtraceEventSeccomp{addr, sysno, caller}
+ case linux.SECCOMP_RET_KILL_THREAD:
+ t.Debugf("vsyscall %d: killed by seccomp", sysno)
+ t.PrepareExit(ExitStatus{Signo: int(linux.SIGSYS)})
+ return (*runExit)(nil)
+ default:
+ panic(fmt.Sprintf("Unknown seccomp result %d", r))
+ }
+ }
+
+ return t.doVsyscallInvoke(sysno, args, caller)
+}
+
+type runVsyscallAfterPtraceEventSeccomp struct {
+ addr usermem.Addr
+ sysno uintptr
+ caller interface{}
+}
+
+func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
+ if t.killed() {
+ return (*runInterrupt)(nil)
+ }
+ sysno := t.Arch().SyscallNo()
+ // "... the syscall may not be changed to another system call using the
+ // orig_rax register. It may only be changed to -1 order [sic] to skip the
+ // currently emulated call. ... The tracer MUST NOT modify rip or rsp." -
+ // Documentation/prctl/seccomp_filter.txt. On Linux, changing orig_ax or ip
+ // causes do_exit(SIGSYS), and changing sp is ignored.
+ if (sysno != ^uintptr(0) && sysno != r.sysno) || usermem.Addr(t.Arch().IP()) != r.addr {
+ t.PrepareExit(ExitStatus{Signo: int(linux.SIGSYS)})
+ return (*runExit)(nil)
+ }
+ if sysno == ^uintptr(0) {
+ return (*runApp)(nil)
+ }
+ return t.doVsyscallInvoke(sysno, t.Arch().SyscallArgs(), r.caller)
+}
+
+func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller interface{}) taskRunState {
+ rval, ctrl, err := t.executeSyscall(sysno, args)
+ if ctrl != nil {
+ t.Debugf("vsyscall %d, caller %x: syscall control: %v", sysno, t.Arch().Value(caller), ctrl)
+ // Set the return value. The stack has already been adjusted.
+ t.Arch().SetReturn(0)
+ } else if err == nil {
+ t.Debugf("vsyscall %d, caller %x: successfully emulated syscall", sysno, t.Arch().Value(caller))
+ // Set the return value. The stack has already been adjusted.
+ t.Arch().SetReturn(uintptr(rval))
+ } else {
+ t.Debugf("vsyscall %d, caller %x: emulated syscall returned error: %v", sysno, t.Arch().Value(caller), err)
+ if err == syserror.EFAULT {
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // A return is not emulated in this case.
+ return (*runApp)(nil)
+ }
+ t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno))))
+ }
+ t.Arch().SetIP(t.Arch().Value(caller))
+ t.Arch().SetStack(t.Arch().Stack() + uintptr(t.Arch().Width()))
+ return (*runApp)(nil)
+}
+
+// ExtractErrno extracts an integer error number from the error.
+// The syscall number is purely for context in the error case. Use -1 if
+// syscall number is unknown.
+func (t *Task) ExtractErrno(err error, sysno int) int {
+ switch err := err.(type) {
+ case nil:
+ return 0
+ case syscall.Errno:
+ return int(err)
+ case SyscallRestartErrno:
+ return int(err)
+ case *memmap.BusError:
+ // Bus errors may generate SIGBUS, but for syscalls they still
+ // return EFAULT. See case in task_run.go where the fault is
+ // handled (and the SIGBUS is delivered).
+ return int(syscall.EFAULT)
+ case *os.PathError:
+ return t.ExtractErrno(err.Err, sysno)
+ case *os.LinkError:
+ return t.ExtractErrno(err.Err, sysno)
+ case *os.SyscallError:
+ return t.ExtractErrno(err.Err, sysno)
+ default:
+ if errno, ok := syserror.TranslateError(err); ok {
+ return int(errno)
+ }
+ }
+ panic(fmt.Sprintf("Unknown syscall %d error: %v", sysno, err))
+}
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
new file mode 100644
index 000000000..461bd7316
--- /dev/null
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -0,0 +1,301 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// MAX_RW_COUNT is the maximum size in bytes of a single read or write.
+// Reads and writes that exceed this size may be silently truncated.
+// (Linux: include/linux/fs.h:MAX_RW_COUNT)
+var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown())
+
+// Activate ensures that the task has an active address space.
+func (t *Task) Activate() {
+ if mm := t.MemoryManager(); mm != nil {
+ if err := mm.Activate(); err != nil {
+ panic("unable to activate mm: " + err.Error())
+ }
+ }
+}
+
+// Deactivate relinquishes the task's active address space.
+func (t *Task) Deactivate() {
+ if mm := t.MemoryManager(); mm != nil {
+ mm.Deactivate()
+ }
+}
+
+// CopyIn copies a fixed-size value or slice of fixed-size values in from the
+// task's memory. The copy will fail with syscall.EFAULT if it traverses user
+// memory that is unmapped or not readable by the user.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyIn(addr usermem.Addr, dst interface{}) (int, error) {
+ return usermem.CopyObjectIn(t, t.MemoryManager(), addr, dst, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyInBytes is a fast version of CopyIn if the caller can serialize the
+// data without reflection and pass in a byte slice.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+ return t.MemoryManager().CopyIn(t, addr, dst, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyOut copies a fixed-size value or slice of fixed-size values out to the
+// task's memory. The copy will fail with syscall.EFAULT if it traverses user
+// memory that is unmapped or not writeable by the user.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyOut(addr usermem.Addr, src interface{}) (int, error) {
+ return usermem.CopyObjectOut(t, t.MemoryManager(), addr, src, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyOutBytes is a fast version of CopyOut if the caller can serialize the
+// data without reflection and pass in a byte slice.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+ return t.MemoryManager().CopyOut(t, addr, src, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyInString copies a NUL-terminated string of length at most maxlen in from
+// the task's memory. The copy will fail with syscall.EFAULT if it traverses
+// user memory that is unmapped or not readable by the user.
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyInString(addr usermem.Addr, maxlen int) (string, error) {
+ return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxlen, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+}
+
+// CopyInVector copies a NULL-terminated vector of strings from the task's
+// memory. The copy will fail with syscall.EFAULT if it traverses
+// user memory that is unmapped or not readable by the user.
+//
+// maxElemSize is the maximum size of each individual element.
+//
+// maxTotalSize is the maximum total length of all elements plus the total
+// number of elements. For example, the following strings correspond to
+// the following set of sizes:
+//
+// { "a", "b", "c" } => 6 (3 for lengths, 3 for elements)
+// { "abc" } => 4 (3 for length, 1 for elements)
+//
+// This Task's AddressSpace must be active.
+func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([]string, error) {
+ var v []string
+ for {
+ argAddr := t.Arch().Native(0)
+ if _, err := t.CopyIn(addr, argAddr); err != nil {
+ return v, err
+ }
+ if t.Arch().Value(argAddr) == 0 {
+ break
+ }
+ // Each string has a zero terminating byte counted, so copying out a string
+ // requires at least one byte of space. Also, see the calculation below.
+ if maxTotalSize <= 0 {
+ return nil, syserror.ENOMEM
+ }
+ thisMax := maxElemSize
+ if maxTotalSize < thisMax {
+ thisMax = maxTotalSize
+ }
+ arg, err := t.CopyInString(usermem.Addr(t.Arch().Value(argAddr)), thisMax)
+ if err != nil {
+ return v, err
+ }
+ v = append(v, arg)
+ addr += usermem.Addr(t.Arch().Width())
+ maxTotalSize -= len(arg) + 1
+ }
+ return v, nil
+}
+
+// CopyOutIovecs converts src to an array of struct iovecs and copies it to the
+// memory mapped at addr.
+//
+// Preconditions: As for usermem.IO.CopyOut. The caller must be running on the
+// task goroutine. t's AddressSpace must be active.
+func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error {
+ switch t.Arch().Width() {
+ case 8:
+ const itemLen = 16
+ if _, ok := addr.AddLength(uint64(src.NumRanges()) * itemLen); !ok {
+ return syserror.EFAULT
+ }
+
+ b := t.CopyScratchBuffer(itemLen)
+ for ; !src.IsEmpty(); src = src.Tail() {
+ ar := src.Head()
+ usermem.ByteOrder.PutUint64(b[0:8], uint64(ar.Start))
+ usermem.ByteOrder.PutUint64(b[8:16], uint64(ar.Length()))
+ if _, err := t.CopyOutBytes(addr, b); err != nil {
+ return err
+ }
+ addr += itemLen
+ }
+
+ default:
+ return syserror.ENOSYS
+ }
+
+ return nil
+}
+
+// CopyInIovecs copies an array of numIovecs struct iovecs from the memory
+// mapped at addr, converts them to usermem.AddrRanges, and returns them as a
+// usermem.AddrRangeSeq.
+//
+// CopyInIovecs shares the following properties with Linux's
+// lib/iov_iter.c:import_iovec() => fs/read_write.c:rw_copy_check_uvector():
+//
+// - If the length of any AddrRange would exceed the range of an ssize_t,
+// CopyInIovecs returns EINVAL.
+//
+// - If the length of any AddrRange would cause its end to overflow,
+// CopyInIovecs returns EFAULT.
+//
+// - If any AddrRange would include addresses outside the application address
+// range, CopyInIovecs returns EFAULT.
+//
+// - The combined length of all AddrRanges is limited to MAX_RW_COUNT. If the
+// combined length of all AddrRanges would otherwise exceed this amount, ranges
+// beyond MAX_RW_COUNT are silently truncated.
+//
+// Preconditions: As for usermem.IO.CopyIn. The caller must be running on the
+// task goroutine. t's AddressSpace must be active.
+func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRangeSeq, error) {
+ if numIovecs == 0 {
+ return usermem.AddrRangeSeq{}, nil
+ }
+
+ var dst []usermem.AddrRange
+ if numIovecs > 1 {
+ dst = make([]usermem.AddrRange, 0, numIovecs)
+ }
+
+ switch t.Arch().Width() {
+ case 8:
+ const itemLen = 16
+ if _, ok := addr.AddLength(uint64(numIovecs) * itemLen); !ok {
+ return usermem.AddrRangeSeq{}, syserror.EFAULT
+ }
+
+ b := t.CopyScratchBuffer(itemLen)
+ for i := 0; i < numIovecs; i++ {
+ if _, err := t.CopyInBytes(addr, b); err != nil {
+ return usermem.AddrRangeSeq{}, err
+ }
+
+ base := usermem.Addr(usermem.ByteOrder.Uint64(b[0:8]))
+ length := usermem.ByteOrder.Uint64(b[8:16])
+ if length > math.MaxInt64 {
+ return usermem.AddrRangeSeq{}, syserror.EINVAL
+ }
+ ar, ok := t.MemoryManager().CheckIORange(base, int64(length))
+ if !ok {
+ return usermem.AddrRangeSeq{}, syserror.EFAULT
+ }
+
+ if numIovecs == 1 {
+ // Special case to avoid allocating dst.
+ return usermem.AddrRangeSeqOf(ar).TakeFirst(MAX_RW_COUNT), nil
+ }
+ dst = append(dst, ar)
+
+ addr += itemLen
+ }
+
+ default:
+ return usermem.AddrRangeSeq{}, syserror.ENOSYS
+ }
+
+ // Truncate to MAX_RW_COUNT.
+ var total uint64
+ for i := range dst {
+ dstlen := uint64(dst[i].Length())
+ if rem := uint64(MAX_RW_COUNT) - total; rem < dstlen {
+ dst[i].End -= usermem.Addr(dstlen - rem)
+ dstlen = rem
+ }
+ total += dstlen
+ }
+
+ return usermem.AddrRangeSeqFromSlice(dst), nil
+}
+
+// SingleIOSequence returns a usermem.IOSequence representing [addr,
+// addr+length) in t's address space. If this contains addresses outside the
+// application address range, it returns EFAULT. If length exceeds
+// MAX_RW_COUNT, the range is silently truncated.
+//
+// SingleIOSequence is analogous to Linux's
+// lib/iov_iter.c:import_single_range(). (Note that the non-vectorized read and
+// write syscalls in Linux do not use import_single_range(). However they check
+// access_ok() in fs/read_write.c:vfs_read/vfs_write, and overflowing address
+// ranges are truncated to MAX_RW_COUNT by fs/read_write.c:rw_verify_area().)
+func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOpts) (usermem.IOSequence, error) {
+ if length > MAX_RW_COUNT {
+ length = MAX_RW_COUNT
+ }
+ ar, ok := t.MemoryManager().CheckIORange(addr, int64(length))
+ if !ok {
+ return usermem.IOSequence{}, syserror.EFAULT
+ }
+ return usermem.IOSequence{
+ IO: t.MemoryManager(),
+ Addrs: usermem.AddrRangeSeqOf(ar),
+ Opts: opts,
+ }, nil
+}
+
+// IovecsIOSequence returns a usermem.IOSequence representing the array of
+// iovcnt struct iovecs at addr in t's address space. opts applies to the
+// returned IOSequence, not the reading of the struct iovec array.
+//
+// IovecsIOSequence is analogous to Linux's lib/iov_iter.c:import_iovec().
+//
+// Preconditions: As for Task.CopyInIovecs.
+func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) {
+ if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+ ars, err := t.CopyInIovecs(addr, iovcnt)
+ if err != nil {
+ return usermem.IOSequence{}, err
+ }
+ return usermem.IOSequence{
+ IO: t.MemoryManager(),
+ Addrs: ars,
+ Opts: opts,
+ }, nil
+}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
new file mode 100644
index 000000000..8bd53928e
--- /dev/null
+++ b/pkg/sentry/kernel/thread_group.go
@@ -0,0 +1,330 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+)
+
+// A ThreadGroup is a logical grouping of tasks that has widespread
+// significance to other kernel features (e.g. signal handling). ("Thread
+// groups" are usually called "processes" in userspace documentation.)
+//
+// ThreadGroup is a superset of Linux's struct signal_struct.
+//
+// +stateify savable
+type ThreadGroup struct {
+ threadGroupNode
+
+ // signalHandlers is the set of signal handlers used by every task in this
+ // thread group. (signalHandlers may also be shared with other thread
+ // groups.)
+ //
+ // signalHandlers.mu (hereafter "the signal mutex") protects state related
+ // to signal handling, as well as state that usually needs to be atomic
+ // with signal handling, for all ThreadGroups and Tasks using
+ // signalHandlers. (This is analogous to Linux's use of struct
+ // sighand_struct::siglock.)
+ //
+ // The signalHandlers pointer can only be mutated during an execve
+ // (Task.finishExec). Consequently, when it's possible for a task in the
+ // thread group to be completing an execve, signalHandlers is protected by
+ // the owning TaskSet.mu. Otherwise, it is possible to read the
+ // signalHandlers pointer without synchronization. In particular,
+ // completing an execve requires that all other tasks in the thread group
+ // have exited, so task goroutines do not need the owning TaskSet.mu to
+ // read the signalHandlers pointer of their thread groups.
+ signalHandlers *SignalHandlers
+
+ // pendingSignals is the set of pending signals that may be handled by any
+ // task in this thread group.
+ //
+ // pendingSignals is protected by the signal mutex.
+ pendingSignals pendingSignals
+
+ // If groupStopDequeued is true, a task in the thread group has dequeued a
+ // stop signal, but has not yet initiated the group stop.
+ //
+ // groupStopDequeued is analogous to Linux's JOBCTL_STOP_DEQUEUED.
+ //
+ // groupStopDequeued is protected by the signal mutex.
+ groupStopDequeued bool
+
+ // groupStopSignal is the signal that caused a group stop to be initiated.
+ //
+ // groupStopSignal is protected by the signal mutex.
+ groupStopSignal linux.Signal
+
+ // groupStopPendingCount is the number of active tasks in the thread group
+ // for which Task.groupStopPending is set.
+ //
+ // groupStopPendingCount is analogous to Linux's
+ // signal_struct::group_stop_count.
+ //
+ // groupStopPendingCount is protected by the signal mutex.
+ groupStopPendingCount int
+
+ // If groupStopComplete is true, groupStopPendingCount transitioned from
+ // non-zero to zero without an intervening SIGCONT.
+ //
+ // groupStopComplete is analogous to Linux's SIGNAL_STOP_STOPPED.
+ //
+ // groupStopComplete is protected by the signal mutex.
+ groupStopComplete bool
+
+ // If groupStopWaitable is true, the thread group is indicating a waitable
+ // group stop event (as defined by EventChildGroupStop).
+ //
+ // Linux represents the analogous state as SIGNAL_STOP_STOPPED being set
+ // and group_exit_code being non-zero.
+ //
+ // groupStopWaitable is protected by the signal mutex.
+ groupStopWaitable bool
+
+ // If groupContNotify is true, then a SIGCONT has recently ended a group
+ // stop on this thread group, and the first task to observe it should
+ // notify its parent. groupContInterrupted is true iff SIGCONT ended an
+ // incomplete group stop. If groupContNotify is false, groupContInterrupted is
+ // meaningless.
+ //
+ // Analogues in Linux:
+ //
+ // - groupContNotify && groupContInterrupted is represented by
+ // SIGNAL_CLD_STOPPED.
+ //
+ // - groupContNotify && !groupContInterrupted is represented by
+ // SIGNAL_CLD_CONTINUED.
+ //
+ // - !groupContNotify is represented by neither flag being set.
+ //
+ // groupContNotify and groupContInterrupted are protected by the signal
+ // mutex.
+ groupContNotify bool
+ groupContInterrupted bool
+
+ // If groupContWaitable is true, the thread group is indicating a waitable
+ // continue event (as defined by EventGroupContinue).
+ //
+ // groupContWaitable is analogous to Linux's SIGNAL_STOP_CONTINUED.
+ //
+ // groupContWaitable is protected by the signal mutex.
+ groupContWaitable bool
+
+ // exiting is true if all tasks in the ThreadGroup should exit. exiting is
+ // analogous to Linux's SIGNAL_GROUP_EXIT.
+ //
+ // exiting is protected by the signal mutex. exiting can only transition
+ // from false to true.
+ exiting bool
+
+ // exitStatus is the thread group's exit status.
+ //
+ // While exiting is false, exitStatus is protected by the signal mutex.
+ // When exiting becomes true, exitStatus becomes immutable.
+ exitStatus ExitStatus
+
+ // terminationSignal is the signal that this thread group's leader will
+ // send to its parent when it exits.
+ //
+ // terminationSignal is protected by the TaskSet mutex.
+ terminationSignal linux.Signal
+
+ // liveGoroutines is the number of non-exited task goroutines in the thread
+ // group.
+ //
+ // liveGoroutines is not saved; it is reset as task goroutines are
+ // restarted by Task.Start.
+ liveGoroutines sync.WaitGroup `state:"nosave"`
+
+ timerMu sync.Mutex `state:"nosave"`
+
+ // itimerRealTimer implements ITIMER_REAL for the thread group.
+ itimerRealTimer *ktime.Timer
+
+ // itimerVirtSetting is the ITIMER_VIRTUAL setting for the thread group.
+ //
+ // itimerVirtSetting is protected by the signal mutex.
+ itimerVirtSetting ktime.Setting
+
+ // itimerProfSetting is the ITIMER_PROF setting for the thread group.
+ //
+ // itimerProfSetting is protected by the signal mutex.
+ itimerProfSetting ktime.Setting
+
+ // rlimitCPUSoftSetting is the setting for RLIMIT_CPU soft limit
+ // notifications for the thread group.
+ //
+ // rlimitCPUSoftSetting is protected by the signal mutex.
+ rlimitCPUSoftSetting ktime.Setting
+
+ // cpuTimersEnabled is non-zero if itimerVirtSetting.Enabled is true,
+ // itimerProfSetting.Enabled is true, rlimitCPUSoftSetting.Enabled is true,
+ // or limits.Get(CPU) is finite.
+ //
+ // cpuTimersEnabled is protected by the signal mutex. cpuTimersEnabled is
+ // accessed using atomic memory operations.
+ cpuTimersEnabled uint32
+
+ // timers is the thread group's POSIX interval timers. nextTimerID is the
+ // TimerID at which allocation should begin searching for an unused ID.
+ //
+ // timers and nextTimerID are protected by timerMu.
+ timers map[linux.TimerID]*IntervalTimer
+ nextTimerID linux.TimerID
+
+ // exitedCPUStats is the CPU usage for all exited tasks in the thread
+ // group. exitedCPUStats is protected by the TaskSet mutex.
+ exitedCPUStats usage.CPUStats
+
+ // childCPUStats is the CPU usage of all joined descendants of this thread
+ // group. childCPUStats is protected by the TaskSet mutex.
+ childCPUStats usage.CPUStats
+
+ // ioUsage is the I/O usage for all exited tasks in the thread group.
+ // The ioUsage pointer is immutable.
+ ioUsage *usage.IO
+
+ // maxRSS is the historical maximum resident set size of the thread group, updated when:
+ //
+ // - A task in the thread group exits, since after all tasks have
+ // exited the MemoryManager is no longer reachable.
+ //
+ // - The thread group completes an execve, since this changes
+ // MemoryManagers.
+ //
+ // maxRSS is protected by the TaskSet mutex.
+ maxRSS uint64
+
+ // childMaxRSS is the maximum resident set size in bytes of all joined
+ // descendants of this thread group.
+ //
+ // childMaxRSS is protected by the TaskSet mutex.
+ childMaxRSS uint64
+
+ // Resource limits for this ThreadGroup. The limits pointer is immutable.
+ limits *limits.LimitSet
+
+ // processGroup is the processGroup for this thread group.
+ //
+ // processGroup is protected by the TaskSet mutex.
+ processGroup *ProcessGroup
+
+ // execed indicates an exec has occurred since creation. This will be
+ // set by finishExec, and new TheadGroups will have this field cleared.
+ // When execed is set, the processGroup may no longer be changed.
+ //
+ // execed is protected by the TaskSet mutex.
+ execed bool
+
+ // rscr is the thread group's RSEQ critical region.
+ rscr atomic.Value `state:".(*RSEQCriticalRegion)"`
+}
+
+// newThreadGroup returns a new, empty thread group in PID namespace ns. The
+// thread group leader will send its parent terminationSignal when it exits.
+// The new thread group isn't visible to the system until a task has been
+// created inside of it by a successful call to TaskSet.NewTask.
+func (k *Kernel) newThreadGroup(ns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet, monotonicClock *timekeeperClock) *ThreadGroup {
+ tg := &ThreadGroup{
+ threadGroupNode: threadGroupNode{
+ pidns: ns,
+ },
+ signalHandlers: sh,
+ terminationSignal: terminationSignal,
+ ioUsage: &usage.IO{},
+ limits: limits,
+ }
+ tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg})
+ tg.timers = make(map[linux.TimerID]*IntervalTimer)
+ tg.rscr.Store(&RSEQCriticalRegion{})
+ return tg
+}
+
+// saveRscr is invopked by stateify.
+func (tg *ThreadGroup) saveRscr() *RSEQCriticalRegion {
+ return tg.rscr.Load().(*RSEQCriticalRegion)
+}
+
+// loadRscr is invoked by stateify.
+func (tg *ThreadGroup) loadRscr(rscr *RSEQCriticalRegion) {
+ tg.rscr.Store(rscr)
+}
+
+// SignalHandlers returns the signal handlers used by tg.
+//
+// Preconditions: The caller must provide the synchronization required to read
+// tg.signalHandlers, as described in the field's comment.
+func (tg *ThreadGroup) SignalHandlers() *SignalHandlers {
+ return tg.signalHandlers
+}
+
+// Limits returns tg's limits.
+func (tg *ThreadGroup) Limits() *limits.LimitSet {
+ return tg.limits
+}
+
+// release releases the thread group's resources.
+func (tg *ThreadGroup) release() {
+ // Timers must be destroyed without holding the TaskSet or signal mutexes
+ // since timers send signals with Timer.mu locked.
+ tg.itimerRealTimer.Destroy()
+ var its []*IntervalTimer
+ tg.pidns.owner.mu.Lock()
+ tg.signalHandlers.mu.Lock()
+ for _, it := range tg.timers {
+ its = append(its, it)
+ }
+ tg.timers = make(map[linux.TimerID]*IntervalTimer) // nil maps can't be saved
+ tg.signalHandlers.mu.Unlock()
+ tg.pidns.owner.mu.Unlock()
+ for _, it := range its {
+ it.DestroyTimer()
+ }
+}
+
+// forEachChildThreadGroupLocked indicates over all child ThreadGroups.
+//
+// Precondition: TaskSet.mu must be held.
+func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) {
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ for child := range t.children {
+ if child == child.tg.leader {
+ fn(child.tg)
+ }
+ }
+ }
+}
+
+// itimerRealListener implements ktime.Listener for ITIMER_REAL expirations.
+//
+// +stateify savable
+type itimerRealListener struct {
+ tg *ThreadGroup
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (l *itimerRealListener) Notify(exp uint64) {
+ l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM))
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (l *itimerRealListener) Destroy() {
+}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
new file mode 100644
index 000000000..656bbd46c
--- /dev/null
+++ b/pkg/sentry/kernel/threads.go
@@ -0,0 +1,465 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// TasksLimit is the maximum number of threads for untrusted application.
+// Linux doesn't really limit this directly, rather it is limited by total
+// memory size, stacks allocated and a global maximum. There's no real reason
+// for us to limit it either, (esp. since threads are backed by go routines),
+// and we would expect to hit resource limits long before hitting this number.
+// However, for correctness, we still check that the user doesn't exceed this
+// number.
+//
+// Note that because of the way futexes are implemented, there *are* in fact
+// serious restrictions on valid thread IDs. They are limited to 2^30 - 1
+// (kernel/fork.c:MAX_THREADS).
+const TasksLimit = (1 << 16)
+
+// ThreadID is a generic thread identifier.
+type ThreadID int32
+
+// String returns a decimal representation of the ThreadID.
+func (tid ThreadID) String() string {
+ return fmt.Sprintf("%d", tid)
+}
+
+// InitTID is the TID given to the first task added to each PID namespace. The
+// thread group led by InitTID is called the namespace's init process. The
+// death of a PID namespace's init process causes all tasks visible in that
+// namespace to be killed.
+const InitTID ThreadID = 1
+
+// A TaskSet comprises all tasks in a system.
+//
+// +stateify savable
+type TaskSet struct {
+ // mu protects all relationships betweens tasks and thread groups in the
+ // TaskSet. (mu is approximately equivalent to Linux's tasklist_lock.)
+ mu sync.RWMutex `state:"nosave"`
+
+ // Root is the root PID namespace, in which all tasks in the TaskSet are
+ // visible. The Root pointer is immutable.
+ Root *PIDNamespace
+
+ // sessions is the set of all sessions.
+ sessions sessionList
+
+ // stopCount is the number of active external stops applicable to all tasks
+ // in the TaskSet (calls to TaskSet.BeginExternalStop that have not been
+ // paired with a call to TaskSet.EndExternalStop). stopCount is protected
+ // by mu.
+ //
+ // stopCount is not saved for the same reason as Task.stopCount; it is
+ // always reset to zero after restore.
+ stopCount int32 `state:"nosave"`
+
+ // liveGoroutines is the number of non-exited task goroutines in the
+ // TaskSet.
+ //
+ // liveGoroutines is not saved; it is reset as task goroutines are
+ // restarted by Task.Start.
+ liveGoroutines sync.WaitGroup `state:"nosave"`
+
+ // runningGoroutines is the number of running task goroutines in the
+ // TaskSet.
+ //
+ // runningGoroutines is not saved; its counter value is required to be zero
+ // at time of save (but note that this is not necessarily the same thing as
+ // sync.WaitGroup's zero value).
+ runningGoroutines sync.WaitGroup `state:"nosave"`
+}
+
+// newTaskSet returns a new, empty TaskSet.
+func newTaskSet() *TaskSet {
+ ts := &TaskSet{}
+ ts.Root = newPIDNamespace(ts, nil /* parent */, auth.NewRootUserNamespace())
+ return ts
+}
+
+// forEachThreadGroupLocked applies f to each thread group in ts.
+//
+// Preconditions: ts.mu must be locked (for reading or writing).
+func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) {
+ for tg := range ts.Root.tgids {
+ f(tg)
+ }
+}
+
+// A PIDNamespace represents a PID namespace, a bimap between thread IDs and
+// tasks. See the pid_namespaces(7) man page for further details.
+//
+// N.B. A task is said to be visible in a PID namespace if the PID namespace
+// contains a thread ID that maps to that task.
+//
+// +stateify savable
+type PIDNamespace struct {
+ // owner is the TaskSet that this PID namespace belongs to. The owner
+ // pointer is immutable.
+ owner *TaskSet
+
+ // parent is the PID namespace of the process that created this one. If
+ // this is the root PID namespace, parent is nil. The parent pointer is
+ // immutable.
+ //
+ // Invariant: All tasks that are visible in this namespace are also visible
+ // in all ancestor namespaces.
+ parent *PIDNamespace
+
+ // userns is the user namespace with which this PID namespace is
+ // associated. Privileged operations on this PID namespace must have
+ // appropriate capabilities in userns. The userns pointer is immutable.
+ userns *auth.UserNamespace
+
+ // The following fields are protected by owner.mu.
+
+ // last is the last ThreadID to be allocated in this namespace.
+ last ThreadID
+
+ // tasks is a mapping from ThreadIDs in this namespace to tasks visible in
+ // the namespace.
+ tasks map[ThreadID]*Task
+
+ // tids is a mapping from tasks visible in this namespace to their
+ // identifiers in this namespace.
+ tids map[*Task]ThreadID
+
+ // tgids is a mapping from thread groups visible in this namespace to
+ // their identifiers in this namespace.
+ //
+ // The content of tgids is equivalent to tids[tg.leader]. This exists
+ // primarily as an optimization to quickly find all thread groups.
+ tgids map[*ThreadGroup]ThreadID
+
+ // sessions is a mapping from SessionIDs in this namespace to sessions
+ // visible in the namespace.
+ sessions map[SessionID]*Session
+
+ // sids is a mapping from sessions visible in this namespace to their
+ // identifiers in this namespace.
+ sids map[*Session]SessionID
+
+ // processGroups is a mapping from ProcessGroupIDs in this namespace to
+ // process groups visible in the namespace.
+ processGroups map[ProcessGroupID]*ProcessGroup
+
+ // pgids is a mapping from process groups visible in this namespace to
+ // their identifiers in this namespace.
+ pgids map[*ProcessGroup]ProcessGroupID
+
+ // exiting indicates that the namespace's init process is exiting or has
+ // exited.
+ exiting bool
+}
+
+func newPIDNamespace(ts *TaskSet, parent *PIDNamespace, userns *auth.UserNamespace) *PIDNamespace {
+ return &PIDNamespace{
+ owner: ts,
+ parent: parent,
+ userns: userns,
+ tasks: make(map[ThreadID]*Task),
+ tids: make(map[*Task]ThreadID),
+ tgids: make(map[*ThreadGroup]ThreadID),
+ sessions: make(map[SessionID]*Session),
+ sids: make(map[*Session]SessionID),
+ processGroups: make(map[ProcessGroupID]*ProcessGroup),
+ pgids: make(map[*ProcessGroup]ProcessGroupID),
+ }
+}
+
+// NewChild returns a new, empty PID namespace that is a child of ns. Authority
+// over the new PID namespace is controlled by userns.
+func (ns *PIDNamespace) NewChild(userns *auth.UserNamespace) *PIDNamespace {
+ return newPIDNamespace(ns.owner, ns, userns)
+}
+
+// TaskWithID returns the task with thread ID tid in PID namespace ns. If no
+// task has that TID, TaskWithID returns nil.
+func (ns *PIDNamespace) TaskWithID(tid ThreadID) *Task {
+ ns.owner.mu.RLock()
+ t := ns.tasks[tid]
+ ns.owner.mu.RUnlock()
+ return t
+}
+
+// ThreadGroupWithID returns the thread group lead by the task with thread ID
+// tid in PID namespace ns. If no task has that TID, or if the task with that
+// TID is not a thread group leader, ThreadGroupWithID returns nil.
+func (ns *PIDNamespace) ThreadGroupWithID(tid ThreadID) *ThreadGroup {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ t := ns.tasks[tid]
+ if t == nil {
+ return nil
+ }
+ if t != t.tg.leader {
+ return nil
+ }
+ return t.tg
+}
+
+// IDOfTask returns the TID assigned to the given task in PID namespace ns. If
+// the task is not visible in that namespace, IDOfTask returns 0. (This return
+// value is significant in some cases, e.g. getppid() is documented as
+// returning 0 if the caller's parent is in an ancestor namespace and
+// consequently not visible to the caller.) If the task is nil, IDOfTask returns
+// 0.
+func (ns *PIDNamespace) IDOfTask(t *Task) ThreadID {
+ ns.owner.mu.RLock()
+ id := ns.tids[t]
+ ns.owner.mu.RUnlock()
+ return id
+}
+
+// IDOfThreadGroup returns the TID assigned to tg's leader in PID namespace ns.
+// If the task is not visible in that namespace, IDOfThreadGroup returns 0.
+func (ns *PIDNamespace) IDOfThreadGroup(tg *ThreadGroup) ThreadID {
+ ns.owner.mu.RLock()
+ id := ns.tgids[tg]
+ ns.owner.mu.RUnlock()
+ return id
+}
+
+// Tasks returns a snapshot of the tasks in ns.
+func (ns *PIDNamespace) Tasks() []*Task {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ tasks := make([]*Task, 0, len(ns.tasks))
+ for t := range ns.tids {
+ tasks = append(tasks, t)
+ }
+ return tasks
+}
+
+// ThreadGroups returns a snapshot of the thread groups in ns.
+func (ns *PIDNamespace) ThreadGroups() []*ThreadGroup {
+ return ns.ThreadGroupsAppend(nil)
+}
+
+// ThreadGroupsAppend appends a snapshot of the thread groups in ns to tgs.
+func (ns *PIDNamespace) ThreadGroupsAppend(tgs []*ThreadGroup) []*ThreadGroup {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ for tg := range ns.tgids {
+ tgs = append(tgs, tg)
+ }
+ return tgs
+}
+
+// UserNamespace returns the user namespace associated with PID namespace ns.
+func (ns *PIDNamespace) UserNamespace() *auth.UserNamespace {
+ return ns.userns
+}
+
+// A threadGroupNode defines the relationship between a thread group and the
+// rest of the system. Conceptually, threadGroupNode is data belonging to the
+// owning TaskSet, as if TaskSet contained a field `nodes
+// map[*ThreadGroup]*threadGroupNode`. However, for practical reasons,
+// threadGroupNode is embedded in the ThreadGroup it represents.
+// (threadGroupNode is an anonymous field in ThreadGroup; this is to expose
+// threadGroupEntry's methods on ThreadGroup to make it implement
+// threadGroupLinker.)
+//
+// +stateify savable
+type threadGroupNode struct {
+ // pidns is the PID namespace containing the thread group and all of its
+ // member tasks. The pidns pointer is immutable.
+ pidns *PIDNamespace
+
+ // eventQueue is notified whenever a event of interest to Task.Wait occurs
+ // in a child of this thread group, or a ptrace tracee of a task in this
+ // thread group. Events are defined in task_exit.go.
+ //
+ // Note that we cannot check and save this wait queue similarly to other
+ // wait queues, as the queue will not be empty by the time of saving, due
+ // to the wait sourced from Exec().
+ eventQueue waiter.Queue `state:"nosave"`
+
+ // leader is the thread group's leader, which is the oldest task in the
+ // thread group; usually the last task in the thread group to call
+ // execve(), or if no such task exists then the first task in the thread
+ // group, which was created by a call to fork() or clone() without
+ // CLONE_THREAD. Once a thread group has been made visible to the rest of
+ // the system by TaskSet.newTask, leader is never nil.
+ //
+ // Note that it's possible for the leader to exit without causing the rest
+ // of the thread group to exit; in such a case, leader will still be valid
+ // and non-nil, but leader will not be in tasks.
+ //
+ // leader is protected by the TaskSet mutex.
+ leader *Task
+
+ // If execing is not nil, it is a task in the thread group that has killed
+ // all other tasks so that it can become the thread group leader and
+ // perform an execve. (execing may already be the thread group leader.)
+ //
+ // execing is analogous to Linux's signal_struct::group_exit_task.
+ //
+ // execing is protected by the TaskSet mutex.
+ execing *Task
+
+ // tasks is all tasks in the thread group that have not yet been reaped.
+ //
+ // tasks is protected by both the TaskSet mutex and the signal mutex:
+ // Mutating tasks requires locking the TaskSet mutex for writing *and*
+ // locking the signal mutex. Reading tasks requires locking the TaskSet
+ // mutex *or* locking the signal mutex.
+ tasks taskList
+
+ // tasksCount is the number of tasks in the thread group that have not yet
+ // been reaped; equivalently, tasksCount is the number of tasks in tasks.
+ //
+ // tasksCount is protected by both the TaskSet mutex and the signal mutex,
+ // as with tasks.
+ tasksCount int
+
+ // liveTasks is the number of tasks in the thread group that have not yet
+ // reached TaskExitZombie.
+ //
+ // liveTasks is protected by the TaskSet mutex (NOT the signal mutex).
+ liveTasks int
+
+ // activeTasks is the number of tasks in the thread group that have not yet
+ // reached TaskExitInitiated.
+ //
+ // activeTasks is protected by both the TaskSet mutex and the signal mutex,
+ // as with tasks.
+ activeTasks int
+}
+
+// PIDNamespace returns the PID namespace containing tg.
+func (tg *ThreadGroup) PIDNamespace() *PIDNamespace {
+ return tg.pidns
+}
+
+// TaskSet returns the TaskSet containing tg.
+func (tg *ThreadGroup) TaskSet() *TaskSet {
+ return tg.pidns.owner
+}
+
+// Leader returns tg's leader.
+func (tg *ThreadGroup) Leader() *Task {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ return tg.leader
+}
+
+// Count returns the number of non-exited threads in the group.
+func (tg *ThreadGroup) Count() int {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ var count int
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ count++
+ }
+ return count
+}
+
+// MemberIDs returns a snapshot of the ThreadIDs (in PID namespace pidns) for
+// all tasks in tg.
+func (tg *ThreadGroup) MemberIDs(pidns *PIDNamespace) []ThreadID {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+
+ var tasks []ThreadID
+ for t := tg.tasks.Front(); t != nil; t = t.Next() {
+ if id, ok := pidns.tids[t]; ok {
+ tasks = append(tasks, id)
+ }
+ }
+ return tasks
+}
+
+// ID returns tg's leader's thread ID in its own PID namespace. If tg's leader
+// is dead, ID returns 0.
+func (tg *ThreadGroup) ID() ThreadID {
+ tg.pidns.owner.mu.RLock()
+ id := tg.pidns.tgids[tg]
+ tg.pidns.owner.mu.RUnlock()
+ return id
+}
+
+// A taskNode defines the relationship between a task and the rest of the
+// system. The comments on threadGroupNode also apply to taskNode.
+//
+// +stateify savable
+type taskNode struct {
+ // tg is the thread group that this task belongs to. The tg pointer is
+ // immutable.
+ tg *ThreadGroup `state:"wait"`
+
+ // taskEntry links into tg.tasks. Note that this means that
+ // Task.Next/Prev/SetNext/SetPrev refer to sibling tasks in the same thread
+ // group. See threadGroupNode.tasks for synchronization info.
+ taskEntry
+
+ // parent is the task's parent. parent may be nil.
+ //
+ // parent is protected by the TaskSet mutex.
+ parent *Task
+
+ // children is this task's children.
+ //
+ // children is protected by the TaskSet mutex.
+ children map[*Task]struct{}
+
+ // If childPIDNamespace is not nil, all new tasks created by this task will
+ // be members of childPIDNamespace rather than this one. (As a corollary,
+ // this task becomes unable to create sibling tasks in the same thread
+ // group.)
+ //
+ // childPIDNamespace is exclusive to the task goroutine.
+ childPIDNamespace *PIDNamespace
+}
+
+// ThreadGroup returns the thread group containing t.
+func (t *Task) ThreadGroup() *ThreadGroup {
+ return t.tg
+}
+
+// PIDNamespace returns the PID namespace containing t.
+func (t *Task) PIDNamespace() *PIDNamespace {
+ return t.tg.pidns
+}
+
+// TaskSet returns the TaskSet containing t.
+func (t *Task) TaskSet() *TaskSet {
+ return t.tg.pidns.owner
+}
+
+// Timekeeper returns the system Timekeeper.
+func (t *Task) Timekeeper() *Timekeeper {
+ return t.k.timekeeper
+}
+
+// Parent returns t's parent.
+func (t *Task) Parent() *Task {
+ t.tg.pidns.owner.mu.RLock()
+ defer t.tg.pidns.owner.mu.RUnlock()
+ return t.parent
+}
+
+// ThreadID returns t's thread ID in its own PID namespace. If the task is
+// dead, ThreadID returns 0.
+func (t *Task) ThreadID() ThreadID {
+ return t.tg.pidns.IDOfTask(t)
+}
diff --git a/pkg/sentry/kernel/time/context.go b/pkg/sentry/kernel/time/context.go
new file mode 100644
index 000000000..c0660d362
--- /dev/null
+++ b/pkg/sentry/kernel/time/context.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextID is the time package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxRealtimeClock is a Context.Value key for the current real time.
+ CtxRealtimeClock contextID = iota
+)
+
+// RealtimeClockFromContext returns the real time clock associated with context
+// ctx.
+func RealtimeClockFromContext(ctx context.Context) Clock {
+ if v := ctx.Value(CtxRealtimeClock); v != nil {
+ return v.(Clock)
+ }
+ return nil
+}
+
+// NowFromContext returns the current real time associated with context ctx.
+func NowFromContext(ctx context.Context) Time {
+ if clk := RealtimeClockFromContext(ctx); clk != nil {
+ return clk.Now()
+ }
+ panic("encountered context without RealtimeClock")
+}
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
new file mode 100644
index 000000000..3846cf1ea
--- /dev/null
+++ b/pkg/sentry/kernel/time/time.go
@@ -0,0 +1,691 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package time defines the Timer type, which provides a periodic timer that
+// works by sampling a user-provided clock.
+package time
+
+import (
+ "fmt"
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Events that may be generated by a Clock.
+const (
+ // ClockEventSet occurs when a Clock undergoes a discontinuous change.
+ ClockEventSet waiter.EventMask = 1 << iota
+
+ // ClockEventRateIncrease occurs when the rate at which a Clock advances
+ // increases significantly, such that values returned by previous calls to
+ // Clock.WallTimeUntil may be too large.
+ ClockEventRateIncrease
+)
+
+// Time represents an instant in time with nanosecond precision.
+//
+// Time may represent time with respect to any clock and may not have any
+// meaning in the real world.
+//
+// +stateify savable
+type Time struct {
+ ns int64
+}
+
+var (
+ // MinTime is the zero time instant, the lowest possible time that can
+ // be represented by Time.
+ MinTime = Time{ns: math.MinInt64}
+
+ // MaxTime is the highest possible time that can be represented by
+ // Time.
+ MaxTime = Time{ns: math.MaxInt64}
+
+ // ZeroTime represents the zero time in an unspecified Clock's domain.
+ ZeroTime = Time{ns: 0}
+)
+
+const (
+ // MinDuration is the minimum duration representable by time.Duration.
+ MinDuration = time.Duration(math.MinInt64)
+
+ // MaxDuration is the maximum duration representable by time.Duration.
+ MaxDuration = time.Duration(math.MaxInt64)
+)
+
+// FromNanoseconds returns a Time representing the point ns nanoseconds after
+// an unspecified Clock's zero time.
+func FromNanoseconds(ns int64) Time {
+ return Time{ns}
+}
+
+// FromSeconds returns a Time representing the point s seconds after an
+// unspecified Clock's zero time.
+func FromSeconds(s int64) Time {
+ if s > math.MaxInt64/time.Second.Nanoseconds() {
+ return MaxTime
+ }
+ return Time{s * 1e9}
+}
+
+// FromUnix converts from Unix seconds and nanoseconds to Time, assuming a real
+// time Unix clock domain.
+func FromUnix(s int64, ns int64) Time {
+ if s > math.MaxInt64/time.Second.Nanoseconds() {
+ return MaxTime
+ }
+ t := s * 1e9
+ if t > math.MaxInt64-ns {
+ return MaxTime
+ }
+ return Time{t + ns}
+}
+
+// FromTimespec converts from Linux Timespec to Time.
+func FromTimespec(ts linux.Timespec) Time {
+ return Time{ts.ToNsecCapped()}
+}
+
+// FromTimeval converts a Linux Timeval to Time.
+func FromTimeval(tv linux.Timeval) Time {
+ return Time{tv.ToNsecCapped()}
+}
+
+// Nanoseconds returns nanoseconds elapsed since the zero time in t's Clock
+// domain. If t represents walltime, this is nanoseconds since the Unix epoch.
+func (t Time) Nanoseconds() int64 {
+ return t.ns
+}
+
+// Seconds returns seconds elapsed since the zero time in t's Clock domain. If
+// t represents walltime, this is seconds since Unix epoch.
+func (t Time) Seconds() int64 {
+ return t.Nanoseconds() / time.Second.Nanoseconds()
+}
+
+// Timespec converts Time to a Linux timespec.
+func (t Time) Timespec() linux.Timespec {
+ return linux.NsecToTimespec(t.Nanoseconds())
+}
+
+// Unix returns the (seconds, nanoseconds) representation of t such that
+// seconds*1e9 + nanoseconds = t.
+func (t Time) Unix() (s int64, ns int64) {
+ s = t.ns / 1e9
+ ns = t.ns % 1e9
+ return
+}
+
+// TimeT converts Time to a Linux time_t.
+func (t Time) TimeT() linux.TimeT {
+ return linux.NsecToTimeT(t.Nanoseconds())
+}
+
+// Timeval converts Time to a Linux timeval.
+func (t Time) Timeval() linux.Timeval {
+ return linux.NsecToTimeval(t.Nanoseconds())
+}
+
+// Add adds the duration of d to t.
+func (t Time) Add(d time.Duration) Time {
+ if t.ns > 0 && d.Nanoseconds() > math.MaxInt64-int64(t.ns) {
+ return MaxTime
+ }
+ if t.ns < 0 && d.Nanoseconds() < math.MinInt64-int64(t.ns) {
+ return MinTime
+ }
+ return Time{int64(t.ns) + d.Nanoseconds()}
+}
+
+// AddTime adds the duration of u to t.
+func (t Time) AddTime(u Time) Time {
+ return t.Add(time.Duration(u.ns))
+}
+
+// Equal reports whether the two times represent the same instant in time.
+func (t Time) Equal(u Time) bool {
+ return t.ns == u.ns
+}
+
+// Before reports whether the instant t is before the instant u.
+func (t Time) Before(u Time) bool {
+ return t.ns < u.ns
+}
+
+// After reports whether the instant t is after the instant u.
+func (t Time) After(u Time) bool {
+ return t.ns > u.ns
+}
+
+// Sub returns the duration of t - u.
+//
+// N.B. This measure may not make sense for every Time returned by ktime.Clock.
+// Callers who need wall time duration can use ktime.Clock.WallTimeUntil to
+// estimate that wall time.
+func (t Time) Sub(u Time) time.Duration {
+ dur := time.Duration(int64(t.ns)-int64(u.ns)) * time.Nanosecond
+ switch {
+ case u.Add(dur).Equal(t):
+ return dur
+ case t.Before(u):
+ return MinDuration
+ default:
+ return MaxDuration
+ }
+}
+
+// IsMin returns whether t represents the lowest possible time instant.
+func (t Time) IsMin() bool {
+ return t == MinTime
+}
+
+// IsZero returns whether t represents the zero time instant in t's Clock domain.
+func (t Time) IsZero() bool {
+ return t == ZeroTime
+}
+
+// String returns the time represented in nanoseconds as a string.
+func (t Time) String() string {
+ return fmt.Sprintf("%dns", t.Nanoseconds())
+}
+
+// A Clock is an abstract time source.
+type Clock interface {
+ // Now returns the current time in nanoseconds according to the Clock.
+ Now() Time
+
+ // WallTimeUntil returns the estimated wall time until Now will return a
+ // value greater than or equal to t, given that a recent call to Now
+ // returned now. If t has already passed, WallTimeUntil may return 0 or a
+ // negative value.
+ //
+ // WallTimeUntil must be abstract to support Clocks that do not represent
+ // wall time (e.g. thread group execution timers). Clocks that represent
+ // wall times may embed the WallRateClock type to obtain an appropriate
+ // trivial implementation of WallTimeUntil.
+ //
+ // WallTimeUntil is used to determine when associated Timers should next
+ // check for expirations. Returning too small a value may result in
+ // spurious Timer goroutine wakeups, while returning too large a value may
+ // result in late expirations. Implementations should usually err on the
+ // side of underestimating.
+ WallTimeUntil(t, now Time) time.Duration
+
+ // Waitable methods may be used to subscribe to Clock events. Waiters will
+ // not be preserved by Save and must be re-established during restore.
+ //
+ // Since Clock events are transient, implementations of
+ // waiter.Waitable.Readiness should return 0.
+ waiter.Waitable
+}
+
+// WallRateClock implements Clock.WallTimeUntil for Clocks that elapse at the
+// same rate as wall time.
+type WallRateClock struct{}
+
+// WallTimeUntil implements Clock.WallTimeUntil.
+func (WallRateClock) WallTimeUntil(t, now Time) time.Duration {
+ return t.Sub(now)
+}
+
+// NoClockEvents implements waiter.Waitable for Clocks that do not generate
+// events.
+type NoClockEvents struct{}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return 0
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (NoClockEvents) EventUnregister(e *waiter.Entry) {
+}
+
+// ClockEventsQueue implements waiter.Waitable by wrapping waiter.Queue and
+// defining waiter.Waitable.Readiness as required by Clock.
+type ClockEventsQueue struct {
+ waiter.Queue
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return 0
+}
+
+// A TimerListener receives expirations from a Timer.
+type TimerListener interface {
+ // Notify is called when its associated Timer expires. exp is the number of
+ // expirations.
+ //
+ // Notify is called with the associated Timer's mutex locked, so Notify
+ // must not take any locks that precede Timer.mu in lock order.
+ //
+ // Preconditions: exp > 0.
+ Notify(exp uint64)
+
+ // Destroy is called when the timer is destroyed.
+ Destroy()
+}
+
+// Setting contains user-controlled mutable Timer properties.
+//
+// +stateify savable
+type Setting struct {
+ // Enabled is true if the timer is running.
+ Enabled bool
+
+ // Next is the time in nanoseconds of the next expiration.
+ Next Time
+
+ // Period is the time in nanoseconds between expirations. If Period is
+ // zero, the timer will not automatically restart after expiring.
+ //
+ // Invariant: Period >= 0.
+ Period time.Duration
+}
+
+// SettingFromSpec converts a (value, interval) pair to a Setting based on a
+// reading from c. value is interpreted as a time relative to c.Now().
+func SettingFromSpec(value time.Duration, interval time.Duration, c Clock) (Setting, error) {
+ return SettingFromSpecAt(value, interval, c.Now())
+}
+
+// SettingFromSpecAt converts a (value, interval) pair to a Setting. value is
+// interpreted as a time relative to now.
+func SettingFromSpecAt(value time.Duration, interval time.Duration, now Time) (Setting, error) {
+ if value < 0 {
+ return Setting{}, syserror.EINVAL
+ }
+ if value == 0 {
+ return Setting{Period: interval}, nil
+ }
+ return Setting{
+ Enabled: true,
+ Next: now.Add(value),
+ Period: interval,
+ }, nil
+}
+
+// SettingFromAbsSpec converts a (value, interval) pair to a Setting. value is
+// interpreted as an absolute time.
+func SettingFromAbsSpec(value Time, interval time.Duration) (Setting, error) {
+ if value.Before(ZeroTime) {
+ return Setting{}, syserror.EINVAL
+ }
+ if value.IsZero() {
+ return Setting{Period: interval}, nil
+ }
+ return Setting{
+ Enabled: true,
+ Next: value,
+ Period: interval,
+ }, nil
+}
+
+// SettingFromItimerspec converts a linux.Itimerspec to a Setting. If abs is
+// true, its.Value is interpreted as an absolute time. Otherwise, it is
+// interpreted as a time relative to c.Now().
+func SettingFromItimerspec(its linux.Itimerspec, abs bool, c Clock) (Setting, error) {
+ if abs {
+ return SettingFromAbsSpec(FromTimespec(its.Value), its.Interval.ToDuration())
+ }
+ return SettingFromSpec(its.Value.ToDuration(), its.Interval.ToDuration(), c)
+}
+
+// SpecFromSetting converts a timestamp and a Setting to a (relative value,
+// interval) pair, as used by most Linux syscalls that return a struct
+// itimerval or struct itimerspec.
+func SpecFromSetting(now Time, s Setting) (value, period time.Duration) {
+ if !s.Enabled {
+ return 0, s.Period
+ }
+ return s.Next.Sub(now), s.Period
+}
+
+// ItimerspecFromSetting converts a Setting to a linux.Itimerspec.
+func ItimerspecFromSetting(now Time, s Setting) linux.Itimerspec {
+ val, iv := SpecFromSetting(now, s)
+ return linux.Itimerspec{
+ Interval: linux.DurationToTimespec(iv),
+ Value: linux.DurationToTimespec(val),
+ }
+}
+
+// At returns an updated Setting and a number of expirations after the
+// associated Clock indicates a time of now.
+//
+// Settings may be created by successive calls to At with decreasing
+// values of now (i.e. time may appear to go backward). Supporting this is
+// required to support non-monotonic clocks, as well as allowing
+// Timer.clock.Now() to be called without holding Timer.mu.
+func (s Setting) At(now Time) (Setting, uint64) {
+ if !s.Enabled {
+ return s, 0
+ }
+ if s.Next.After(now) {
+ return s, 0
+ }
+ if s.Period == 0 {
+ s.Enabled = false
+ return s, 1
+ }
+ exp := 1 + uint64(now.Sub(s.Next).Nanoseconds())/uint64(s.Period)
+ s.Next = s.Next.Add(time.Duration(uint64(s.Period) * exp))
+ return s, exp
+}
+
+// Timer is an optionally-periodic timer driven by sampling a user-specified
+// Clock. Timer's semantics support the requirements of Linux's interval timers
+// (setitimer(2), timer_create(2), timerfd_create(2)).
+//
+// Timers should be created using NewTimer and must be cleaned up by calling
+// Timer.Destroy when no longer used.
+//
+// +stateify savable
+type Timer struct {
+ // clock is the time source. clock is immutable.
+ clock Clock
+
+ // listener is notified of expirations. listener is immutable.
+ listener TimerListener
+
+ // mu protects the following mutable fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // setting is the timer setting. setting is protected by mu.
+ setting Setting
+
+ // paused is true if the Timer is paused. paused is protected by mu.
+ paused bool
+
+ // kicker is used to wake the Timer goroutine. The kicker pointer is
+ // immutable, but its state is protected by mu.
+ kicker *time.Timer `state:"nosave"`
+
+ // entry is registered with clock.EventRegister. entry is immutable.
+ //
+ // Per comment in Clock, entry must be re-registered after restore; per
+ // comment in Timer.Load, this is done in Timer.Resume.
+ entry waiter.Entry `state:"nosave"`
+
+ // events is the channel that will be notified whenever entry receives an
+ // event. It is also closed by Timer.Destroy to instruct the Timer
+ // goroutine to exit.
+ events chan struct{} `state:"nosave"`
+}
+
+// timerTickEvents are Clock events that require the Timer goroutine to Tick
+// prematurely.
+const timerTickEvents = ClockEventSet | ClockEventRateIncrease
+
+// NewTimer returns a new Timer that will obtain time from clock and send
+// expirations to listener. The Timer is initially stopped and has no first
+// expiration or period configured.
+func NewTimer(clock Clock, listener TimerListener) *Timer {
+ t := &Timer{
+ clock: clock,
+ listener: listener,
+ }
+ t.init()
+ return t
+}
+
+// After waits for the duration to elapse according to clock and then sends a
+// notification on the returned channel. The timer is started immediately and
+// will fire exactly once. The second return value is the start time used with
+// the duration.
+//
+// Callers must call Timer.Destroy.
+func After(clock Clock, duration time.Duration) (*Timer, Time, <-chan struct{}) {
+ notifier, tchan := NewChannelNotifier()
+ t := NewTimer(clock, notifier)
+ now := clock.Now()
+
+ t.Swap(Setting{
+ Enabled: true,
+ Period: 0,
+ Next: now.Add(duration),
+ })
+ return t, now, tchan
+}
+
+// init initializes Timer state that is not preserved across save/restore. If
+// init has already been called, calling it again is a no-op.
+//
+// Preconditions: t.mu must be locked, or the caller must have exclusive access
+// to t.
+func (t *Timer) init() {
+ if t.kicker != nil {
+ return
+ }
+ // If t.kicker is nil, the Timer goroutine can't be running, so we can't
+ // race with it.
+ t.kicker = time.NewTimer(0)
+ t.entry, t.events = waiter.NewChannelEntry(nil)
+ t.clock.EventRegister(&t.entry, timerTickEvents)
+ go t.runGoroutine() // S/R-SAFE: synchronized by t.mu
+}
+
+// Destroy releases resources owned by the Timer. A Destroyed Timer must not be
+// used again; in particular, a Destroyed Timer should not be Saved.
+func (t *Timer) Destroy() {
+ // Stop the Timer, ensuring that the Timer goroutine will not call
+ // t.kicker.Reset, before calling t.kicker.Stop.
+ t.mu.Lock()
+ t.setting.Enabled = false
+ t.mu.Unlock()
+ t.kicker.Stop()
+ // Unregister t.entry, ensuring that the Clock will not send to t.events,
+ // before closing t.events to instruct the Timer goroutine to exit.
+ t.clock.EventUnregister(&t.entry)
+ close(t.events)
+ t.listener.Destroy()
+}
+
+func (t *Timer) runGoroutine() {
+ for {
+ select {
+ case <-t.kicker.C:
+ case _, ok := <-t.events:
+ if !ok {
+ // Channel closed by Destroy.
+ return
+ }
+ }
+ t.Tick()
+ }
+}
+
+// Tick requests that the Timer immediately check for expirations and
+// re-evaluate when it should next check for expirations.
+func (t *Timer) Tick() {
+ now := t.clock.Now()
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.paused {
+ return
+ }
+ s, exp := t.setting.At(now)
+ t.setting = s
+ if exp > 0 {
+ t.listener.Notify(exp)
+ }
+ t.resetKickerLocked(now)
+}
+
+// Pause pauses the Timer, ensuring that it does not generate any further
+// expirations until Resume is called. If the Timer is already paused, Pause
+// has no effect.
+func (t *Timer) Pause() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.paused = true
+ // t.kicker may be nil if we were restored but never resumed.
+ if t.kicker != nil {
+ t.kicker.Stop()
+ }
+}
+
+// Resume ends the effect of Pause. If the Timer is not paused, Resume has no
+// effect.
+func (t *Timer) Resume() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if !t.paused {
+ return
+ }
+ t.paused = false
+
+ // Lazily initialize the Timer. We can't call Timer.init until Timer.Resume
+ // because save/restore will restore Timers before
+ // kernel.Timekeeper.SetClocks() has been called, so if t.clock is backed
+ // by a kernel.Timekeeper then the Timer goroutine will panic if it calls
+ // t.clock.Now().
+ t.init()
+
+ // Kick the Timer goroutine in case it was already initialized, but the
+ // Timer goroutine was sleeping.
+ t.kicker.Reset(0)
+}
+
+// Get returns a snapshot of the Timer's current Setting and the time
+// (according to the Timer's Clock) at which the snapshot was taken.
+//
+// Preconditions: The Timer must not be paused (since its Setting cannot
+// be advanced to the current time while it is paused.)
+func (t *Timer) Get() (Time, Setting) {
+ now := t.clock.Now()
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.paused {
+ panic(fmt.Sprintf("Timer.Get called on paused Timer %p", t))
+ }
+ s, exp := t.setting.At(now)
+ t.setting = s
+ if exp > 0 {
+ t.listener.Notify(exp)
+ }
+ t.resetKickerLocked(now)
+ return now, s
+}
+
+// Swap atomically changes the Timer's Setting and returns the Timer's previous
+// Setting and the time (according to the Timer's Clock) at which the snapshot
+// was taken. Setting s.Enabled to true starts the Timer, while setting
+// s.Enabled to false stops it.
+//
+// Preconditions: The Timer must not be paused.
+func (t *Timer) Swap(s Setting) (Time, Setting) {
+ return t.SwapAnd(s, nil)
+}
+
+// SwapAnd atomically changes the Timer's Setting, calls f if it is not nil,
+// and returns the Timer's previous Setting and the time (according to the
+// Timer's Clock) at which the Setting was changed. Setting s.Enabled to true
+// starts the timer, while setting s.Enabled to false stops it.
+//
+// Preconditions: The Timer must not be paused. f cannot call any Timer methods
+// since it is called with the Timer mutex locked.
+func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
+ now := t.clock.Now()
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.paused {
+ panic(fmt.Sprintf("Timer.SwapAnd called on paused Timer %p", t))
+ }
+ oldS, oldExp := t.setting.At(now)
+ if oldExp > 0 {
+ t.listener.Notify(oldExp)
+ }
+ if f != nil {
+ f()
+ }
+ newS, newExp := s.At(now)
+ t.setting = newS
+ if newExp > 0 {
+ t.listener.Notify(newExp)
+ }
+ t.resetKickerLocked(now)
+ return now, oldS
+}
+
+// Atomically invokes f atomically with respect to expirations of t; that is, t
+// cannot generate expirations while f is being called.
+//
+// Preconditions: f cannot call any Timer methods since it is called with the
+// Timer mutex locked.
+func (t *Timer) Atomically(f func()) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ f()
+}
+
+// Preconditions: t.mu must be locked.
+func (t *Timer) resetKickerLocked(now Time) {
+ if t.setting.Enabled {
+ // Clock.WallTimeUntil may return a negative value. This is fine;
+ // time.when treats negative Durations as 0.
+ t.kicker.Reset(t.clock.WallTimeUntil(t.setting.Next, now))
+ }
+ // We don't call t.kicker.Stop if !t.setting.Enabled because in most cases
+ // resetKickerLocked will be called from the Timer goroutine itself, in
+ // which case t.kicker has already fired and t.kicker.Stop will be an
+ // expensive no-op (time.Timer.Stop => time.stopTimer => runtime.stopTimer
+ // => runtime.deltimer).
+}
+
+// Clock returns the Clock used by t.
+func (t *Timer) Clock() Clock {
+ return t.clock
+}
+
+// ChannelNotifier is a TimerListener that sends a message on an empty struct
+// channel.
+//
+// ChannelNotifier cannot be saved or loaded.
+type ChannelNotifier struct {
+ // tchan must be a buffered channel.
+ tchan chan struct{}
+}
+
+// NewChannelNotifier creates a new channel notifier.
+//
+// If the notifier is used with a timer, Timer.Destroy will close the channel
+// returned here.
+func NewChannelNotifier() (TimerListener, <-chan struct{}) {
+ tchan := make(chan struct{}, 1)
+ return &ChannelNotifier{tchan}, tchan
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (c *ChannelNotifier) Notify(uint64) {
+ select {
+ case c.tchan <- struct{}{}:
+ default:
+ }
+}
+
+// Destroy implements ktime.TimerListener.Destroy and will close the channel.
+func (c *ChannelNotifier) Destroy() {
+ close(c.tchan)
+}
diff --git a/pkg/sentry/kernel/time/time_state_autogen.go b/pkg/sentry/kernel/time/time_state_autogen.go
new file mode 100755
index 000000000..1750b55d6
--- /dev/null
+++ b/pkg/sentry/kernel/time/time_state_autogen.go
@@ -0,0 +1,56 @@
+// automatically generated by stateify.
+
+package time
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Time) beforeSave() {}
+func (x *Time) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ns", &x.ns)
+}
+
+func (x *Time) afterLoad() {}
+func (x *Time) load(m state.Map) {
+ m.Load("ns", &x.ns)
+}
+
+func (x *Setting) beforeSave() {}
+func (x *Setting) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Enabled", &x.Enabled)
+ m.Save("Next", &x.Next)
+ m.Save("Period", &x.Period)
+}
+
+func (x *Setting) afterLoad() {}
+func (x *Setting) load(m state.Map) {
+ m.Load("Enabled", &x.Enabled)
+ m.Load("Next", &x.Next)
+ m.Load("Period", &x.Period)
+}
+
+func (x *Timer) beforeSave() {}
+func (x *Timer) save(m state.Map) {
+ x.beforeSave()
+ m.Save("clock", &x.clock)
+ m.Save("listener", &x.listener)
+ m.Save("setting", &x.setting)
+ m.Save("paused", &x.paused)
+}
+
+func (x *Timer) afterLoad() {}
+func (x *Timer) load(m state.Map) {
+ m.Load("clock", &x.clock)
+ m.Load("listener", &x.listener)
+ m.Load("setting", &x.setting)
+ m.Load("paused", &x.paused)
+}
+
+func init() {
+ state.Register("time.Time", (*Time)(nil), state.Fns{Save: (*Time).save, Load: (*Time).load})
+ state.Register("time.Setting", (*Setting)(nil), state.Fns{Save: (*Setting).save, Load: (*Setting).load})
+ state.Register("time.Timer", (*Timer)(nil), state.Fns{Save: (*Timer).save, Load: (*Timer).load})
+}
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
new file mode 100644
index 000000000..505a4fa4f
--- /dev/null
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ sentrytime "gvisor.googlesource.com/gvisor/pkg/sentry/time"
+)
+
+// Timekeeper manages all of the kernel clocks.
+//
+// +stateify savable
+type Timekeeper struct {
+ // clocks are the clock sources.
+ //
+ // These are not saved directly, as the new machine's clock may behave
+ // differently.
+ //
+ // It is set only once, by SetClocks.
+ clocks sentrytime.Clocks `state:"nosave"`
+
+ // bootTime is the realtime when the system "booted". i.e., when
+ // SetClocks was called in the initial (not restored) run.
+ bootTime ktime.Time
+
+ // monotonicOffset is the offset to apply to the monotonic clock output
+ // from clocks.
+ //
+ // It is set only once, by SetClocks.
+ monotonicOffset int64 `state:"nosave"`
+
+ // restored, if non-nil, indicates that this Timekeeper was restored
+ // from a state file. The clocks are not set until restored is closed.
+ restored chan struct{} `state:"nosave"`
+
+ // saveMonotonic is the (offset) value of the monotonic clock at the
+ // time of save.
+ //
+ // It is only valid if restored is non-nil.
+ //
+ // It is only used in SetClocks after restore to compute the new
+ // monotonicOffset.
+ saveMonotonic int64
+
+ // saveRealtime is the value of the realtime clock at the time of save.
+ //
+ // It is only valid if restored is non-nil.
+ //
+ // It is only used in SetClocks after restore to compute the new
+ // monotonicOffset.
+ saveRealtime int64
+
+ // params manages the parameter page.
+ params *VDSOParamPage
+
+ // mu protects destruction with stop and wg.
+ mu sync.Mutex `state:"nosave"`
+
+ // stop is used to tell the update goroutine to exit.
+ stop chan struct{} `state:"nosave"`
+
+ // wg is used to indicate that the update goroutine has exited.
+ wg sync.WaitGroup `state:"nosave"`
+}
+
+// NewTimekeeper returns a Timekeeper that is automatically kept up-to-date.
+// NewTimekeeper does not take ownership of paramPage.
+//
+// SetClocks must be called on the returned Timekeeper before it is usable.
+func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) {
+ return &Timekeeper{
+ params: NewVDSOParamPage(mfp, paramPage),
+ }, nil
+}
+
+// SetClocks the backing clock source.
+//
+// SetClocks must be called before the Timekeeper is used, and it may not be
+// called more than once, as changing the clock source without extra correction
+// could cause time discontinuities.
+//
+// It must also be called after Load.
+func (t *Timekeeper) SetClocks(c sentrytime.Clocks) {
+ // Update the params, marking them "not ready", as we may need to
+ // restart calibration on this new machine.
+ if t.restored != nil {
+ if err := t.params.Write(func() vdsoParams {
+ return vdsoParams{}
+ }); err != nil {
+ panic("unable to reset VDSO params: " + err.Error())
+ }
+ }
+
+ if t.clocks != nil {
+ panic("SetClocks called on previously-initialized Timekeeper")
+ }
+
+ t.clocks = c
+
+ // Compute the offset of the monotonic clock from the base Clocks.
+ //
+ // In a fresh (not restored) sentry, monotonic time starts at zero.
+ //
+ // In a restored sentry, monotonic time jumps forward by approximately
+ // the same amount as real time. There are no guarantees here, we are
+ // just making a best-effort attempt to to make it appear that the app
+ // was simply not scheduled for a long period, rather than that the
+ // real time clock was changed.
+ //
+ // If real time went backwards, it remains the same.
+ wantMonotonic := int64(0)
+
+ nowMonotonic, err := t.clocks.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ panic("Unable to get current monotonic time: " + err.Error())
+ }
+
+ nowRealtime, err := t.clocks.GetTime(sentrytime.Realtime)
+ if err != nil {
+ panic("Unable to get current realtime: " + err.Error())
+ }
+
+ if t.restored != nil {
+ wantMonotonic = t.saveMonotonic
+ elapsed := nowRealtime - t.saveRealtime
+ if elapsed > 0 {
+ wantMonotonic += elapsed
+ }
+ }
+
+ t.monotonicOffset = wantMonotonic - nowMonotonic
+
+ if t.restored == nil {
+ // Hold on to the initial "boot" time.
+ t.bootTime = ktime.FromNanoseconds(nowRealtime)
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.startUpdater()
+
+ if t.restored != nil {
+ close(t.restored)
+ }
+}
+
+// startUpdater starts an update goroutine that keeps the clocks updated.
+//
+// mu must be held.
+func (t *Timekeeper) startUpdater() {
+ if t.stop != nil {
+ // Timekeeper already started
+ return
+ }
+ t.stop = make(chan struct{})
+
+ // Keep the clocks up to date.
+ //
+ // Note that the Go runtime uses host CLOCK_MONOTONIC to service the
+ // timer, so it may run at a *slightly* different rate from the
+ // application CLOCK_MONOTONIC. That is fine, as we only need to update
+ // at approximately this rate.
+ timer := time.NewTicker(sentrytime.ApproxUpdateInterval)
+ t.wg.Add(1)
+ go func() { // S/R-SAFE: stopped during save.
+ for {
+ // Start with an update immediately, so the clocks are
+ // ready ASAP.
+
+ // Call Update within a Write block to prevent the VDSO
+ // from using the old params between Update and
+ // Write.
+ if err := t.params.Write(func() vdsoParams {
+ monotonicParams, monotonicOk, realtimeParams, realtimeOk := t.clocks.Update()
+
+ var p vdsoParams
+ if monotonicOk {
+ p.monotonicReady = 1
+ p.monotonicBaseCycles = int64(monotonicParams.BaseCycles)
+ p.monotonicBaseRef = int64(monotonicParams.BaseRef) + t.monotonicOffset
+ p.monotonicFrequency = monotonicParams.Frequency
+ }
+ if realtimeOk {
+ p.realtimeReady = 1
+ p.realtimeBaseCycles = int64(realtimeParams.BaseCycles)
+ p.realtimeBaseRef = int64(realtimeParams.BaseRef)
+ p.realtimeFrequency = realtimeParams.Frequency
+ }
+
+ log.Debugf("Updating VDSO parameters: %+v", p)
+
+ return p
+ }); err != nil {
+ log.Warningf("Unable to update VDSO parameter page: %v", err)
+ }
+
+ select {
+ case <-timer.C:
+ case <-t.stop:
+ t.wg.Done()
+ return
+ }
+ }
+ }()
+}
+
+// stopUpdater stops the update goroutine, blocking until it exits.
+//
+// mu must be held.
+func (t *Timekeeper) stopUpdater() {
+ if t.stop == nil {
+ // Updater not running.
+ return
+ }
+
+ close(t.stop)
+ t.wg.Wait()
+ t.stop = nil
+}
+
+// Destroy destroys the Timekeeper, freeing all associated resources.
+func (t *Timekeeper) Destroy() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ t.stopUpdater()
+}
+
+// PauseUpdates stops clock parameter updates. This should only be used when
+// Tasks are not running and thus cannot access the clock.
+func (t *Timekeeper) PauseUpdates() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.stopUpdater()
+}
+
+// ResumeUpdates restarts clock parameter updates stopped by PauseUpdates.
+func (t *Timekeeper) ResumeUpdates() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.startUpdater()
+}
+
+// GetTime returns the current time in nanoseconds.
+func (t *Timekeeper) GetTime(c sentrytime.ClockID) (int64, error) {
+ if t.clocks == nil {
+ if t.restored == nil {
+ panic("Timekeeper used before initialized with SetClocks")
+ }
+ <-t.restored
+ }
+ now, err := t.clocks.GetTime(c)
+ if err == nil && c == sentrytime.Monotonic {
+ now += t.monotonicOffset
+ }
+ return now, err
+}
+
+// BootTime returns the system boot real time.
+func (t *Timekeeper) BootTime() ktime.Time {
+ return t.bootTime
+}
+
+// timekeeperClock is a ktime.Clock that reads time from a
+// kernel.Timekeeper-managed clock.
+//
+// +stateify savable
+type timekeeperClock struct {
+ tk *Timekeeper
+ c sentrytime.ClockID
+
+ // Implements ktime.Clock.WallTimeUntil.
+ ktime.WallRateClock `state:"nosave"`
+
+ // Implements waiter.Waitable. (We have no ability to detect
+ // discontinuities from external changes to CLOCK_REALTIME).
+ ktime.NoClockEvents `state:"nosave"`
+}
+
+// Now implements ktime.Clock.Now.
+func (tc *timekeeperClock) Now() ktime.Time {
+ now, err := tc.tk.GetTime(tc.c)
+ if err != nil {
+ panic(fmt.Sprintf("timekeeperClock(ClockID=%v)).Now: %v", tc.c, err))
+ }
+ return ktime.FromNanoseconds(now)
+}
diff --git a/pkg/sentry/kernel/timekeeper_state.go b/pkg/sentry/kernel/timekeeper_state.go
new file mode 100644
index 000000000..6ce358a05
--- /dev/null
+++ b/pkg/sentry/kernel/timekeeper_state.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/time"
+)
+
+// beforeSave is invoked by stateify.
+func (t *Timekeeper) beforeSave() {
+ if t.stop != nil {
+ panic("pauseUpdates must be called before Save")
+ }
+
+ // N.B. we want the *offset* monotonic time.
+ var err error
+ if t.saveMonotonic, err = t.GetTime(time.Monotonic); err != nil {
+ panic("unable to get current monotonic time: " + err.Error())
+ }
+
+ if t.saveRealtime, err = t.GetTime(time.Realtime); err != nil {
+ panic("unable to get current realtime: " + err.Error())
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (t *Timekeeper) afterLoad() {
+ t.restored = make(chan struct{})
+}
diff --git a/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go b/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go
new file mode 100755
index 000000000..6f5580ebe
--- /dev/null
+++ b/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go
@@ -0,0 +1,119 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: pkg/sentry/kernel/uncaught_signal.proto
+
+package gvisor
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ registers_go_proto "gvisor.googlesource.com/gvisor/pkg/sentry/arch/registers_go_proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type UncaughtSignal struct {
+ Tid int32 `protobuf:"varint,1,opt,name=tid,proto3" json:"tid,omitempty"`
+ Pid int32 `protobuf:"varint,2,opt,name=pid,proto3" json:"pid,omitempty"`
+ Registers *registers_go_proto.Registers `protobuf:"bytes,3,opt,name=registers,proto3" json:"registers,omitempty"`
+ SignalNumber int32 `protobuf:"varint,4,opt,name=signal_number,json=signalNumber,proto3" json:"signal_number,omitempty"`
+ FaultAddr uint64 `protobuf:"varint,5,opt,name=fault_addr,json=faultAddr,proto3" json:"fault_addr,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *UncaughtSignal) Reset() { *m = UncaughtSignal{} }
+func (m *UncaughtSignal) String() string { return proto.CompactTextString(m) }
+func (*UncaughtSignal) ProtoMessage() {}
+func (*UncaughtSignal) Descriptor() ([]byte, []int) {
+ return fileDescriptor_5ca9e03e13704688, []int{0}
+}
+
+func (m *UncaughtSignal) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_UncaughtSignal.Unmarshal(m, b)
+}
+func (m *UncaughtSignal) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_UncaughtSignal.Marshal(b, m, deterministic)
+}
+func (m *UncaughtSignal) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_UncaughtSignal.Merge(m, src)
+}
+func (m *UncaughtSignal) XXX_Size() int {
+ return xxx_messageInfo_UncaughtSignal.Size(m)
+}
+func (m *UncaughtSignal) XXX_DiscardUnknown() {
+ xxx_messageInfo_UncaughtSignal.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_UncaughtSignal proto.InternalMessageInfo
+
+func (m *UncaughtSignal) GetTid() int32 {
+ if m != nil {
+ return m.Tid
+ }
+ return 0
+}
+
+func (m *UncaughtSignal) GetPid() int32 {
+ if m != nil {
+ return m.Pid
+ }
+ return 0
+}
+
+func (m *UncaughtSignal) GetRegisters() *registers_go_proto.Registers {
+ if m != nil {
+ return m.Registers
+ }
+ return nil
+}
+
+func (m *UncaughtSignal) GetSignalNumber() int32 {
+ if m != nil {
+ return m.SignalNumber
+ }
+ return 0
+}
+
+func (m *UncaughtSignal) GetFaultAddr() uint64 {
+ if m != nil {
+ return m.FaultAddr
+ }
+ return 0
+}
+
+func init() {
+ proto.RegisterType((*UncaughtSignal)(nil), "gvisor.UncaughtSignal")
+}
+
+func init() {
+ proto.RegisterFile("pkg/sentry/kernel/uncaught_signal.proto", fileDescriptor_5ca9e03e13704688)
+}
+
+var fileDescriptor_5ca9e03e13704688 = []byte{
+ // 210 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x4c, 0x8e, 0x4d, 0x4a, 0xc6, 0x30,
+ 0x10, 0x86, 0x89, 0xfd, 0x81, 0xc6, 0x1f, 0x34, 0xab, 0x20, 0x88, 0x45, 0x17, 0x76, 0xd5, 0x80,
+ 0x9e, 0xc0, 0x0b, 0xb8, 0x88, 0xb8, 0x2e, 0x69, 0x13, 0xd3, 0xd0, 0x9a, 0x86, 0x49, 0x22, 0x78,
+ 0x24, 0x6f, 0x29, 0x4d, 0xd4, 0xef, 0xdb, 0x0d, 0xcf, 0xbc, 0xf3, 0xcc, 0x8b, 0x1f, 0xdc, 0xa2,
+ 0x99, 0x57, 0x36, 0xc0, 0x17, 0x5b, 0x14, 0x58, 0xb5, 0xb2, 0x68, 0x27, 0x11, 0xf5, 0x1c, 0x06,
+ 0x6f, 0xb4, 0x15, 0x6b, 0xef, 0x60, 0x0b, 0x1b, 0xa9, 0xf5, 0xa7, 0xf1, 0x1b, 0x5c, 0xdf, 0x1e,
+ 0x1d, 0x08, 0x98, 0x66, 0x06, 0x4a, 0x1b, 0x1f, 0x14, 0xf8, 0x1c, 0xbc, 0xfb, 0x46, 0xf8, 0xe2,
+ 0xed, 0x57, 0xf1, 0x9a, 0x0c, 0xe4, 0x12, 0x17, 0xc1, 0x48, 0x8a, 0x5a, 0xd4, 0x55, 0x7c, 0x1f,
+ 0x77, 0xe2, 0x8c, 0xa4, 0x27, 0x99, 0x38, 0x23, 0x09, 0xc3, 0xcd, 0xbf, 0x89, 0x16, 0x2d, 0xea,
+ 0x4e, 0x1f, 0xaf, 0xfa, 0xfc, 0xb3, 0xe7, 0x7f, 0x0b, 0x7e, 0xc8, 0x90, 0x7b, 0x7c, 0x9e, 0x0b,
+ 0x0e, 0x36, 0x7e, 0x8c, 0x0a, 0x68, 0x99, 0x64, 0x67, 0x19, 0xbe, 0x24, 0x46, 0x6e, 0x30, 0x7e,
+ 0x17, 0x71, 0x0d, 0x83, 0x90, 0x12, 0x68, 0xd5, 0xa2, 0xae, 0xe4, 0x4d, 0x22, 0xcf, 0x52, 0xc2,
+ 0x58, 0xa7, 0xca, 0x4f, 0x3f, 0x01, 0x00, 0x00, 0xff, 0xff, 0xfd, 0x62, 0x54, 0xdf, 0x06, 0x01,
+ 0x00, 0x00,
+}
diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go
new file mode 100644
index 000000000..96fe3cbb9
--- /dev/null
+++ b/pkg/sentry/kernel/uts_namespace.go
@@ -0,0 +1,102 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+)
+
+// UTSNamespace represents a UTS namespace, a holder of two system identifiers:
+// the hostname and domain name.
+//
+// +stateify savable
+type UTSNamespace struct {
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
+ hostName string
+ domainName string
+
+ // userns is the user namespace associated with the UTSNamespace.
+ // Privileged operations on this UTSNamespace must have appropriate
+ // capabilities in userns.
+ //
+ // userns is immutable.
+ userns *auth.UserNamespace
+}
+
+// NewUTSNamespace creates a new UTS namespace.
+func NewUTSNamespace(hostName, domainName string, userns *auth.UserNamespace) *UTSNamespace {
+ return &UTSNamespace{
+ hostName: hostName,
+ domainName: domainName,
+ userns: userns,
+ }
+}
+
+// UTSNamespace returns the task's UTS namespace.
+func (t *Task) UTSNamespace() *UTSNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.utsns
+}
+
+// HostName returns the host name of this UTS namespace.
+func (u *UTSNamespace) HostName() string {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return u.hostName
+}
+
+// SetHostName sets the host name of this UTS namespace.
+func (u *UTSNamespace) SetHostName(host string) {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ u.hostName = host
+}
+
+// DomainName returns the domain name of this UTS namespace.
+func (u *UTSNamespace) DomainName() string {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return u.domainName
+}
+
+// SetDomainName sets the domain name of this UTS namespace.
+func (u *UTSNamespace) SetDomainName(domain string) {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ u.domainName = domain
+}
+
+// UserNamespace returns the user namespace associated with this UTS namespace.
+func (u *UTSNamespace) UserNamespace() *auth.UserNamespace {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return u.userns
+}
+
+// Clone makes a copy of this UTS namespace, associating the given user
+// namespace.
+func (u *UTSNamespace) Clone(userns *auth.UserNamespace) *UTSNamespace {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return &UTSNamespace{
+ hostName: u.hostName,
+ domainName: u.domainName,
+ userns: userns,
+ }
+}
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
new file mode 100644
index 000000000..d40ad74f4
--- /dev/null
+++ b/pkg/sentry/kernel/vdso.go
@@ -0,0 +1,148 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// vdsoParams are the parameters exposed to the VDSO.
+//
+// They are exposed to the VDSO via a parameter page managed by VDSOParamPage,
+// which also includes a sequence counter.
+type vdsoParams struct {
+ monotonicReady uint64
+ monotonicBaseCycles int64
+ monotonicBaseRef int64
+ monotonicFrequency uint64
+
+ realtimeReady uint64
+ realtimeBaseCycles int64
+ realtimeBaseRef int64
+ realtimeFrequency uint64
+}
+
+// VDSOParamPage manages a VDSO parameter page.
+//
+// Its memory layout looks like:
+//
+// type page struct {
+// // seq is a sequence counter that protects the fields below.
+// seq uint64
+// vdsoParams
+// }
+//
+// Everything in the struct is 8 bytes for easy alignment.
+//
+// It must be kept in sync with params in vdso/vdso_time.cc.
+//
+// +stateify savable
+type VDSOParamPage struct {
+ // The parameter page is fr, allocated from mfp.MemoryFile().
+ mfp pgalloc.MemoryFileProvider
+ fr platform.FileRange
+
+ // seq is the current sequence count written to the page.
+ //
+ // A write is in progress if bit 1 of the counter is set.
+ //
+ // Timekeeper's updater goroutine may call Write before equality is
+ // checked in state_test_util tests, causing this field to change across
+ // save / restore.
+ seq uint64
+}
+
+// NewVDSOParamPage returns a VDSOParamPage.
+//
+// Preconditions:
+//
+// * fr is a single page allocated from mfp.MemoryFile(). VDSOParamPage does
+// not take ownership of fr; it must remain allocated for the lifetime of the
+// VDSOParamPage.
+//
+// * VDSOParamPage must be the only writer to fr.
+//
+// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
+func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage {
+ return &VDSOParamPage{mfp: mfp, fr: fr}
+}
+
+// access returns a mapping of the param page.
+func (v *VDSOParamPage) access() (safemem.Block, error) {
+ bs, err := v.mfp.MemoryFile().MapInternal(v.fr, usermem.ReadWrite)
+ if err != nil {
+ return safemem.Block{}, err
+ }
+ if bs.NumBlocks() != 1 {
+ panic(fmt.Sprintf("Multiple blocks (%d) in VDSO param BlockSeq", bs.NumBlocks()))
+ }
+ return bs.Head(), nil
+}
+
+// incrementSeq increments the sequence counter in the param page.
+func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error {
+ next := v.seq + 1
+ old, err := safemem.SwapUint64(paramPage, next)
+ if err != nil {
+ return err
+ }
+
+ if old != v.seq {
+ return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq)
+ }
+
+ v.seq = next
+ return nil
+}
+
+// Write updates the VDSO parameters.
+//
+// Write starts a write block, calls f to get the new parameters, writes
+// out the new parameters, then ends the write block.
+func (v *VDSOParamPage) Write(f func() vdsoParams) error {
+ paramPage, err := v.access()
+ if err != nil {
+ return err
+ }
+
+ // Write begin.
+ next := v.seq + 1
+ if next%2 != 1 {
+ panic("Out-of-order sequence count")
+ }
+
+ err = v.incrementSeq(paramPage)
+ if err != nil {
+ return err
+ }
+
+ // Get the new params.
+ p := f()
+ buf := binary.Marshal(nil, usermem.ByteOrder, p)
+
+ // Skip the sequence counter.
+ if _, err := safemem.Copy(paramPage.DropFirst(8), safemem.BlockFromSafeSlice(buf)); err != nil {
+ panic(fmt.Sprintf("Unable to get set VDSO parameters: %v", err))
+ }
+
+ // Write end.
+ return v.incrementSeq(paramPage)
+}
diff --git a/pkg/sentry/kernel/version.go b/pkg/sentry/kernel/version.go
new file mode 100644
index 000000000..5640dd71d
--- /dev/null
+++ b/pkg/sentry/kernel/version.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// Version defines the application-visible system version.
+type Version struct {
+ // Operating system name (e.g. "Linux").
+ Sysname string
+
+ // Operating system release (e.g. "4.4-amd64").
+ Release string
+
+ // Operating system version. On Linux this takes the shape
+ // "#VERSION CONFIG_FLAGS TIMESTAMP"
+ // where:
+ // - VERSION is a sequence counter incremented on every successful build
+ // - CONFIG_FLAGS is a space-separated list of major enabled kernel features
+ // (e.g. "SMP" and "PREEMPT")
+ // - TIMESTAMP is the build timestamp as returned by `date`
+ Version string
+}
diff --git a/pkg/sentry/limits/context.go b/pkg/sentry/limits/context.go
new file mode 100644
index 000000000..9200edb52
--- /dev/null
+++ b/pkg/sentry/limits/context.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package limits
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextID is the limit package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxLimits is a Context.Value key for a LimitSet.
+ CtxLimits contextID = iota
+)
+
+// FromContext returns the limits that apply to ctx.
+func FromContext(ctx context.Context) *LimitSet {
+ if v := ctx.Value(CtxLimits); v != nil {
+ return v.(*LimitSet)
+ }
+ return nil
+}
diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go
new file mode 100644
index 000000000..b6c22656b
--- /dev/null
+++ b/pkg/sentry/limits/limits.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package limits provides resource limits.
+package limits
+
+import (
+ "sync"
+ "syscall"
+)
+
+// LimitType defines a type of resource limit.
+type LimitType int
+
+// Set of constants defining the different types of resource limits.
+const (
+ CPU LimitType = iota
+ FileSize
+ Data
+ Stack
+ Core
+ Rss
+ ProcessCount
+ NumberOfFiles
+ MemoryLocked
+ AS
+ Locks
+ SignalsPending
+ MessageQueueBytes
+ Nice
+ RealTimePriority
+ Rttime
+)
+
+// Infinity is a constant representing a resource with no limit.
+const Infinity = ^uint64(0)
+
+// Limit specifies a system limit.
+//
+// +stateify savable
+type Limit struct {
+ // Cur specifies the current limit.
+ Cur uint64
+ // Max specifies the maximum settable limit.
+ Max uint64
+}
+
+// LimitSet represents the Limits that correspond to each LimitType.
+//
+// +stateify savable
+type LimitSet struct {
+ mu sync.Mutex `state:"nosave"`
+ data map[LimitType]Limit
+}
+
+// NewLimitSet creates a new, empty LimitSet.
+func NewLimitSet() *LimitSet {
+ return &LimitSet{
+ data: make(map[LimitType]Limit),
+ }
+}
+
+// GetCopy returns a clone of the LimitSet.
+func (l *LimitSet) GetCopy() *LimitSet {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ copyData := make(map[LimitType]Limit)
+ for k, v := range l.data {
+ copyData[k] = v
+ }
+ return &LimitSet{
+ data: copyData,
+ }
+}
+
+// Get returns the resource limit associated with LimitType t.
+// If no limit is provided, it defaults to an infinite limit.Infinity.
+func (l *LimitSet) Get(t LimitType) Limit {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ s, ok := l.data[t]
+ if !ok {
+ return Limit{Cur: Infinity, Max: Infinity}
+ }
+ return s
+}
+
+// GetCapped returns the current value for the limit, capped as specified.
+func (l *LimitSet) GetCapped(t LimitType, max uint64) uint64 {
+ s := l.Get(t)
+ if s.Cur == Infinity || s.Cur > max {
+ return max
+ }
+ return s.Cur
+}
+
+// SetUnchecked assigns value v to resource of LimitType t.
+func (l *LimitSet) SetUnchecked(t LimitType, v Limit) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.data[t] = v
+}
+
+// Set assigns value v to resource of LimitType t and returns the old value.
+// privileged should be true only when either the caller has CAP_SYS_RESOURCE
+// or when creating limits for a new kernel.
+func (l *LimitSet) Set(t LimitType, v Limit, privileged bool) (Limit, error) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ // If a limit is already set, make sure the new limit doesn't
+ // exceed the previous max limit.
+ if _, ok := l.data[t]; ok {
+ // Unprivileged users can only lower their hard limits.
+ if l.data[t].Max < v.Max && !privileged {
+ return Limit{}, syscall.EPERM
+ }
+ if v.Cur > v.Max {
+ return Limit{}, syscall.EINVAL
+ }
+ }
+ old := l.data[t]
+ l.data[t] = v
+ return old, nil
+}
diff --git a/pkg/sentry/limits/limits_state_autogen.go b/pkg/sentry/limits/limits_state_autogen.go
new file mode 100755
index 000000000..b1cdb4c49
--- /dev/null
+++ b/pkg/sentry/limits/limits_state_autogen.go
@@ -0,0 +1,36 @@
+// automatically generated by stateify.
+
+package limits
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Limit) beforeSave() {}
+func (x *Limit) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Cur", &x.Cur)
+ m.Save("Max", &x.Max)
+}
+
+func (x *Limit) afterLoad() {}
+func (x *Limit) load(m state.Map) {
+ m.Load("Cur", &x.Cur)
+ m.Load("Max", &x.Max)
+}
+
+func (x *LimitSet) beforeSave() {}
+func (x *LimitSet) save(m state.Map) {
+ x.beforeSave()
+ m.Save("data", &x.data)
+}
+
+func (x *LimitSet) afterLoad() {}
+func (x *LimitSet) load(m state.Map) {
+ m.Load("data", &x.data)
+}
+
+func init() {
+ state.Register("limits.Limit", (*Limit)(nil), state.Fns{Save: (*Limit).save, Load: (*Limit).load})
+ state.Register("limits.LimitSet", (*LimitSet)(nil), state.Fns{Save: (*LimitSet).save, Load: (*LimitSet).load})
+}
diff --git a/pkg/sentry/limits/linux.go b/pkg/sentry/limits/linux.go
new file mode 100644
index 000000000..a2b401e3d
--- /dev/null
+++ b/pkg/sentry/limits/linux.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package limits
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// FromLinuxResource maps linux resources to sentry LimitTypes.
+var FromLinuxResource = map[int]LimitType{
+ linux.RLIMIT_CPU: CPU,
+ linux.RLIMIT_FSIZE: FileSize,
+ linux.RLIMIT_DATA: Data,
+ linux.RLIMIT_STACK: Stack,
+ linux.RLIMIT_CORE: Core,
+ linux.RLIMIT_RSS: Rss,
+ linux.RLIMIT_NPROC: ProcessCount,
+ linux.RLIMIT_NOFILE: NumberOfFiles,
+ linux.RLIMIT_MEMLOCK: MemoryLocked,
+ linux.RLIMIT_AS: AS,
+ linux.RLIMIT_LOCKS: Locks,
+ linux.RLIMIT_SIGPENDING: SignalsPending,
+ linux.RLIMIT_MSGQUEUE: MessageQueueBytes,
+ linux.RLIMIT_NICE: Nice,
+ linux.RLIMIT_RTPRIO: RealTimePriority,
+ linux.RLIMIT_RTTIME: Rttime,
+}
+
+// FromLinux maps linux rlimit values to sentry Limits, being careful to handle
+// infinities.
+func FromLinux(rl uint64) uint64 {
+ if rl == linux.RLimInfinity {
+ return Infinity
+ }
+ return rl
+}
+
+// ToLinux maps sentry Limits to linux rlimit values, being careful to handle
+// infinities.
+func ToLinux(l uint64) uint64 {
+ if l == Infinity {
+ return linux.RLimInfinity
+ }
+ return l
+}
+
+// NewLinuxLimitSet returns a LimitSet whose values match the default rlimits
+// in Linux.
+func NewLinuxLimitSet() (*LimitSet, error) {
+ ls := NewLimitSet()
+ for rlt, rl := range linux.InitRLimits {
+ lt, ok := FromLinuxResource[rlt]
+ if !ok {
+ return nil, fmt.Errorf("unknown rlimit type %v", rlt)
+ }
+ ls.SetUnchecked(lt, Limit{
+ Cur: FromLinux(rl.Cur),
+ Max: FromLinux(rl.Max),
+ })
+ }
+ return ls, nil
+}
+
+// NewLinuxDistroLimitSet returns a new LimitSet whose values are typical
+// for a booted Linux distro.
+//
+// Many Linux init systems adjust the default Linux limits to values more
+// expected by the rest of the userspace. NewLinuxDistroLimitSet returns a
+// LimitSet with sensible defaults for applications that aren't starting
+// their own init system.
+func NewLinuxDistroLimitSet() (*LimitSet, error) {
+ ls, err := NewLinuxLimitSet()
+ if err != nil {
+ return nil, err
+ }
+
+ // Adjust ProcessCount to a lower value because GNU bash allocates 16
+ // bytes per proc and OOMs if this number is set too high. Value was
+ // picked arbitrarily.
+ //
+ // 1,048,576 ought to be enough for anyone.
+ l := ls.Get(ProcessCount)
+ l.Cur = 1 << 20
+ ls.Set(ProcessCount, l, true /* privileged */)
+ return ls, nil
+}
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
new file mode 100644
index 000000000..900236531
--- /dev/null
+++ b/pkg/sentry/loader/elf.go
@@ -0,0 +1,669 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package loader
+
+import (
+ "bytes"
+ "debug/elf"
+ "fmt"
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const (
+ // elfMagic identifies an ELF file.
+ elfMagic = "\x7fELF"
+
+ // maxTotalPhdrSize is the maximum combined size of all program
+ // headers. Linux limits this to one page.
+ maxTotalPhdrSize = usermem.PageSize
+)
+
+var (
+ // header64Size is the size of elf.Header64.
+ header64Size = int(binary.Size(elf.Header64{}))
+
+ // Prog64Size is the size of elf.Prog64.
+ prog64Size = int(binary.Size(elf.Prog64{}))
+)
+
+func progFlagsAsPerms(f elf.ProgFlag) usermem.AccessType {
+ var p usermem.AccessType
+ if f&elf.PF_R == elf.PF_R {
+ p.Read = true
+ }
+ if f&elf.PF_W == elf.PF_W {
+ p.Write = true
+ }
+ if f&elf.PF_X == elf.PF_X {
+ p.Execute = true
+ }
+ return p
+}
+
+// elfInfo contains the metadata needed to load an ELF binary.
+type elfInfo struct {
+ // os is the target OS of the ELF.
+ os abi.OS
+
+ // arch is the target architecture of the ELF.
+ arch arch.Arch
+
+ // entry is the program entry point.
+ entry usermem.Addr
+
+ // phdrs are the program headers.
+ phdrs []elf.ProgHeader
+
+ // phdrSize is the size of a single program header in the ELF.
+ phdrSize int
+
+ // phdrOff is the offset of the program headers in the file.
+ phdrOff uint64
+
+ // sharedObject is true if the ELF represents a shared object.
+ sharedObject bool
+}
+
+// parseHeader parse the ELF header, verifying that this is a supported ELF
+// file and returning the ELF program headers.
+//
+// This is similar to elf.NewFile, except that it is more strict about what it
+// accepts from the ELF, and it doesn't parse unnecessary parts of the file.
+//
+// ctx may be nil if f does not need it.
+func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
+ // Check ident first; it will tell us the endianness of the rest of the
+ // structs.
+ var ident [elf.EI_NIDENT]byte
+ _, err := readFull(ctx, f, usermem.BytesIOSequence(ident[:]), 0)
+ if err != nil {
+ log.Infof("Error reading ELF ident: %v", err)
+ // The entire ident array always exists.
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ err = syserror.ENOEXEC
+ }
+ return elfInfo{}, err
+ }
+
+ // Only some callers pre-check the ELF magic.
+ if !bytes.Equal(ident[:len(elfMagic)], []byte(elfMagic)) {
+ log.Infof("File is not an ELF")
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ // We only support 64-bit, little endian binaries
+ if class := elf.Class(ident[elf.EI_CLASS]); class != elf.ELFCLASS64 {
+ log.Infof("Unsupported ELF class: %v", class)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ if endian := elf.Data(ident[elf.EI_DATA]); endian != elf.ELFDATA2LSB {
+ log.Infof("Unsupported ELF endianness: %v", endian)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ byteOrder := binary.LittleEndian
+
+ if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT {
+ log.Infof("Unsupported ELF version: %v", version)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ // EI_OSABI is ignored by Linux, which is the only OS supported.
+ os := abi.Linux
+
+ var hdr elf.Header64
+ hdrBuf := make([]byte, header64Size)
+ _, err = readFull(ctx, f, usermem.BytesIOSequence(hdrBuf), 0)
+ if err != nil {
+ log.Infof("Error reading ELF header: %v", err)
+ // The entire header always exists.
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ err = syserror.ENOEXEC
+ }
+ return elfInfo{}, err
+ }
+ binary.Unmarshal(hdrBuf, byteOrder, &hdr)
+
+ // We only support amd64.
+ if machine := elf.Machine(hdr.Machine); machine != elf.EM_X86_64 {
+ log.Infof("Unsupported ELF machine %d", machine)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ a := arch.AMD64
+
+ var sharedObject bool
+ elfType := elf.Type(hdr.Type)
+ switch elfType {
+ case elf.ET_EXEC:
+ sharedObject = false
+ case elf.ET_DYN:
+ sharedObject = true
+ default:
+ log.Infof("Unsupported ELF type %v", elfType)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ if int(hdr.Phentsize) != prog64Size {
+ log.Infof("Unsupported phdr size %d", hdr.Phentsize)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ totalPhdrSize := prog64Size * int(hdr.Phnum)
+ if totalPhdrSize < prog64Size {
+ log.Warningf("No phdrs or total phdr size overflows: prog64Size: %d phnum: %d", prog64Size, int(hdr.Phnum))
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ if totalPhdrSize > maxTotalPhdrSize {
+ log.Infof("Too many phdrs (%d): total size %d > %d", hdr.Phnum, totalPhdrSize, maxTotalPhdrSize)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ phdrBuf := make([]byte, totalPhdrSize)
+ _, err = readFull(ctx, f, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
+ if err != nil {
+ log.Infof("Error reading ELF phdrs: %v", err)
+ // If phdrs were specified, they should all exist.
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ err = syserror.ENOEXEC
+ }
+ return elfInfo{}, err
+ }
+
+ phdrs := make([]elf.ProgHeader, hdr.Phnum)
+ for i := range phdrs {
+ var prog64 elf.Prog64
+ binary.Unmarshal(phdrBuf[:prog64Size], byteOrder, &prog64)
+ phdrBuf = phdrBuf[prog64Size:]
+ phdrs[i] = elf.ProgHeader{
+ Type: elf.ProgType(prog64.Type),
+ Flags: elf.ProgFlag(prog64.Flags),
+ Off: prog64.Off,
+ Vaddr: prog64.Vaddr,
+ Paddr: prog64.Paddr,
+ Filesz: prog64.Filesz,
+ Memsz: prog64.Memsz,
+ Align: prog64.Align,
+ }
+ }
+
+ return elfInfo{
+ os: os,
+ arch: a,
+ entry: usermem.Addr(hdr.Entry),
+ phdrs: phdrs,
+ phdrOff: hdr.Phoff,
+ phdrSize: prog64Size,
+ sharedObject: sharedObject,
+ }, nil
+}
+
+// mapSegment maps a phdr into the Task. offset is the offset to apply to
+// phdr.Vaddr.
+func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
+ // We must make a page-aligned mapping.
+ adjust := usermem.Addr(phdr.Vaddr).PageOffset()
+
+ addr, ok := offset.AddLength(phdr.Vaddr)
+ if !ok {
+ // If offset != 0 we should have ensured this would fit.
+ ctx.Warningf("Computed segment load address overflows: %#x + %#x", phdr.Vaddr, offset)
+ return syserror.ENOEXEC
+ }
+ addr -= usermem.Addr(adjust)
+
+ fileSize := phdr.Filesz + adjust
+ if fileSize < phdr.Filesz {
+ ctx.Infof("Computed segment file size overflows: %#x + %#x", phdr.Filesz, adjust)
+ return syserror.ENOEXEC
+ }
+ ms, ok := usermem.Addr(fileSize).RoundUp()
+ if !ok {
+ ctx.Infof("fileSize %#x too large", fileSize)
+ return syserror.ENOEXEC
+ }
+ mapSize := uint64(ms)
+
+ if mapSize > 0 {
+ // This must result in a page-aligned offset. i.e., the original
+ // phdr.Off must have the same alignment as phdr.Vaddr. If that is not
+ // true, MMap will reject the mapping.
+ fileOffset := phdr.Off - adjust
+
+ prot := progFlagsAsPerms(phdr.Flags)
+ mopts := memmap.MMapOpts{
+ Length: mapSize,
+ Offset: fileOffset,
+ Addr: addr,
+ Fixed: true,
+ // Linux will happily allow conflicting segments to map over
+ // one another.
+ Unmap: true,
+ Private: true,
+ Perms: prot,
+ MaxPerms: usermem.AnyAccess,
+ }
+ defer func() {
+ if mopts.MappingIdentity != nil {
+ mopts.MappingIdentity.DecRef()
+ }
+ }()
+ if err := f.ConfigureMMap(ctx, &mopts); err != nil {
+ ctx.Infof("File is not memory-mappable: %v", err)
+ return err
+ }
+ if _, err := m.MMap(ctx, mopts); err != nil {
+ ctx.Infof("Error mapping PT_LOAD segment %+v at %#x: %v", phdr, addr, err)
+ return err
+ }
+
+ // We need to clear the end of the last page that exceeds fileSize so
+ // we don't map part of the file beyond fileSize.
+ //
+ // Note that Linux *does not* clear the portion of the first page
+ // before phdr.Off.
+ if mapSize > fileSize {
+ zeroAddr, ok := addr.AddLength(fileSize)
+ if !ok {
+ panic(fmt.Sprintf("successfully mmaped address overflows? %#x + %#x", addr, fileSize))
+ }
+ zeroSize := int64(mapSize - fileSize)
+ if zeroSize < 0 {
+ panic(fmt.Sprintf("zeroSize too big? %#x", uint64(zeroSize)))
+ }
+ if _, err := m.ZeroOut(ctx, zeroAddr, zeroSize, usermem.IOOpts{IgnorePermissions: true}); err != nil {
+ ctx.Warningf("Failed to zero end of page [%#x, %#x): %v", zeroAddr, zeroAddr+usermem.Addr(zeroSize), err)
+ return err
+ }
+ }
+ }
+
+ memSize := phdr.Memsz + adjust
+ if memSize < phdr.Memsz {
+ ctx.Infof("Computed segment mem size overflows: %#x + %#x", phdr.Memsz, adjust)
+ return syserror.ENOEXEC
+ }
+
+ // Allocate more anonymous pages if necessary.
+ if mapSize < memSize {
+ anonAddr, ok := addr.AddLength(mapSize)
+ if !ok {
+ panic(fmt.Sprintf("anonymous memory doesn't fit in pre-sized range? %#x + %#x", addr, mapSize))
+ }
+ anonSize, ok := usermem.Addr(memSize - mapSize).RoundUp()
+ if !ok {
+ ctx.Infof("extra anon pages too large: %#x", memSize-mapSize)
+ return syserror.ENOEXEC
+ }
+
+ if _, err := m.MMap(ctx, memmap.MMapOpts{
+ Length: uint64(anonSize),
+ Addr: anonAddr,
+ // Fixed without Unmap will fail the mmap if something is
+ // already at addr.
+ Fixed: true,
+ Private: true,
+ // N.B. Linux uses vm_brk to map these pages, ignoring
+ // the segment protections, instead always mapping RW.
+ // These pages are not included in the final brk
+ // region.
+ Perms: usermem.ReadWrite,
+ MaxPerms: usermem.AnyAccess,
+ }); err != nil {
+ ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// loadedELF describes an ELF that has been successfully loaded.
+type loadedELF struct {
+ // os is the target OS of the ELF.
+ os abi.OS
+
+ // arch is the target architecture of the ELF.
+ arch arch.Arch
+
+ // entry is the entry point of the ELF.
+ entry usermem.Addr
+
+ // start is the end of the ELF.
+ start usermem.Addr
+
+ // end is the end of the ELF.
+ end usermem.Addr
+
+ // interpter is the path to the ELF interpreter.
+ interpreter string
+
+ // phdrAddr is the address of the ELF program headers.
+ phdrAddr usermem.Addr
+
+ // phdrSize is the size of a single program header in the ELF.
+ phdrSize int
+
+ // phdrNum is the number of program headers.
+ phdrNum int
+
+ // auxv contains a subset of ELF-specific auxiliary vector entries:
+ // * AT_PHDR
+ // * AT_PHENT
+ // * AT_PHNUM
+ // * AT_BASE
+ // * AT_ENTRY
+ auxv arch.Auxv
+}
+
+// loadParsedELF loads f into mm.
+//
+// info is the parsed elfInfo from the header.
+//
+// It does not load the ELF interpreter, or return any auxv entries.
+//
+// Preconditions:
+// * f is an ELF file
+func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
+ first := true
+ var start, end usermem.Addr
+ var interpreter string
+ for _, phdr := range info.phdrs {
+ switch phdr.Type {
+ case elf.PT_LOAD:
+ vaddr := usermem.Addr(phdr.Vaddr)
+ if first {
+ first = false
+ start = vaddr
+ }
+ if vaddr < end {
+ ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+ var ok bool
+ end, ok = vaddr.AddLength(phdr.Memsz)
+ if !ok {
+ ctx.Infof("PT_LOAD header size overflows. %#x + %#x", vaddr, phdr.Memsz)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ case elf.PT_INTERP:
+ if phdr.Filesz < 2 {
+ ctx.Infof("PT_INTERP path too small: %v", phdr.Filesz)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+ if phdr.Filesz > linux.PATH_MAX {
+ ctx.Infof("PT_INTERP path too big: %v", phdr.Filesz)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ path := make([]byte, phdr.Filesz)
+ _, err := readFull(ctx, f, usermem.BytesIOSequence(path), int64(phdr.Off))
+ if err != nil {
+ // If an interpreter was specified, it should exist.
+ ctx.Infof("Error reading PT_INTERP path: %v", err)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ if path[len(path)-1] != 0 {
+ ctx.Infof("PT_INTERP path not NUL-terminated: %v", path)
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ // Strip NUL-terminator and everything beyond from
+ // string. Note that there may be a NUL-terminator
+ // before len(path)-1.
+ interpreter = string(path[:bytes.IndexByte(path, '\x00')])
+ if interpreter == "" {
+ // Linux actually attempts to open_exec("\0").
+ // open_exec -> do_open_execat fails to check
+ // that name != '\0' before calling
+ // do_filp_open, which thus opens the working
+ // directory. do_open_execat returns EACCES
+ // because the directory is not a regular file.
+ //
+ // We bypass that nonsense and simply
+ // short-circuit with EACCES. Those this does
+ // mean that there may be some edge cases where
+ // the open path would return a different
+ // error.
+ ctx.Infof("PT_INTERP path is empty: %v", path)
+ return loadedELF{}, syserror.EACCES
+ }
+ }
+ }
+
+ // Shared objects don't have fixed load addresses. We need to pick a
+ // base address big enough to fit all segments, so we first create a
+ // mapping for the total size just to find a region that is big enough.
+ //
+ // It is safe to unmap it immediately with racing with another mapping
+ // because we are the only one in control of the MemoryManager.
+ //
+ // Note that the vaddr of the first PT_LOAD segment is ignored when
+ // choosing the load address (even if it is non-zero). The vaddr does
+ // become an offset from that load address.
+ var offset usermem.Addr
+ if info.sharedObject {
+ totalSize := end - start
+ totalSize, ok := totalSize.RoundUp()
+ if !ok {
+ ctx.Infof("ELF PT_LOAD segments too big")
+ return loadedELF{}, syserror.ENOEXEC
+ }
+
+ var err error
+ offset, err = m.MMap(ctx, memmap.MMapOpts{
+ Length: uint64(totalSize),
+ Addr: sharedLoadOffset,
+ Private: true,
+ })
+ if err != nil {
+ ctx.Infof("Error allocating address space for shared object: %v", err)
+ return loadedELF{}, err
+ }
+ if err := m.MUnmap(ctx, offset, uint64(totalSize)); err != nil {
+ panic(fmt.Sprintf("Failed to unmap base address: %v", err))
+ }
+
+ start, ok = start.AddLength(uint64(offset))
+ if !ok {
+ panic(fmt.Sprintf("Start %#x + offset %#x overflows?", start, offset))
+ }
+
+ end, ok = end.AddLength(uint64(offset))
+ if !ok {
+ panic(fmt.Sprintf("End %#x + offset %#x overflows?", end, offset))
+ }
+
+ info.entry, ok = info.entry.AddLength(uint64(offset))
+ if !ok {
+ ctx.Infof("Entrypoint %#x + offset %#x overflows? Is the entrypoint within a segment?", info.entry, offset)
+ return loadedELF{}, err
+ }
+ }
+
+ // Map PT_LOAD segments.
+ for _, phdr := range info.phdrs {
+ switch phdr.Type {
+ case elf.PT_LOAD:
+ if phdr.Memsz == 0 {
+ // No need to load segments with size 0, but
+ // they exist in some binaries.
+ continue
+ }
+
+ if err := mapSegment(ctx, m, f, &phdr, offset); err != nil {
+ ctx.Infof("Failed to map PT_LOAD segment: %+v", phdr)
+ return loadedELF{}, err
+ }
+ }
+ }
+
+ // This assumes that the first segment contains the ELF headers. This
+ // may not be true in a malformed ELF, but Linux makes the same
+ // assumption.
+ phdrAddr, ok := start.AddLength(info.phdrOff)
+ if !ok {
+ ctx.Warningf("ELF start address %#x + phdr offset %#x overflows", start, info.phdrOff)
+ phdrAddr = 0
+ }
+
+ return loadedELF{
+ os: info.os,
+ arch: info.arch,
+ entry: info.entry,
+ start: start,
+ end: end,
+ interpreter: interpreter,
+ phdrAddr: phdrAddr,
+ phdrSize: info.phdrSize,
+ phdrNum: len(info.phdrs),
+ }, nil
+}
+
+// loadInitialELF loads f into mm.
+//
+// It creates an arch.Context for the ELF and prepares the mm for this arch.
+//
+// It does not load the ELF interpreter, or return any auxv entries.
+//
+// Preconditions:
+// * f is an ELF file
+// * f is the first ELF loaded into m
+func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) {
+ info, err := parseHeader(ctx, f)
+ if err != nil {
+ ctx.Infof("Failed to parse initial ELF: %v", err)
+ return loadedELF{}, nil, err
+ }
+
+ // Create the arch.Context now so we can prepare the mmap layout before
+ // mapping anything.
+ ac := arch.New(info.arch, fs)
+
+ l, err := m.SetMmapLayout(ac, limits.FromContext(ctx))
+ if err != nil {
+ ctx.Warningf("Failed to set mmap layout: %v", err)
+ return loadedELF{}, nil, err
+ }
+
+ // PIELoadAddress tries to move the ELF out of the way of the default
+ // mmap base to ensure that the initial brk has sufficient space to
+ // grow.
+ le, err := loadParsedELF(ctx, m, f, info, ac.PIELoadAddress(l))
+ return le, ac, err
+}
+
+// loadInterpreterELF loads f into mm.
+//
+// The interpreter must be for the same OS/Arch as the initial ELF.
+//
+// It does not return any auxv entries.
+//
+// Preconditions:
+// * f is an ELF file
+func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, initial loadedELF) (loadedELF, error) {
+ info, err := parseHeader(ctx, f)
+ if err != nil {
+ if err == syserror.ENOEXEC {
+ // Bad interpreter.
+ err = syserror.ELIBBAD
+ }
+ return loadedELF{}, err
+ }
+
+ if info.os != initial.os {
+ ctx.Infof("Initial ELF OS %v and interpreter ELF OS %v differ", initial.os, info.os)
+ return loadedELF{}, syserror.ELIBBAD
+ }
+ if info.arch != initial.arch {
+ ctx.Infof("Initial ELF arch %v and interpreter ELF arch %v differ", initial.arch, info.arch)
+ return loadedELF{}, syserror.ELIBBAD
+ }
+
+ // The interpreter is not given a load offset, as its location does not
+ // affect brk.
+ return loadParsedELF(ctx, m, f, info, 0)
+}
+
+// loadELF loads f into the Task address space.
+//
+// If loadELF returns ErrSwitchFile it should be called again with the returned
+// path and argv.
+//
+// Preconditions:
+// * f is an ELF file
+func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) {
+ bin, ac, err := loadInitialELF(ctx, m, fs, f)
+ if err != nil {
+ ctx.Infof("Error loading binary: %v", err)
+ return loadedELF{}, nil, err
+ }
+
+ var interp loadedELF
+ if bin.interpreter != "" {
+ d, i, err := openPath(ctx, mounts, root, wd, maxTraversals, bin.interpreter)
+ if err != nil {
+ ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err)
+ return loadedELF{}, nil, err
+ }
+ defer i.DecRef()
+ // We don't need the Dirent.
+ d.DecRef()
+
+ interp, err = loadInterpreterELF(ctx, m, i, bin)
+ if err != nil {
+ ctx.Infof("Error loading interpreter: %v", err)
+ return loadedELF{}, nil, err
+ }
+
+ if interp.interpreter != "" {
+ // No recursive interpreters!
+ ctx.Infof("Interpreter requires an interpreter")
+ return loadedELF{}, nil, syserror.ENOEXEC
+ }
+ }
+
+ // ELF-specific auxv entries.
+ bin.auxv = arch.Auxv{
+ arch.AuxEntry{linux.AT_PHDR, bin.phdrAddr},
+ arch.AuxEntry{linux.AT_PHENT, usermem.Addr(bin.phdrSize)},
+ arch.AuxEntry{linux.AT_PHNUM, usermem.Addr(bin.phdrNum)},
+ arch.AuxEntry{linux.AT_ENTRY, bin.entry},
+ }
+ if bin.interpreter != "" {
+ bin.auxv = append(bin.auxv, arch.AuxEntry{linux.AT_BASE, interp.start})
+
+ // Start in the interpreter.
+ // N.B. AT_ENTRY above contains the *original* entry point.
+ bin.entry = interp.entry
+ } else {
+ // Always add AT_BASE even if there is no interpreter.
+ bin.auxv = append(bin.auxv, arch.AuxEntry{linux.AT_BASE, 0})
+ }
+
+ return bin, ac, nil
+}
diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go
new file mode 100644
index 000000000..b88062ae5
--- /dev/null
+++ b/pkg/sentry/loader/interpreter.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package loader
+
+import (
+ "bytes"
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const (
+ // interpreterScriptMagic identifies an interpreter script.
+ interpreterScriptMagic = "#!"
+
+ // interpMaxLineLength is the maximum length for the first line of an
+ // interpreter script.
+ //
+ // From execve(2): "A maximum line length of 127 characters is allowed
+ // for the first line in a #! executable shell script."
+ interpMaxLineLength = 127
+)
+
+// parseInterpreterScript returns the interpreter path and argv.
+func parseInterpreterScript(ctx context.Context, filename string, f *fs.File, argv []string) (newpath string, newargv []string, err error) {
+ line := make([]byte, interpMaxLineLength)
+ n, err := readFull(ctx, f, usermem.BytesIOSequence(line), 0)
+ // Short read is OK.
+ if err != nil && err != io.ErrUnexpectedEOF {
+ if err == io.EOF {
+ err = syserror.ENOEXEC
+ }
+ return "", []string{}, err
+ }
+ line = line[:n]
+
+ if !bytes.Equal(line[:2], []byte(interpreterScriptMagic)) {
+ return "", []string{}, syserror.ENOEXEC
+ }
+ // Ignore #!.
+ line = line[2:]
+
+ // Ignore everything after newline.
+ // Linux silently truncates the remainder of the line if it exceeds
+ // interpMaxLineLength.
+ i := bytes.IndexByte(line, '\n')
+ if i > 0 {
+ line = line[:i]
+ }
+
+ // Skip any whitespace before the interpeter.
+ line = bytes.TrimLeft(line, " \t")
+
+ // Linux only looks for spaces or tabs delimiting the interpreter and
+ // arg.
+ //
+ // execve(2): "On Linux, the entire string following the interpreter
+ // name is passed as a single argument to the interpreter, and this
+ // string can include white space."
+ interp := line
+ var arg []byte
+ i = bytes.IndexAny(line, " \t")
+ if i >= 0 {
+ interp = line[:i]
+ arg = bytes.TrimLeft(line[i:], " \t")
+ }
+
+ if string(interp) == "" {
+ ctx.Infof("Interpreter script contains no interpreter: %v", line)
+ return "", []string{}, syserror.ENOEXEC
+ }
+
+ // Build the new argument list:
+ //
+ // 1. The interpreter.
+ newargv = append(newargv, string(interp))
+
+ // 2. The optional interpreter argument.
+ if len(arg) > 0 {
+ newargv = append(newargv, string(arg))
+ }
+
+ // 3. The original arguments. The original argv[0] is replaced with the
+ // full script filename.
+ if len(argv) > 0 {
+ argv[0] = filename
+ } else {
+ argv = []string{filename}
+ }
+ newargv = append(newargv, argv...)
+
+ return string(interp), newargv, nil
+}
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
new file mode 100644
index 000000000..dc1a52398
--- /dev/null
+++ b/pkg/sentry/loader/loader.go
@@ -0,0 +1,283 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package loader loads a binary into a MemoryManager.
+package loader
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "path"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// readFull behaves like io.ReadFull for an *fs.File.
+func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.Preadv(ctx, dst, offset+total)
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// openPath opens name for loading.
+//
+// openPath returns the fs.Dirent and an *fs.File for name, which is not
+// installed in the Task FDMap. The caller takes ownership of both.
+//
+// name must be a readable, executable, regular file.
+func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, name string) (*fs.Dirent, *fs.File, error) {
+ if name == "" {
+ ctx.Infof("cannot open empty name")
+ return nil, nil, syserror.ENOENT
+ }
+
+ d, err := mm.FindInode(ctx, root, wd, name, maxTraversals)
+ if err != nil {
+ return nil, nil, err
+ }
+ defer d.DecRef()
+
+ perms := fs.PermMask{
+ // TODO(gvisor.dev/issue/160): Linux requires only execute
+ // permission, not read. However, our backing filesystems may
+ // prevent us from reading the file without read permission.
+ //
+ // Additionally, a task with a non-readable executable has
+ // additional constraints on access via ptrace and procfs.
+ Read: true,
+ Execute: true,
+ }
+ if err := d.Inode.CheckPermission(ctx, perms); err != nil {
+ return nil, nil, err
+ }
+
+ // If they claim it's a directory, then make sure.
+ //
+ // N.B. we reject directories below, but we must first reject
+ // non-directories passed as directories.
+ if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) {
+ return nil, nil, syserror.ENOTDIR
+ }
+
+ // No exec-ing directories, pipes, etc!
+ if !fs.IsRegular(d.Inode.StableAttr) {
+ ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr)
+ return nil, nil, syserror.EACCES
+ }
+
+ // Create a new file.
+ file, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // We must be able to read at arbitrary offsets.
+ if !file.Flags().Pread {
+ file.DecRef()
+ ctx.Infof("%s cannot be read at an offset: %+v", name, file.Flags())
+ return nil, nil, syserror.EACCES
+ }
+
+ // Grab a reference for the caller.
+ d.IncRef()
+ return d, file, nil
+}
+
+// allocStack allocates and maps a stack in to any available part of the address space.
+func allocStack(ctx context.Context, m *mm.MemoryManager, a arch.Context) (*arch.Stack, error) {
+ ar, err := m.MapStack(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return &arch.Stack{a, m, ar.End}, nil
+}
+
+const (
+ // maxLoaderAttempts is the maximum number of attempts to try to load
+ // an interpreter scripts, to prevent loops. 6 (initial + 5 changes) is
+ // what the Linux kernel allows (fs/exec.c:search_binary_handler).
+ maxLoaderAttempts = 6
+)
+
+// loadPath resolves filename to a binary and loads it.
+//
+// It returns:
+// * loadedELF, description of the loaded binary
+// * arch.Context matching the binary arch
+// * fs.Dirent of the binary file
+// * Possibly updated argv
+func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, fs *cpuid.FeatureSet, filename string, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) {
+ for i := 0; i < maxLoaderAttempts; i++ {
+ d, f, err := openPath(ctx, mounts, root, wd, remainingTraversals, filename)
+ if err != nil {
+ ctx.Infof("Error opening %s: %v", filename, err)
+ return loadedELF{}, nil, nil, nil, err
+ }
+ defer f.DecRef()
+ // We will return d in the successful case, but defer a DecRef
+ // for intermediate loops and failure cases.
+ defer d.DecRef()
+
+ // Check the header. Is this an ELF or interpreter script?
+ var hdr [4]uint8
+ // N.B. We assume that reading from a regular file cannot block.
+ _, err = readFull(ctx, f, usermem.BytesIOSequence(hdr[:]), 0)
+ // Allow unexpected EOF, as a valid executable could be only three
+ // bytes (e.g., #!a).
+ if err != nil && err != io.ErrUnexpectedEOF {
+ if err == io.EOF {
+ err = syserror.ENOEXEC
+ }
+ return loadedELF{}, nil, nil, nil, err
+ }
+
+ switch {
+ case bytes.Equal(hdr[:], []byte(elfMagic)):
+ loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, fs, f)
+ if err != nil {
+ ctx.Infof("Error loading ELF: %v", err)
+ return loadedELF{}, nil, nil, nil, err
+ }
+ // An ELF is always terminal. Hold on to d.
+ d.IncRef()
+ return loaded, ac, d, argv, err
+ case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)):
+ newpath, newargv, err := parseInterpreterScript(ctx, filename, f, argv)
+ if err != nil {
+ ctx.Infof("Error loading interpreter script: %v", err)
+ return loadedELF{}, nil, nil, nil, err
+ }
+ filename = newpath
+ argv = newargv
+ default:
+ ctx.Infof("Unknown magic: %v", hdr)
+ return loadedELF{}, nil, nil, nil, syserror.ENOEXEC
+ }
+ }
+
+ return loadedELF{}, nil, nil, nil, syserror.ELOOP
+}
+
+// Load loads filename into a MemoryManager.
+//
+// If Load returns ErrSwitchFile it should be called again with the returned
+// path and argv.
+//
+// Preconditions:
+// * The Task MemoryManager is empty.
+// * Load is called on the Task goroutine.
+func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
+ // Load the binary itself.
+ loaded, ac, d, argv, err := loadPath(ctx, m, mounts, root, wd, maxTraversals, fs, filename, argv)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux())
+ }
+ defer d.DecRef()
+
+ // Load the VDSO.
+ vdsoAddr, err := loadVDSO(ctx, m, vdso, loaded)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Setup the heap. brk starts at the next page after the end of the
+ // binary. Userspace can assume that the remainer of the page after
+ // loaded.end is available for its use.
+ e, ok := loaded.end.RoundUp()
+ if !ok {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC)
+ }
+ m.BrkSetup(ctx, e)
+
+ // Allocate our stack.
+ stack, err := allocStack(ctx, m, ac)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to allocate stack: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Push the original filename to the stack, for AT_EXECFN.
+ execfn, err := stack.Push(filename)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Push 16 random bytes on the stack which AT_RANDOM will point to.
+ var b [16]byte
+ if _, err := rand.Read(b[:]); err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to read random bytes: %v", err), syserr.FromError(err).ToLinux())
+ }
+ random, err := stack.Push(b)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push random bytes: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ c := auth.CredentialsFromContext(ctx)
+
+ // Add generic auxv entries.
+ auxv := append(loaded.auxv, arch.Auxv{
+ arch.AuxEntry{linux.AT_UID, usermem.Addr(c.RealKUID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC},
+ arch.AuxEntry{linux.AT_EXECFN, execfn},
+ arch.AuxEntry{linux.AT_RANDOM, random},
+ arch.AuxEntry{linux.AT_PAGESZ, usermem.PageSize},
+ arch.AuxEntry{linux.AT_SYSINFO_EHDR, vdsoAddr},
+ }...)
+ auxv = append(auxv, extraAuxv...)
+
+ sl, err := stack.Load(argv, envv, auxv)
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load stack: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ m.SetArgvStart(sl.ArgvStart)
+ m.SetArgvEnd(sl.ArgvEnd)
+ m.SetEnvvStart(sl.EnvvStart)
+ m.SetEnvvEnd(sl.EnvvEnd)
+ m.SetAuxv(auxv)
+ m.SetExecutable(d)
+
+ ac.SetIP(uintptr(loaded.entry))
+ ac.SetStack(uintptr(stack.Bottom))
+
+ name := path.Base(filename)
+ if len(name) > linux.TASK_COMM_LEN-1 {
+ name = name[:linux.TASK_COMM_LEN-1]
+ }
+
+ return loaded.os, ac, name, nil
+}
diff --git a/pkg/sentry/loader/loader_state_autogen.go b/pkg/sentry/loader/loader_state_autogen.go
new file mode 100755
index 000000000..351aff778
--- /dev/null
+++ b/pkg/sentry/loader/loader_state_autogen.go
@@ -0,0 +1,57 @@
+// automatically generated by stateify.
+
+package loader
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *VDSO) beforeSave() {}
+func (x *VDSO) save(m state.Map) {
+ x.beforeSave()
+ var phdrs []elfProgHeader = x.savePhdrs()
+ m.SaveValue("phdrs", phdrs)
+ m.Save("ParamPage", &x.ParamPage)
+ m.Save("vdso", &x.vdso)
+ m.Save("os", &x.os)
+ m.Save("arch", &x.arch)
+}
+
+func (x *VDSO) afterLoad() {}
+func (x *VDSO) load(m state.Map) {
+ m.Load("ParamPage", &x.ParamPage)
+ m.Load("vdso", &x.vdso)
+ m.Load("os", &x.os)
+ m.Load("arch", &x.arch)
+ m.LoadValue("phdrs", new([]elfProgHeader), func(y interface{}) { x.loadPhdrs(y.([]elfProgHeader)) })
+}
+
+func (x *elfProgHeader) beforeSave() {}
+func (x *elfProgHeader) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Type", &x.Type)
+ m.Save("Flags", &x.Flags)
+ m.Save("Off", &x.Off)
+ m.Save("Vaddr", &x.Vaddr)
+ m.Save("Paddr", &x.Paddr)
+ m.Save("Filesz", &x.Filesz)
+ m.Save("Memsz", &x.Memsz)
+ m.Save("Align", &x.Align)
+}
+
+func (x *elfProgHeader) afterLoad() {}
+func (x *elfProgHeader) load(m state.Map) {
+ m.Load("Type", &x.Type)
+ m.Load("Flags", &x.Flags)
+ m.Load("Off", &x.Off)
+ m.Load("Vaddr", &x.Vaddr)
+ m.Load("Paddr", &x.Paddr)
+ m.Load("Filesz", &x.Filesz)
+ m.Load("Memsz", &x.Memsz)
+ m.Load("Align", &x.Align)
+}
+
+func init() {
+ state.Register("loader.VDSO", (*VDSO)(nil), state.Fns{Save: (*VDSO).save, Load: (*VDSO).load})
+ state.Register("loader.elfProgHeader", (*elfProgHeader)(nil), state.Fns{Save: (*elfProgHeader).save, Load: (*elfProgHeader).load})
+}
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
new file mode 100644
index 000000000..4e73527cf
--- /dev/null
+++ b/pkg/sentry/loader/vdso.go
@@ -0,0 +1,402 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package loader
+
+import (
+ "debug/elf"
+ "fmt"
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/anon"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type fileContext struct {
+ context.Context
+}
+
+func (f *fileContext) Value(key interface{}) interface{} {
+ switch key {
+ case uniqueid.CtxGlobalUniqueID:
+ return uint64(0)
+ default:
+ return f.Context.Value(key)
+ }
+}
+
+// byteReader implements fs.FileOperations for reading from a []byte source.
+type byteReader struct {
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ data []byte
+}
+
+var _ fs.FileOperations = (*byteReader)(nil)
+
+// newByteReaderFile creates a fake file to read data from.
+func newByteReaderFile(data []byte) *fs.File {
+ // Create a fake inode.
+ inode := fs.NewInode(
+ &fsutil.SimpleFileInode{},
+ fs.NewPseudoMountSource(),
+ fs.StableAttr{
+ Type: fs.Anonymous,
+ DeviceID: anon.PseudoDevice.DeviceID(),
+ InodeID: anon.PseudoDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ })
+
+ // Use the fake inode to create a fake dirent.
+ dirent := fs.NewTransientDirent(inode)
+ defer dirent.DecRef()
+
+ // Use the fake dirent to make a fake file.
+ flags := fs.FileFlags{Read: true, Pread: true}
+ return fs.NewFile(&fileContext{Context: context.Background()}, dirent, flags, &byteReader{
+ data: data,
+ })
+}
+
+func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset >= int64(len(b.data)) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, b.data[offset:])
+ return int64(n), err
+}
+
+func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ panic("Write not supported")
+}
+
+// validateVDSO checks that the VDSO can be loaded by loadVDSO.
+//
+// VDSOs are special (see below). Since we are going to map the VDSO directly
+// rather than using a normal loading process, we require that the PT_LOAD
+// segments have the same layout in the ELF as they expect to have in memory.
+//
+// Namely, this means that we must verify:
+// * PT_LOAD file offsets are equivalent to the memory offset from the first
+// segment.
+// * No extra zeroed space (memsz) is required.
+// * PT_LOAD segments are in order.
+// * No two PT_LOAD segments occupy parts of the same page.
+// * PT_LOAD segments don't extend beyond the end of the file.
+//
+// ctx may be nil if f does not need it.
+func validateVDSO(ctx context.Context, f *fs.File, size uint64) (elfInfo, error) {
+ info, err := parseHeader(ctx, f)
+ if err != nil {
+ log.Infof("Unable to parse VDSO header: %v", err)
+ return elfInfo{}, err
+ }
+
+ var first *elf.ProgHeader
+ var prev *elf.ProgHeader
+ var prevEnd usermem.Addr
+ for i, phdr := range info.phdrs {
+ if phdr.Type != elf.PT_LOAD {
+ continue
+ }
+
+ if first == nil {
+ first = &info.phdrs[i]
+ if phdr.Off != 0 {
+ log.Warningf("First PT_LOAD segment has non-zero file offset")
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ }
+
+ memoryOffset := phdr.Vaddr - first.Vaddr
+ if memoryOffset != phdr.Off {
+ log.Warningf("PT_LOAD segment memory offset %#x != file offset %#x", memoryOffset, phdr.Off)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ // memsz larger than filesz means that extra zeroed space should be
+ // provided at the end of the segment. Since we are mapping the ELF
+ // directly, we don't want to just overwrite part of the ELF with
+ // zeroes.
+ if phdr.Memsz != phdr.Filesz {
+ log.Warningf("PT_LOAD segment memsz %#x != filesz %#x", phdr.Memsz, phdr.Filesz)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ start := usermem.Addr(memoryOffset)
+ end, ok := start.AddLength(phdr.Memsz)
+ if !ok {
+ log.Warningf("PT_LOAD segment size overflows: %#x + %#x", start, end)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ if uint64(end) > size {
+ log.Warningf("PT_LOAD segment end %#x extends beyond end of file %#x", end, size)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ if prev != nil {
+ if start < prevEnd {
+ log.Warningf("PT_LOAD segments out of order")
+ return elfInfo{}, syserror.ENOEXEC
+ }
+
+ // We mprotect entire pages, so each segment must be in
+ // its own page.
+ prevEndPage := prevEnd.RoundDown()
+ startPage := start.RoundDown()
+ if prevEndPage >= startPage {
+ log.Warningf("PT_LOAD segments share a page: %#x", prevEndPage)
+ return elfInfo{}, syserror.ENOEXEC
+ }
+ }
+ prev = &info.phdrs[i]
+ prevEnd = end
+ }
+
+ return info, nil
+}
+
+// VDSO describes a VDSO.
+//
+// NOTE(mpratt): to support multiple architectures or operating systems, this
+// would need to contain a VDSO for each.
+//
+// +stateify savable
+type VDSO struct {
+ // ParamPage is the VDSO parameter page. This page should be updated to
+ // inform the VDSO for timekeeping data.
+ ParamPage *mm.SpecialMappable
+
+ // vdso is the VDSO ELF itself.
+ vdso *mm.SpecialMappable
+
+ // os is the operating system targeted by the VDSO.
+ os abi.OS
+
+ // arch is the architecture targeted by the VDSO.
+ arch arch.Arch
+
+ // phdrs are the VDSO ELF phdrs.
+ phdrs []elf.ProgHeader `state:".([]elfProgHeader)"`
+}
+
+// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
+// param page for updating by the kernel.
+func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
+ vdsoFile := newByteReaderFile(vdsoBin)
+
+ // First make sure the VDSO is valid. vdsoFile does not use ctx, so a
+ // nil context can be passed.
+ info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin)))
+ vdsoFile.DecRef()
+ if err != nil {
+ return nil, err
+ }
+
+ // Then copy it into a VDSO mapping.
+ size, ok := usermem.Addr(len(vdsoBin)).RoundUp()
+ if !ok {
+ return nil, fmt.Errorf("VDSO size overflows? %#x", len(vdsoBin))
+ }
+
+ mf := mfp.MemoryFile()
+ vdso, err := mf.Allocate(uint64(size), usage.System)
+ if err != nil {
+ return nil, fmt.Errorf("unable to allocate VDSO memory: %v", err)
+ }
+
+ ims, err := mf.MapInternal(vdso, usermem.ReadWrite)
+ if err != nil {
+ mf.DecRef(vdso)
+ return nil, fmt.Errorf("unable to map VDSO memory: %v", err)
+ }
+
+ _, err = safemem.CopySeq(ims, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(vdsoBin)))
+ if err != nil {
+ mf.DecRef(vdso)
+ return nil, fmt.Errorf("unable to copy VDSO into memory: %v", err)
+ }
+
+ // Finally, allocate a param page for this VDSO.
+ paramPage, err := mf.Allocate(usermem.PageSize, usage.System)
+ if err != nil {
+ mf.DecRef(vdso)
+ return nil, fmt.Errorf("unable to allocate VDSO param page: %v", err)
+ }
+
+ return &VDSO{
+ ParamPage: mm.NewSpecialMappable("[vvar]", mfp, paramPage),
+ // TODO(gvisor.dev/issue/157): Don't advertise the VDSO, as
+ // some applications may not be able to handle multiple [vdso]
+ // hints.
+ vdso: mm.NewSpecialMappable("", mfp, vdso),
+ phdrs: info.phdrs,
+ }, nil
+}
+
+// loadVDSO loads the VDSO into m.
+//
+// VDSOs are special.
+//
+// VDSOs are fully position independent. However, instead of loading a VDSO
+// like a normal ELF binary, mapping only the PT_LOAD segments, the Linux
+// kernel simply directly maps the entire file into process memory, with very
+// little real ELF parsing.
+//
+// NOTE(b/25323870): This means that userspace can, and unfortunately does,
+// depend on parts of the ELF that would normally not be mapped. To maintain
+// compatibility with such binaries, we load the VDSO much like Linux.
+//
+// loadVDSO takes a reference on the VDSO and parameter page FrameRegions.
+func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) (usermem.Addr, error) {
+ if v.os != bin.os {
+ ctx.Warningf("Binary ELF OS %v and VDSO ELF OS %v differ", bin.os, v.os)
+ return 0, syserror.ENOEXEC
+ }
+ if v.arch != bin.arch {
+ ctx.Warningf("Binary ELF arch %v and VDSO ELF arch %v differ", bin.arch, v.arch)
+ return 0, syserror.ENOEXEC
+ }
+
+ // Reserve address space for the VDSO and its parameter page, which is
+ // mapped just before the VDSO.
+ mapSize := v.vdso.Length() + v.ParamPage.Length()
+ addr, err := m.MMap(ctx, memmap.MMapOpts{
+ Length: mapSize,
+ Private: true,
+ })
+ if err != nil {
+ ctx.Infof("Unable to reserve VDSO address space: %v", err)
+ return 0, err
+ }
+
+ // Now map the param page.
+ _, err = m.MMap(ctx, memmap.MMapOpts{
+ Length: v.ParamPage.Length(),
+ MappingIdentity: v.ParamPage,
+ Mappable: v.ParamPage,
+ Addr: addr,
+ Fixed: true,
+ Unmap: true,
+ Private: true,
+ Perms: usermem.Read,
+ MaxPerms: usermem.Read,
+ })
+ if err != nil {
+ ctx.Infof("Unable to map VDSO param page: %v", err)
+ return 0, err
+ }
+
+ // Now map the VDSO itself.
+ vdsoAddr, ok := addr.AddLength(v.ParamPage.Length())
+ if !ok {
+ panic(fmt.Sprintf("Part of mapped range overflows? %#x + %#x", addr, v.ParamPage.Length()))
+ }
+ _, err = m.MMap(ctx, memmap.MMapOpts{
+ Length: v.vdso.Length(),
+ MappingIdentity: v.vdso,
+ Mappable: v.vdso,
+ Addr: vdsoAddr,
+ Fixed: true,
+ Unmap: true,
+ Private: true,
+ Perms: usermem.Read,
+ MaxPerms: usermem.AnyAccess,
+ })
+ if err != nil {
+ ctx.Infof("Unable to map VDSO: %v", err)
+ return 0, err
+ }
+
+ vdsoEnd, ok := vdsoAddr.AddLength(v.vdso.Length())
+ if !ok {
+ panic(fmt.Sprintf("VDSO mapping overflows? %#x + %#x", vdsoAddr, v.vdso.Length()))
+ }
+
+ // Set additional protections for the individual segments.
+ var first *elf.ProgHeader
+ for i, phdr := range v.phdrs {
+ if phdr.Type != elf.PT_LOAD {
+ continue
+ }
+
+ if first == nil {
+ first = &v.phdrs[i]
+ }
+
+ memoryOffset := phdr.Vaddr - first.Vaddr
+ segAddr, ok := vdsoAddr.AddLength(memoryOffset)
+ if !ok {
+ ctx.Warningf("PT_LOAD segment address overflows: %#x + %#x", segAddr, memoryOffset)
+ return 0, syserror.ENOEXEC
+ }
+ segPage := segAddr.RoundDown()
+ segSize := usermem.Addr(phdr.Memsz)
+ segSize, ok = segSize.AddLength(segAddr.PageOffset())
+ if !ok {
+ ctx.Warningf("PT_LOAD segment memsize %#x + offset %#x overflows", phdr.Memsz, segAddr.PageOffset())
+ return 0, syserror.ENOEXEC
+ }
+ segSize, ok = segSize.RoundUp()
+ if !ok {
+ ctx.Warningf("PT_LOAD segment size overflows: %#x", phdr.Memsz+segAddr.PageOffset())
+ return 0, syserror.ENOEXEC
+ }
+ segEnd, ok := segPage.AddLength(uint64(segSize))
+ if !ok {
+ ctx.Warningf("PT_LOAD segment range overflows: %#x + %#x", segAddr, segSize)
+ return 0, syserror.ENOEXEC
+ }
+ if segEnd > vdsoEnd {
+ ctx.Warningf("PT_LOAD segment ends beyond VDSO: %#x > %#x", segEnd, vdsoEnd)
+ return 0, syserror.ENOEXEC
+ }
+
+ perms := progFlagsAsPerms(phdr.Flags)
+ if perms != usermem.Read {
+ if err := m.MProtect(segPage, uint64(segSize), perms, false); err != nil {
+ ctx.Warningf("Unable to set PT_LOAD segment protections %+v at [%#x, %#x): %v", perms, segAddr, segEnd, err)
+ return 0, syserror.ENOEXEC
+ }
+ }
+ }
+
+ return vdsoAddr, nil
+}
diff --git a/pkg/sentry/loader/vdso_bin.go b/pkg/sentry/loader/vdso_bin.go
new file mode 100755
index 000000000..cf351f9dc
--- /dev/null
+++ b/pkg/sentry/loader/vdso_bin.go
@@ -0,0 +1,5 @@
+// Generated by go_embed_data for //pkg/sentry/loader:vdso_bin. DO NOT EDIT.
+
+package loader
+
+var vdsoBin = []byte("ELF\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00>\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x008\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00p\xff\xff\xff\xff\xff\x96\x00\x00\x00\x00\x00\x00\x96\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00x\x00\x00\x00\x00\x00\x00xp\xff\xff\xff\xff\xffxp\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$\x00\x00\x00\x00\x00\x00$p\xff\xff\xff\xff\xff$p\xff\xff\xff\xff\xff@\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00P\xe5td\x00\x00\x00d\x00\x00\x00\x00\x00\x00dp\xff\xff\xff\xff\xffdp\xff\xff\xff\xff\xff<\x00\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00@p\xff\xff\xff\xff\xff(\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00 \x00pp\xff\xff\xff\xff\xffl\x00\x00\x00\x00\x00\x00\x00D\x00\x00\x00\"\x00 \x00@p\xff\xff\xff\xff\xff(\x00\x00\x00\x00\x00\x00\x00R\x00\x00\x00\x00 \x00\xe0p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00^\x00\x00\x00\"\x00 \x00pp\xff\xff\xff\xff\xffl\x00\x00\x00\x00\x00\x00\x00k\x00\x00\x00\x00 \x00p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00y\x00\x00\x00\"\x00 \x00\xe0p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00~\x00\x00\x00\"\x00 \x00p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00linux-vdso.so.1\x00LINUX_2.6\x00__vdso_clock_gettime\x00__vdso_gettimeofday\x00clock_gettime\x00__vdso_time\x00gettimeofday\x00__vdso_getcpu\x00time\x00getcpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa1\xbf\xee \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf6u\xae\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x00GNU\x00gold 1.11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00GNU\x00\x8e\xf8\x94\x9b\xe5\xdd[.\xba\xd9T\xc5\xdcFc.Jd;8\x00\x00\x00\x00\x00\x00\xdc \x00\x00T\x00\x00\x00 \x00\x00l\x00\x00\x00| \x00\x00\x9c\x00\x00\x00\xac \x00\x00\xbc\x00\x00\x00\xbc \x00\x00\xd4\x00\x00\x00|\x00\x00\xf4\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00zR\x00x \x90\x00\x00\x00\x00\x00\x00\x00\x00\x80 \x00\x00(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,\x00\x00\x004\x00\x00\x00\x98 \x00\x00l\x00\x00\x00\x00A\x86A\x83G0X\nAAE \x00\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\xd8 \x00\x00\"\x00\x00\x00\x00A\x83G XA\x00\x00\x00\x00\x84\x00\x00\x00\xe8 \x00\x00\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x9c\x00\x00\x00\xe0 \x00\x00\xb3\x00\x00\x00\x00A\x83\nh J\x00\x00\x00\xbc\x00\x00\x00\x80 \x00\x00\xb6\x00\x00\x00\x00A\x83\nh M\x00\x00\x00\x00\x00\x00\x00`p\xff\xff\xff\xff\xff \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\xff\xffo\x00\x00\x00\x00\xd6p\xff\xff\xff\xff\xff\xfc\xff\xffo\x00\x00\x00\x00\xecp\xff\xff\xff\xff\xff\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x85\xfft\x83\xffuH\x89\xf7\xe9\x8f\x00\x00\x80\x00\x00\x00\x00\xb8\xe4\x00\x00\x00\xc3H\x89\xf7\xe9\xb8\x00\x00\x00\x84\x00\x00\x00\x00\x00USH\x89\xf3H\x83\xecH\x85\xfft;H\x89\xfdH\x89\xe7\xe8\x97\x00\x00\x00\x85\xc0u@H\x8b$H\x8bL$H\xba\xcf\xf7S㥛\xc4 H\x89E\x00H\x89\xc8H\xc1\xf9?H\xf7\xeaH\xc1\xfaH)\xcaH\x89UH\x85\xdbt\xc7\x00\x00\x00\x00\xc7C\x00\x00\x00\x001\xc0H\x83\xc4[]\xc3@\x001\xc0\xeb\xf1@\x00SH\x89\xfbH\x83\xecH\x89\xe7\xe80\x00\x00\x00H\x85\xdbH\x8b$tH\x89H\x83\xc4[\xc3@\x00f.\x84\x00\x00\x00\x00\x00\xb85\x00\x00H\x98Ð\x90\x90\x90\x90\x90SH\x89\xfeH\x8d \xd5\xde\xff\xffH\x8b9\x83\xe7\xfeH\x8bY(L\x8bA8Lc\xcfL\x8bY0L\x8bQ@\xae\xe81H\x8b9L9\xcfu\xddH\x85\xdbtrH\x89щ\xc0H\xc1\xe1 H \xc11\xc0I9\xcb(H\xb8\x00\x00\x00\x00\x00ʚ;1\xd2L)\xd9I\xf7\xf2H\x89\xcfH\xc1\xff?H\xaf\xf8H\xf7\xe1H\xfaH\xac\xd0 I\x8d<\x00H\xb9SZ\x9b\xa0/\xb8D\x00[H\x89\xf8H\xc1\xe8 H\xf7\xe1H\x89\xd0H\xc1\xe8 H\x89Hi\xc0\x00ʚ;H)\xc71\xc0H\x89~\xc3\x00\xb8\xe4\x00\x00\x001\xff[\xc3\x00f.\x84\x00\x00\x00\x00\x00SH\x89\xfeH\x8d \xde\xff\xffH\x8b9\x83\xe7\xfeH\x8bYL\x8bALc\xcfL\x8bYL\x8bQ \xae\xe81H\x8b9L9\xcfu\xddH\x85\xdbtrH\x89щ\xc0H\xc1\xe1 H \xc11\xc0I9\xcb(H\xb8\x00\x00\x00\x00\x00ʚ;1\xd2L)\xd9I\xf7\xf2H\x89\xcfH\xc1\xff?H\xaf\xf8H\xf7\xe1H\xfaH\xac\xd0 I\x8d<\x00H\xb9SZ\x9b\xa0/\xb8D\x00[H\x89\xf8H\xc1\xe8 H\xf7\xe1H\x89\xd0H\xc1\xe8 H\x89Hi\xc0\x00ʚ;H)\xc71\xc0H\x89~\xc3\x00\xb8\xe4\x00\x00\x00\xbf\x00\x00\x00[\xc3\x00GCC: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00 p\xff\xff\xff\xff\xff\xb3\x00\x00\x00\x00\x00\x00\x009\x00\x00\x00\x00 \x00\xe0p\xff\xff\xff\xff\xff\xb6\x00\x00\x00\x00\x00\x00\x00]\x00\x00\x00 \x00xp\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00f\x00\x00\x00\x00\x00\xf1\xff\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00s\x00\x00\x00\x00\x00\xf1\xff\x00\xf0o\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00{\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00 \x00@p\xff\xff\xff\xff\xff(\x00\x00\x00\x00\x00\x00\x00\x9a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xb0\x00\x00\x00\x00 \x00pp\xff\xff\xff\xff\xffl\x00\x00\x00\x00\x00\x00\x00\xc4\x00\x00\x00\"\x00 \x00@p\xff\xff\xff\xff\xff(\x00\x00\x00\x00\x00\x00\x00\xd2\x00\x00\x00\x00 \x00\xe0p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00\xde\x00\x00\x00\"\x00 \x00pp\xff\xff\xff\xff\xffl\x00\x00\x00\x00\x00\x00\x00\xeb\x00\x00\x00\x00 \x00p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\xf9\x00\x00\x00\"\x00 \x00\xe0p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00\xfe\x00\x00\x00\"\x00 \x00p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x00vdso.cc\x00vdso_time.cc\x00_ZN4vdso13ClockRealtimeEP8timespec\x00_ZN4vdso14ClockMonotonicEP8timespec\x00_DYNAMIC\x00VDSO_PRELINK\x00_params\x00LINUX_2.6\x00__vdso_clock_gettime\x00_GLOBAL_OFFSET_TABLE_\x00__vdso_gettimeofday\x00clock_gettime\x00__vdso_time\x00gettimeofday\x00__vdso_getcpu\x00time\x00getcpu\x00\x00.text\x00.comment\x00.bss\x00.dynstr\x00.eh_frame_hdr\x00.gnu.version\x00.dynsym\x00.hash\x00.note\x00.eh_frame\x00.gnu.version_d\x00.dynamic\x00.shstrtab\x00.strtab\x00.symtab\x00.data\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff \x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x008\x00\x00\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`p\xff\xff\xff\xff\xff`\x00\x00\x00\x00\x00\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xffP\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00+\x00\x00\x00\xff\xff\xffo\x00\x00\x00\x00\x00\x00\x00\xd6p\xff\xff\xff\xff\xff\xd6\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00V\x00\x00\x00\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00\xecp\xff\xff\xff\xff\xff\xec\x00\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00F\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$p\xff\xff\xff\xff\xff$\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00dp\xff\xff\xff\xff\xffd\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00L\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa0p\xff\xff\xff\xff\xff\xa0\x00\x00\x00\x00\x00\x00\xd8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00xp\xff\xff\xff\xff\xffx\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x88\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff@\x00\x00\x00\x00\x00\x00V\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x96\x00\x00\x00\x00\x00\x006\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xd0\x00\x00\x00\x00\x00\x00\xb0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00x\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00\x00\x8e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
diff --git a/pkg/sentry/loader/vdso_state.go b/pkg/sentry/loader/vdso_state.go
new file mode 100644
index 000000000..db378e90a
--- /dev/null
+++ b/pkg/sentry/loader/vdso_state.go
@@ -0,0 +1,48 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package loader
+
+import (
+ "debug/elf"
+)
+
+// +stateify savable
+type elfProgHeader struct {
+ Type elf.ProgType
+ Flags elf.ProgFlag
+ Off uint64
+ Vaddr uint64
+ Paddr uint64
+ Filesz uint64
+ Memsz uint64
+ Align uint64
+}
+
+// savePhdrs is invoked by stateify.
+func (v *VDSO) savePhdrs() []elfProgHeader {
+ s := make([]elfProgHeader, 0, len(v.phdrs))
+ for _, h := range v.phdrs {
+ s = append(s, elfProgHeader(h))
+ }
+ return s
+}
+
+// loadPhdrs is invoked by stateify.
+func (v *VDSO) loadPhdrs(s []elfProgHeader) {
+ v.phdrs = make([]elf.ProgHeader, 0, len(s))
+ for _, h := range s {
+ v.phdrs = append(v.phdrs, elf.ProgHeader(h))
+ }
+}
diff --git a/pkg/sentry/memmap/mappable_range.go b/pkg/sentry/memmap/mappable_range.go
new file mode 100755
index 000000000..6b6c2c685
--- /dev/null
+++ b/pkg/sentry/memmap/mappable_range.go
@@ -0,0 +1,62 @@
+package memmap
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type MappableRange struct {
+ // Start is the inclusive start of the range.
+ Start uint64
+
+ // End is the exclusive end of the range.
+ End uint64
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r MappableRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r MappableRange) Length() uint64 {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r MappableRange) Contains(x uint64) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r MappableRange) Overlaps(r2 MappableRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r MappableRange) IsSupersetOf(r2 MappableRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r MappableRange) Intersect(r2 MappableRange) MappableRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r MappableRange) CanSplitAt(x uint64) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/sentry/memmap/mapping_set.go b/pkg/sentry/memmap/mapping_set.go
new file mode 100644
index 000000000..3cf2b338f
--- /dev/null
+++ b/pkg/sentry/memmap/mapping_set.go
@@ -0,0 +1,253 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memmap
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// MappingSet maps offsets into a Mappable to mappings of those offsets. It is
+// used to implement Mappable.AddMapping and RemoveMapping for Mappables that
+// may need to call MappingSpace.Invalidate.
+//
+// type MappingSet <generated by go_generics>
+
+// MappingsOfRange is the value type of MappingSet, and represents the set of
+// all mappings of the corresponding MappableRange.
+//
+// Using a map offers O(1) lookups in RemoveMapping and
+// mappingSetFunctions.Merge.
+type MappingsOfRange map[MappingOfRange]struct{}
+
+// MappingOfRange represents a mapping of a MappableRange.
+//
+// +stateify savable
+type MappingOfRange struct {
+ MappingSpace MappingSpace
+ AddrRange usermem.AddrRange
+ Writable bool
+}
+
+func (r MappingOfRange) invalidate(opts InvalidateOpts) {
+ r.MappingSpace.Invalidate(r.AddrRange, opts)
+}
+
+// String implements fmt.Stringer.String.
+func (r MappingOfRange) String() string {
+ return fmt.Sprintf("%#v", r.AddrRange)
+}
+
+// mappingSetFunctions implements segment.Functions for MappingSet.
+type mappingSetFunctions struct{}
+
+// MinKey implements segment.Functions.MinKey.
+func (mappingSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+// MaxKey implements segment.Functions.MaxKey.
+func (mappingSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+// ClearValue implements segment.Functions.ClearValue.
+func (mappingSetFunctions) ClearValue(v *MappingsOfRange) {
+ *v = MappingsOfRange{}
+}
+
+// Merge implements segment.Functions.Merge.
+//
+// Since each value is a map of MappingOfRanges, values can only be merged if
+// all MappingOfRanges in each map have an exact pair in the other map, forming
+// one contiguous region.
+func (mappingSetFunctions) Merge(r1 MappableRange, val1 MappingsOfRange, r2 MappableRange, val2 MappingsOfRange) (MappingsOfRange, bool) {
+ if len(val1) != len(val2) {
+ return nil, false
+ }
+
+ merged := make(MappingsOfRange, len(val1))
+
+ // Each MappingOfRange in val1 must have a matching region in val2, forming
+ // one contiguous region.
+ for k1 := range val1 {
+ // We expect val2 to to contain a key that forms a contiguous
+ // region with k1.
+ k2 := MappingOfRange{
+ MappingSpace: k1.MappingSpace,
+ AddrRange: usermem.AddrRange{
+ Start: k1.AddrRange.End,
+ End: k1.AddrRange.End + usermem.Addr(r2.Length()),
+ },
+ Writable: k1.Writable,
+ }
+ if _, ok := val2[k2]; !ok {
+ return nil, false
+ }
+
+ // OK. Add it to the merged map.
+ merged[MappingOfRange{
+ MappingSpace: k1.MappingSpace,
+ AddrRange: usermem.AddrRange{
+ Start: k1.AddrRange.Start,
+ End: k2.AddrRange.End,
+ },
+ Writable: k1.Writable,
+ }] = struct{}{}
+ }
+
+ return merged, true
+}
+
+// Split implements segment.Functions.Split.
+func (mappingSetFunctions) Split(r MappableRange, val MappingsOfRange, split uint64) (MappingsOfRange, MappingsOfRange) {
+ if split <= r.Start || split >= r.End {
+ panic(fmt.Sprintf("split is not within range %v", r))
+ }
+
+ m1 := make(MappingsOfRange, len(val))
+ m2 := make(MappingsOfRange, len(val))
+
+ // split is a value in MappableRange, we need the offset into the
+ // corresponding MappingsOfRange.
+ offset := usermem.Addr(split - r.Start)
+ for k := range val {
+ k1 := MappingOfRange{
+ MappingSpace: k.MappingSpace,
+ AddrRange: usermem.AddrRange{
+ Start: k.AddrRange.Start,
+ End: k.AddrRange.Start + offset,
+ },
+ Writable: k.Writable,
+ }
+ m1[k1] = struct{}{}
+
+ k2 := MappingOfRange{
+ MappingSpace: k.MappingSpace,
+ AddrRange: usermem.AddrRange{
+ Start: k.AddrRange.Start + offset,
+ End: k.AddrRange.End,
+ },
+ Writable: k.Writable,
+ }
+ m2[k2] = struct{}{}
+ }
+
+ return m1, m2
+}
+
+// subsetMapping returns the MappingOfRange that maps subsetRange, given that
+// ms maps wholeRange beginning at addr.
+//
+// For instance, suppose wholeRange = [0x0, 0x2000) and addr = 0x4000,
+// indicating that ms maps addresses [0x4000, 0x6000) to MappableRange [0x0,
+// 0x2000). Then for subsetRange = [0x1000, 0x2000), subsetMapping returns a
+// MappingOfRange for which AddrRange = [0x5000, 0x6000).
+func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr usermem.Addr, writable bool) MappingOfRange {
+ if !wholeRange.IsSupersetOf(subsetRange) {
+ panic(fmt.Sprintf("%v is not a superset of %v", wholeRange, subsetRange))
+ }
+
+ offset := subsetRange.Start - wholeRange.Start
+ start := addr + usermem.Addr(offset)
+ return MappingOfRange{
+ MappingSpace: ms,
+ AddrRange: usermem.AddrRange{
+ Start: start,
+ End: start + usermem.Addr(subsetRange.Length()),
+ },
+ Writable: writable,
+ }
+}
+
+// AddMapping adds the given mapping and returns the set of MappableRanges that
+// previously had no mappings.
+//
+// Preconditions: As for Mappable.AddMapping.
+func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange {
+ mr := MappableRange{offset, offset + uint64(ar.Length())}
+ var mapped []MappableRange
+ seg, gap := s.Find(mr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < mr.End:
+ seg = s.Isolate(seg, mr)
+ seg.Value()[subsetMapping(mr, seg.Range(), ms, ar.Start, writable)] = struct{}{}
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok() && gap.Start() < mr.End:
+ gapMR := gap.Range().Intersect(mr)
+ mapped = append(mapped, gapMR)
+ // Insert a set and continue from the above case.
+ seg, gap = s.Insert(gap, gapMR, make(MappingsOfRange)), MappingGapIterator{}
+
+ default:
+ return mapped
+ }
+ }
+}
+
+// RemoveMapping removes the given mapping and returns the set of
+// MappableRanges that now have no mappings.
+//
+// Preconditions: As for Mappable.RemoveMapping.
+func (s *MappingSet) RemoveMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange {
+ mr := MappableRange{offset, offset + uint64(ar.Length())}
+ var unmapped []MappableRange
+
+ seg := s.FindSegment(mr.Start)
+ if !seg.Ok() {
+ panic(fmt.Sprintf("MappingSet.RemoveMapping(%v): no segment containing %#x: %v", mr, mr.Start, s))
+ }
+ for seg.Ok() && seg.Start() < mr.End {
+ // Ensure this segment is limited to our range.
+ seg = s.Isolate(seg, mr)
+
+ // Remove this part of the mapping.
+ mappings := seg.Value()
+ delete(mappings, subsetMapping(mr, seg.Range(), ms, ar.Start, writable))
+
+ if len(mappings) == 0 {
+ unmapped = append(unmapped, seg.Range())
+ seg = s.Remove(seg).NextSegment()
+ } else {
+ seg = seg.NextSegment()
+ }
+ }
+ s.MergeAdjacent(mr)
+ return unmapped
+}
+
+// Invalidate calls MappingSpace.Invalidate for all mappings of offsets in mr.
+func (s *MappingSet) Invalidate(mr MappableRange, opts InvalidateOpts) {
+ for seg := s.LowerBoundSegment(mr.Start); seg.Ok() && seg.Start() < mr.End; seg = seg.NextSegment() {
+ segMR := seg.Range()
+ for m := range seg.Value() {
+ region := subsetMapping(segMR, segMR.Intersect(mr), m.MappingSpace, m.AddrRange.Start, m.Writable)
+ region.invalidate(opts)
+ }
+ }
+}
+
+// InvalidateAll calls MappingSpace.Invalidate for all mappings of s.
+func (s *MappingSet) InvalidateAll(opts InvalidateOpts) {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ for m := range seg.Value() {
+ m.invalidate(opts)
+ }
+ }
+}
diff --git a/pkg/sentry/memmap/mapping_set_impl.go b/pkg/sentry/memmap/mapping_set_impl.go
new file mode 100755
index 000000000..eb3071e89
--- /dev/null
+++ b/pkg/sentry/memmap/mapping_set_impl.go
@@ -0,0 +1,1270 @@
+package memmap
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ MappingminDegree = 3
+
+ MappingmaxDegree = 2 * MappingminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type MappingSet struct {
+ root Mappingnode `state:".(*MappingSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *MappingSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *MappingSet) IsEmptyRange(r MappableRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *MappingSet) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *MappingSet) SpanRange(r MappableRange) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *MappingSet) FirstSegment() MappingIterator {
+ if s.root.nrSegments == 0 {
+ return MappingIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *MappingSet) LastSegment() MappingIterator {
+ if s.root.nrSegments == 0 {
+ return MappingIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *MappingSet) FirstGap() MappingGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return MappingGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *MappingSet) LastGap() MappingGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return MappingGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *MappingSet) Find(key uint64) (MappingIterator, MappingGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return MappingIterator{n, i}, MappingGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return MappingIterator{}, MappingGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *MappingSet) FindSegment(key uint64) MappingIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *MappingSet) LowerBoundSegment(min uint64) MappingIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *MappingSet) UpperBoundSegment(max uint64) MappingIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *MappingSet) FindGap(key uint64) MappingGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *MappingSet) LowerBoundGap(min uint64) MappingGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *MappingSet) UpperBoundGap(max uint64) MappingGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *MappingSet) Add(r MappableRange, val MappingsOfRange) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *MappingSet) AddWithoutMerging(r MappableRange, val MappingsOfRange) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *MappingSet) Insert(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (mappingSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *MappingSet) InsertWithoutMerging(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *MappingSet) InsertWithoutMergingUnchecked(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return MappingIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *MappingSet) Remove(seg MappingIterator) MappingGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ mappingSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(MappingGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *MappingSet) RemoveAll() {
+ s.root = Mappingnode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *MappingSet) RemoveRange(r MappableRange) MappingGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *MappingSet) Merge(first, second MappingIterator) MappingIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *MappingSet) MergeUnchecked(first, second MappingIterator) MappingIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (mappingSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return MappingIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *MappingSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *MappingSet) MergeRange(r MappableRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *MappingSet) MergeAdjacent(r MappableRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *MappingSet) Split(seg MappingIterator, split uint64) (MappingIterator, MappingIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *MappingSet) SplitUnchecked(seg MappingIterator, split uint64) (MappingIterator, MappingIterator) {
+ val1, val2 := (mappingSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), MappableRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *MappingSet) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *MappingSet) Isolate(seg MappingIterator, r MappableRange) MappingIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *MappingSet) ApplyContiguous(r MappableRange, fn func(seg MappingIterator)) MappingGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return MappingGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return MappingGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type Mappingnode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *Mappingnode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [MappingmaxDegree - 1]MappableRange
+ values [MappingmaxDegree - 1]MappingsOfRange
+ children [MappingmaxDegree]*Mappingnode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *Mappingnode) firstSegment() MappingIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return MappingIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *Mappingnode) lastSegment() MappingIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return MappingIterator{n, n.nrSegments - 1}
+}
+
+func (n *Mappingnode) prevSibling() *Mappingnode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *Mappingnode) nextSibling() *Mappingnode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *Mappingnode) rebalanceBeforeInsert(gap MappingGapIterator) MappingGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < MappingmaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &Mappingnode{
+ nrSegments: MappingminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &Mappingnode{
+ nrSegments: MappingminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:MappingminDegree-1], n.keys[:MappingminDegree-1])
+ copy(left.values[:MappingminDegree-1], n.values[:MappingminDegree-1])
+ copy(right.keys[:MappingminDegree-1], n.keys[MappingminDegree:])
+ copy(right.values[:MappingminDegree-1], n.values[MappingminDegree:])
+ n.keys[0], n.values[0] = n.keys[MappingminDegree-1], n.values[MappingminDegree-1]
+ MappingzeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:MappingminDegree], n.children[:MappingminDegree])
+ copy(right.children[:MappingminDegree], n.children[MappingminDegree:])
+ MappingzeroNodeSlice(n.children[2:])
+ for i := 0; i < MappingminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < MappingminDegree {
+ return MappingGapIterator{left, gap.index}
+ }
+ return MappingGapIterator{right, gap.index - MappingminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[MappingminDegree-1], n.values[MappingminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &Mappingnode{
+ nrSegments: MappingminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:MappingminDegree-1], n.keys[MappingminDegree:])
+ copy(sibling.values[:MappingminDegree-1], n.values[MappingminDegree:])
+ MappingzeroValueSlice(n.values[MappingminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:MappingminDegree], n.children[MappingminDegree:])
+ MappingzeroNodeSlice(n.children[MappingminDegree:])
+ for i := 0; i < MappingminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = MappingminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < MappingminDegree {
+ return gap
+ }
+ return MappingGapIterator{sibling, gap.index - MappingminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *Mappingnode) rebalanceAfterRemove(gap MappingGapIterator) MappingGapIterator {
+ for {
+ if n.nrSegments >= MappingminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= MappingminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ mappingSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return MappingGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return MappingGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= MappingminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ mappingSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return MappingGapIterator{n, n.nrSegments}
+ }
+ return MappingGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return MappingGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return MappingGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *Mappingnode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = MappingGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ mappingSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type MappingIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *Mappingnode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg MappingIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg MappingIterator) Range() MappableRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg MappingIterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg MappingIterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg MappingIterator) SetRangeUnchecked(r MappableRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg MappingIterator) SetRange(r MappableRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg MappingIterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg MappingIterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg MappingIterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg MappingIterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg MappingIterator) Value() MappingsOfRange {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg MappingIterator) ValuePtr() *MappingsOfRange {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg MappingIterator) SetValue(val MappingsOfRange) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg MappingIterator) PrevSegment() MappingIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return MappingIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return MappingIterator{}
+ }
+ return MappingsegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg MappingIterator) NextSegment() MappingIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return MappingIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return MappingIterator{}
+ }
+ return MappingsegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg MappingIterator) PrevGap() MappingGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return MappingGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg MappingIterator) NextGap() MappingGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return MappingGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg MappingIterator) PrevNonEmpty() (MappingIterator, MappingGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return MappingIterator{}, gap
+ }
+ return gap.PrevSegment(), MappingGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg MappingIterator) NextNonEmpty() (MappingIterator, MappingGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return MappingIterator{}, gap
+ }
+ return gap.NextSegment(), MappingGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type MappingGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *Mappingnode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap MappingGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap MappingGapIterator) Range() MappableRange {
+ return MappableRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap MappingGapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return mappingSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap MappingGapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return mappingSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap MappingGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap MappingGapIterator) PrevSegment() MappingIterator {
+ return MappingsegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap MappingGapIterator) NextSegment() MappingIterator {
+ return MappingsegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap MappingGapIterator) PrevGap() MappingGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return MappingGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap MappingGapIterator) NextGap() MappingGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return MappingGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func MappingsegmentBeforePosition(n *Mappingnode, i int) MappingIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return MappingIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return MappingIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func MappingsegmentAfterPosition(n *Mappingnode, i int) MappingIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return MappingIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return MappingIterator{n, i}
+}
+
+func MappingzeroValueSlice(slice []MappingsOfRange) {
+
+ for i := range slice {
+ mappingSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func MappingzeroNodeSlice(slice []*Mappingnode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *MappingSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *Mappingnode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *Mappingnode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type MappingSegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []MappingsOfRange
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *MappingSet) ExportSortedSlices() *MappingSegmentDataSlices {
+ var sds MappingSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *MappingSet) ImportSortedSlices(sds *MappingSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := MappableRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *MappingSet) saveRoot() *MappingSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *MappingSet) loadRoot(sds *MappingSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
new file mode 100644
index 000000000..0106c857d
--- /dev/null
+++ b/pkg/sentry/memmap/memmap.go
@@ -0,0 +1,361 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memmap defines semantics for memory mappings.
+package memmap
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Mappable represents a memory-mappable object, a mutable mapping from uint64
+// offsets to (platform.File, uint64 File offset) pairs.
+//
+// See mm/mm.go for Mappable's place in the lock order.
+//
+// Preconditions: For all Mappable methods, usermem.AddrRanges and
+// MappableRanges must be non-empty (Length() != 0), and usermem.Addrs and
+// Mappable offsets must be page-aligned.
+type Mappable interface {
+ // AddMapping notifies the Mappable of a mapping from addresses ar in ms to
+ // offsets [offset, offset+ar.Length()) in this Mappable.
+ //
+ // The writable flag indicates whether the backing data for a Mappable can
+ // be modified through the mapping. Effectively, this means a shared mapping
+ // where Translate may be called with at.Write == true. This is a property
+ // established at mapping creation and must remain constant throughout the
+ // lifetime of the mapping.
+ //
+ // Preconditions: offset+ar.Length() does not overflow.
+ AddMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error
+
+ // RemoveMapping notifies the Mappable of the removal of a mapping from
+ // addresses ar in ms to offsets [offset, offset+ar.Length()) in this
+ // Mappable.
+ //
+ // Preconditions: offset+ar.Length() does not overflow. The removed mapping
+ // must exist. writable must match the corresponding call to AddMapping.
+ RemoveMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool)
+
+ // CopyMapping notifies the Mappable of an attempt to copy a mapping in ms
+ // from srcAR to dstAR. For most Mappables, this is equivalent to
+ // AddMapping. Note that it is possible that srcAR.Length() != dstAR.Length(),
+ // and also that srcAR.Length() == 0.
+ //
+ // CopyMapping is only called when a mapping is copied within a given
+ // MappingSpace; it is analogous to Linux's vm_operations_struct::mremap.
+ //
+ // Preconditions: offset+srcAR.Length() and offset+dstAR.Length() do not
+ // overflow. The mapping at srcAR must exist. writable must match the
+ // corresponding call to AddMapping.
+ CopyMapping(ctx context.Context, ms MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error
+
+ // Translate returns the Mappable's current mappings for at least the range
+ // of offsets specified by required, and at most the range of offsets
+ // specified by optional. at is the set of access types that may be
+ // performed using the returned Translations. If not all required offsets
+ // are translated, it returns a non-nil error explaining why.
+ //
+ // Translations are valid until invalidated by a callback to
+ // MappingSpace.Invalidate or until the caller removes its mapping of the
+ // translated range. Mappable implementations must ensure that at least one
+ // reference is held on all pages in a platform.File that may be the result
+ // of a valid Translation.
+ //
+ // Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
+ // required and optional must be page-aligned. The caller must have
+ // established a mapping for all of the queried offsets via a previous call
+ // to AddMapping. The caller is responsible for ensuring that calls to
+ // Translate synchronize with invalidation.
+ //
+ // Postconditions: See CheckTranslateResult.
+ Translate(ctx context.Context, required, optional MappableRange, at usermem.AccessType) ([]Translation, error)
+
+ // InvalidateUnsavable requests that the Mappable invalidate Translations
+ // that cannot be preserved across save/restore.
+ //
+ // Invariant: InvalidateUnsavable never races with concurrent calls to any
+ // other Mappable methods.
+ InvalidateUnsavable(ctx context.Context) error
+}
+
+// Translations are returned by Mappable.Translate.
+type Translation struct {
+ // Source is the translated range in the Mappable.
+ Source MappableRange
+
+ // File is the mapped file.
+ File platform.File
+
+ // Offset is the offset into File at which this Translation begins.
+ Offset uint64
+
+ // Perms is the set of permissions for which platform.AddressSpace.MapFile
+ // and platform.AddressSpace.MapInternal on this Translation is permitted.
+ Perms usermem.AccessType
+}
+
+// FileRange returns the platform.FileRange represented by t.
+func (t Translation) FileRange() platform.FileRange {
+ return platform.FileRange{t.Offset, t.Offset + t.Source.Length()}
+}
+
+// CheckTranslateResult returns an error if (ts, terr) does not satisfy all
+// postconditions for Mappable.Translate(required, optional, at).
+//
+// Preconditions: As for Mappable.Translate.
+func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error {
+ // Verify that the inputs to Mappable.Translate were valid.
+ if !required.WellFormed() || required.Length() <= 0 {
+ panic(fmt.Sprintf("invalid required range: %v", required))
+ }
+ if !usermem.Addr(required.Start).IsPageAligned() || !usermem.Addr(required.End).IsPageAligned() {
+ panic(fmt.Sprintf("unaligned required range: %v", required))
+ }
+ if !optional.IsSupersetOf(required) {
+ panic(fmt.Sprintf("optional range %v is not a superset of required range %v", optional, required))
+ }
+ if !usermem.Addr(optional.Start).IsPageAligned() || !usermem.Addr(optional.End).IsPageAligned() {
+ panic(fmt.Sprintf("unaligned optional range: %v", optional))
+ }
+
+ // The first Translation must include required.Start.
+ if len(ts) != 0 && !ts[0].Source.Contains(required.Start) {
+ return fmt.Errorf("first Translation %+v does not cover start of required range %v", ts[0], required)
+ }
+ for i, t := range ts {
+ if !t.Source.WellFormed() || t.Source.Length() <= 0 {
+ return fmt.Errorf("Translation %+v has invalid Source", t)
+ }
+ if !usermem.Addr(t.Source.Start).IsPageAligned() || !usermem.Addr(t.Source.End).IsPageAligned() {
+ return fmt.Errorf("Translation %+v has unaligned Source", t)
+ }
+ if t.File == nil {
+ return fmt.Errorf("Translation %+v has nil File", t)
+ }
+ if !usermem.Addr(t.Offset).IsPageAligned() {
+ return fmt.Errorf("Translation %+v has unaligned Offset", t)
+ }
+ // Translations must be contiguous and in increasing order of
+ // Translation.Source.
+ if i > 0 && ts[i-1].Source.End != t.Source.Start {
+ return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t)
+ }
+ // At least part of each Translation must be required.
+ if t.Source.Intersect(required).Length() == 0 {
+ return fmt.Errorf("Translation %+v lies entirely outside required range %v", t, required)
+ }
+ // Translations must be constrained to the optional range.
+ if !optional.IsSupersetOf(t.Source) {
+ return fmt.Errorf("Translation %+v lies outside optional range %v", t, optional)
+ }
+ // Each Translation must permit a superset of requested accesses.
+ if !t.Perms.SupersetOf(at) {
+ return fmt.Errorf("Translation %+v does not permit all requested access types %v", t, at)
+ }
+ }
+ // If the set of Translations does not cover the entire required range,
+ // Translate must return a non-nil error explaining why.
+ if terr == nil {
+ if len(ts) == 0 {
+ return fmt.Errorf("no Translations and no error")
+ }
+ if t := ts[len(ts)-1]; !t.Source.Contains(required.End - 1) {
+ return fmt.Errorf("last Translation %+v does not reach end of required range %v, but Translate returned no error", t, required)
+ }
+ }
+ return nil
+}
+
+// BusError may be returned by implementations of Mappable.Translate for errors
+// that should result in SIGBUS delivery if they cause application page fault
+// handling to fail.
+type BusError struct {
+ // Err is the original error.
+ Err error
+}
+
+// Error implements error.Error.
+func (b *BusError) Error() string {
+ return fmt.Sprintf("BusError: %v", b.Err.Error())
+}
+
+// MappableRange represents a range of uint64 offsets into a Mappable.
+//
+// type MappableRange <generated using go_generics>
+
+// String implements fmt.Stringer.String.
+func (mr MappableRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", mr.Start, mr.End)
+}
+
+// MappingSpace represents a mutable mapping from usermem.Addrs to (Mappable,
+// uint64 offset) pairs.
+type MappingSpace interface {
+ // Invalidate is called to notify the MappingSpace that values returned by
+ // previous calls to Mappable.Translate for offsets mapped by addresses in
+ // ar are no longer valid.
+ //
+ // Invalidate must not take any locks preceding mm.MemoryManager.activeMu
+ // in the lock order.
+ //
+ // Preconditions: ar.Length() != 0. ar must be page-aligned.
+ Invalidate(ar usermem.AddrRange, opts InvalidateOpts)
+}
+
+// InvalidateOpts holds options to MappingSpace.Invalidate.
+type InvalidateOpts struct {
+ // InvalidatePrivate is true if private pages in the invalidated region
+ // should also be discarded, causing their data to be lost.
+ InvalidatePrivate bool
+}
+
+// MappingIdentity controls the lifetime of a Mappable, and provides
+// information about the Mappable for /proc/[pid]/maps. It is distinct from
+// Mappable because all Mappables that are coherent must compare equal to
+// support the implementation of shared futexes, but different
+// MappingIdentities may represent the same Mappable, in the same way that
+// multiple fs.Files may represent the same fs.Inode. (This similarity is not
+// coincidental; fs.File implements MappingIdentity, and some
+// fs.InodeOperations implement Mappable.)
+type MappingIdentity interface {
+ // MappingIdentity is reference-counted.
+ refs.RefCounter
+
+ // MappedName returns the application-visible name shown in
+ // /proc/[pid]/maps.
+ MappedName(ctx context.Context) string
+
+ // DeviceID returns the device number shown in /proc/[pid]/maps.
+ DeviceID() uint64
+
+ // InodeID returns the inode number shown in /proc/[pid]/maps.
+ InodeID() uint64
+
+ // Msync has the same semantics as fs.FileOperations.Fsync(ctx,
+ // int64(mr.Start), int64(mr.End-1), fs.SyncData).
+ // (fs.FileOperations.Fsync() takes an inclusive end, but mr.End is
+ // exclusive, hence mr.End-1.) It is defined rather than Fsync so that
+ // implementors don't need to depend on the fs package for fs.SyncType.
+ Msync(ctx context.Context, mr MappableRange) error
+}
+
+// MLockMode specifies the memory locking behavior of a memory mapping.
+type MLockMode int
+
+// Note that the ordering of MLockModes is significant; see
+// mm.MemoryManager.defMLockMode.
+const (
+ // MLockNone specifies that a mapping has no memory locking behavior.
+ //
+ // This must be the zero value for MLockMode.
+ MLockNone MLockMode = iota
+
+ // MLockEager specifies that a mapping is memory-locked, as by mlock() or
+ // similar. Pages in the mapping should be made, and kept, resident in
+ // physical memory as soon as possible.
+ //
+ // As of this writing, MLockEager does not cause memory-locking to be
+ // requested from the host; it only affects the sentry's memory management
+ // behavior.
+ //
+ // MLockEager is analogous to Linux's VM_LOCKED.
+ MLockEager
+
+ // MLockLazy specifies that a mapping is memory-locked, as by mlock() or
+ // similar. Pages in the mapping should be kept resident in physical memory
+ // once they have been made resident due to e.g. a page fault.
+ //
+ // As of this writing, MLockLazy does not cause memory-locking to be
+ // requested from the host; in fact, it has virtually no effect, except for
+ // interactions between mlocked pages and other syscalls.
+ //
+ // MLockLazy is analogous to Linux's VM_LOCKED | VM_LOCKONFAULT.
+ MLockLazy
+)
+
+// MMapOpts specifies a request to create a memory mapping.
+type MMapOpts struct {
+ // Length is the length of the mapping.
+ Length uint64
+
+ // MappingIdentity controls the lifetime of Mappable, and provides
+ // properties of the mapping shown in /proc/[pid]/maps. If MMapOpts is used
+ // to successfully create a memory mapping, a reference is taken on
+ // MappingIdentity.
+ MappingIdentity MappingIdentity
+
+ // Mappable is the Mappable to be mapped. If Mappable is nil, the mapping
+ // is anonymous. If Mappable is not nil, it must remain valid as long as a
+ // reference is held on MappingIdentity.
+ Mappable Mappable
+
+ // Offset is the offset into Mappable to map. If Mappable is nil, Offset is
+ // ignored.
+ Offset uint64
+
+ // Addr is the suggested address for the mapping.
+ Addr usermem.Addr
+
+ // Fixed specifies whether this is a fixed mapping (it must be located at
+ // Addr).
+ Fixed bool
+
+ // Unmap specifies whether existing mappings in the range being mapped may
+ // be replaced. If Unmap is true, Fixed must be true.
+ Unmap bool
+
+ // If Map32Bit is true, all addresses in the created mapping must fit in a
+ // 32-bit integer. (Note that the "end address" of the mapping, i.e. the
+ // address of the first byte *after* the mapping, need not fit in a 32-bit
+ // integer.) Map32Bit is ignored if Fixed is true.
+ Map32Bit bool
+
+ // Perms is the set of permissions to the applied to this mapping.
+ Perms usermem.AccessType
+
+ // MaxPerms limits the set of permissions that may ever apply to this
+ // mapping. If Mappable is not nil, all memmap.Translations returned by
+ // Mappable.Translate must support all accesses in MaxPerms.
+ //
+ // Preconditions: MaxAccessType should be an effective AccessType, as
+ // access cannot be limited beyond effective AccessTypes.
+ MaxPerms usermem.AccessType
+
+ // Private is true if writes to the mapping should be propagated to a copy
+ // that is exclusive to the MemoryManager.
+ Private bool
+
+ // GrowsDown is true if the mapping should be automatically expanded
+ // downward on guard page faults.
+ GrowsDown bool
+
+ // Precommit is true if the platform should eagerly commit resources to the
+ // mapping (see platform.AddressSpace.MapFile).
+ Precommit bool
+
+ // MLockMode specifies the memory locking behavior of the mapping.
+ MLockMode MLockMode
+
+ // Hint is the name used for the mapping in /proc/[pid]/maps. If Hint is
+ // empty, MappingIdentity.MappedName() will be used instead.
+ //
+ // TODO(jamieliu): Replace entirely with MappingIdentity?
+ Hint string
+}
diff --git a/pkg/sentry/memmap/memmap_state_autogen.go b/pkg/sentry/memmap/memmap_state_autogen.go
new file mode 100755
index 000000000..42009f82a
--- /dev/null
+++ b/pkg/sentry/memmap/memmap_state_autogen.go
@@ -0,0 +1,93 @@
+// automatically generated by stateify.
+
+package memmap
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *MappableRange) beforeSave() {}
+func (x *MappableRange) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *MappableRange) afterLoad() {}
+func (x *MappableRange) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func (x *MappingOfRange) beforeSave() {}
+func (x *MappingOfRange) save(m state.Map) {
+ x.beforeSave()
+ m.Save("MappingSpace", &x.MappingSpace)
+ m.Save("AddrRange", &x.AddrRange)
+ m.Save("Writable", &x.Writable)
+}
+
+func (x *MappingOfRange) afterLoad() {}
+func (x *MappingOfRange) load(m state.Map) {
+ m.Load("MappingSpace", &x.MappingSpace)
+ m.Load("AddrRange", &x.AddrRange)
+ m.Load("Writable", &x.Writable)
+}
+
+func (x *MappingSet) beforeSave() {}
+func (x *MappingSet) save(m state.Map) {
+ x.beforeSave()
+ var root *MappingSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *MappingSet) afterLoad() {}
+func (x *MappingSet) load(m state.Map) {
+ m.LoadValue("root", new(*MappingSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*MappingSegmentDataSlices)) })
+}
+
+func (x *Mappingnode) beforeSave() {}
+func (x *Mappingnode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *Mappingnode) afterLoad() {}
+func (x *Mappingnode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *MappingSegmentDataSlices) beforeSave() {}
+func (x *MappingSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *MappingSegmentDataSlices) afterLoad() {}
+func (x *MappingSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func init() {
+ state.Register("memmap.MappableRange", (*MappableRange)(nil), state.Fns{Save: (*MappableRange).save, Load: (*MappableRange).load})
+ state.Register("memmap.MappingOfRange", (*MappingOfRange)(nil), state.Fns{Save: (*MappingOfRange).save, Load: (*MappingOfRange).load})
+ state.Register("memmap.MappingSet", (*MappingSet)(nil), state.Fns{Save: (*MappingSet).save, Load: (*MappingSet).load})
+ state.Register("memmap.Mappingnode", (*Mappingnode)(nil), state.Fns{Save: (*Mappingnode).save, Load: (*Mappingnode).load})
+ state.Register("memmap.MappingSegmentDataSlices", (*MappingSegmentDataSlices)(nil), state.Fns{Save: (*MappingSegmentDataSlices).save, Load: (*MappingSegmentDataSlices).load})
+}
diff --git a/pkg/sentry/memutil/memutil.go b/pkg/sentry/memutil/memutil.go
new file mode 100644
index 000000000..a4154c42a
--- /dev/null
+++ b/pkg/sentry/memutil/memutil.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memutil contains the utility functions for memory operations.
+package memutil
diff --git a/pkg/sentry/memutil/memutil_state_autogen.go b/pkg/sentry/memutil/memutil_state_autogen.go
new file mode 100755
index 000000000..52f337963
--- /dev/null
+++ b/pkg/sentry/memutil/memutil_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package memutil
+
diff --git a/pkg/sentry/memutil/memutil_unsafe.go b/pkg/sentry/memutil/memutil_unsafe.go
new file mode 100644
index 000000000..92eab8a26
--- /dev/null
+++ b/pkg/sentry/memutil/memutil_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memutil
+
+import (
+ "fmt"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// CreateMemFD creates a memfd file and returns the fd.
+func CreateMemFD(name string, flags int) (int, error) {
+ p, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return -1, err
+ }
+ fd, _, e := syscall.Syscall(unix.SYS_MEMFD_CREATE, uintptr(unsafe.Pointer(p)), uintptr(flags), 0)
+ if e != 0 {
+ if e == syscall.ENOSYS {
+ return -1, fmt.Errorf("memfd_create(2) is not implemented. Check that you have Linux 3.17 or higher")
+ }
+ return -1, e
+ }
+ return int(fd), nil
+}
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
new file mode 100644
index 000000000..06f587fde
--- /dev/null
+++ b/pkg/sentry/mm/address_space.go
@@ -0,0 +1,216 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/atomicbitops"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// AddressSpace returns the platform.AddressSpace bound to mm.
+//
+// Preconditions: The caller must have called mm.Activate().
+func (mm *MemoryManager) AddressSpace() platform.AddressSpace {
+ if atomic.LoadInt32(&mm.active) == 0 {
+ panic("trying to use inactive address space?")
+ }
+ return mm.as
+}
+
+// Activate ensures this MemoryManager has a platform.AddressSpace.
+//
+// The caller must not hold any locks when calling Activate.
+//
+// When this MemoryManager is no longer needed by a task, it should call
+// Deactivate to release the reference.
+func (mm *MemoryManager) Activate() error {
+ // Fast path: the MemoryManager already has an active
+ // platform.AddressSpace, and we just need to indicate that we need it too.
+ if atomicbitops.IncUnlessZeroInt32(&mm.active) {
+ return nil
+ }
+
+ for {
+ // Slow path: may need to synchronize with other goroutines changing
+ // mm.active to or from zero.
+ mm.activeMu.Lock()
+ // Inline Unlock instead of using a defer for performance since this
+ // method is commonly in the hot-path.
+
+ // Check if we raced with another goroutine performing activation.
+ if atomic.LoadInt32(&mm.active) > 0 {
+ // This can't race; Deactivate can't decrease mm.active from 1 to 0
+ // without holding activeMu.
+ atomic.AddInt32(&mm.active, 1)
+ mm.activeMu.Unlock()
+ return nil
+ }
+
+ // Do we have a context? If so, then we never unmapped it. This can
+ // only be the case if !mm.p.CooperativelySchedulesAddressSpace().
+ if mm.as != nil {
+ atomic.StoreInt32(&mm.active, 1)
+ mm.activeMu.Unlock()
+ return nil
+ }
+
+ // Get a new address space. We must force unmapping by passing nil to
+ // NewAddressSpace if requested. (As in the nil interface object, not a
+ // typed nil.)
+ mappingsID := (interface{})(mm)
+ if mm.unmapAllOnActivate {
+ mappingsID = nil
+ }
+ as, c, err := mm.p.NewAddressSpace(mappingsID)
+ if err != nil {
+ mm.activeMu.Unlock()
+ return err
+ }
+ if as == nil {
+ // AddressSpace is unavailable, we must wait.
+ //
+ // activeMu must not be held while waiting, as the user
+ // of the address space we are waiting on may attempt
+ // to take activeMu.
+ //
+ // Don't call UninterruptibleSleepStart to register the
+ // wait to allow the watchdog stuck task to trigger in
+ // case a process is starved waiting for the address
+ // space.
+ mm.activeMu.Unlock()
+ <-c
+ continue
+ }
+
+ // Okay, we could restore all mappings at this point.
+ // But forget that. Let's just let them fault in.
+ mm.as = as
+
+ // Unmapping is done, if necessary.
+ mm.unmapAllOnActivate = false
+
+ // Now that m.as has been assigned, we can set m.active to a non-zero value
+ // to enable the fast path.
+ atomic.StoreInt32(&mm.active, 1)
+
+ mm.activeMu.Unlock()
+ return nil
+ }
+}
+
+// Deactivate releases a reference to the MemoryManager.
+func (mm *MemoryManager) Deactivate() {
+ // Fast path: this is not the last goroutine to deactivate the
+ // MemoryManager.
+ if atomicbitops.DecUnlessOneInt32(&mm.active) {
+ return
+ }
+
+ mm.activeMu.Lock()
+ // Same as Activate.
+
+ // Still active?
+ if atomic.AddInt32(&mm.active, -1) > 0 {
+ mm.activeMu.Unlock()
+ return
+ }
+
+ // Can we hold on to the address space?
+ if !mm.p.CooperativelySchedulesAddressSpace() {
+ mm.activeMu.Unlock()
+ return
+ }
+
+ // Release the address space.
+ mm.as.Release()
+
+ // Lost it.
+ mm.as = nil
+ mm.activeMu.Unlock()
+}
+
+// mapASLocked maps addresses in ar into mm.as. If precommit is true, mappings
+// for all addresses in ar should be precommitted.
+//
+// Preconditions: mm.activeMu must be locked. mm.as != nil. ar.Length() != 0.
+// ar must be page-aligned. pseg == mm.pmas.LowerBoundSegment(ar.Start).
+func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, precommit bool) error {
+ // By default, map entire pmas at a time, under the assumption that there
+ // is no cost to mapping more of a pma than necessary.
+ mapAR := usermem.AddrRange{0, ^usermem.Addr(usermem.PageSize - 1)}
+ if precommit {
+ // When explicitly precommitting, only map ar, since overmapping may
+ // incur unexpected resource usage.
+ mapAR = ar
+ } else if mapUnit := mm.p.MapUnit(); mapUnit != 0 {
+ // Limit the range we map to ar, aligned to mapUnit.
+ mapMask := usermem.Addr(mapUnit - 1)
+ mapAR.Start = ar.Start &^ mapMask
+ // If rounding ar.End up overflows, just keep the existing mapAR.End.
+ if end := (ar.End + mapMask) &^ mapMask; end >= ar.End {
+ mapAR.End = end
+ }
+ }
+ if checkInvariants {
+ if !mapAR.IsSupersetOf(ar) {
+ panic(fmt.Sprintf("mapAR %#v is not a superset of ar %#v", mapAR, ar))
+ }
+ }
+
+ // Since this checks ar.End and not mapAR.End, we will never map a pma that
+ // is not required.
+ for pseg.Ok() && pseg.Start() < ar.End {
+ pma := pseg.ValuePtr()
+ pmaAR := pseg.Range()
+ pmaMapAR := pmaAR.Intersect(mapAR)
+ perms := pma.effectivePerms
+ if pma.needCOW {
+ perms.Write = false
+ }
+ if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil {
+ return err
+ }
+ pseg = pseg.NextSegment()
+ }
+ return nil
+}
+
+// unmapASLocked removes all AddressSpace mappings for addresses in ar.
+//
+// Preconditions: mm.activeMu must be locked.
+func (mm *MemoryManager) unmapASLocked(ar usermem.AddrRange) {
+ if mm.as == nil {
+ // No AddressSpace? Force all mappings to be unmapped on the next
+ // Activate.
+ mm.unmapAllOnActivate = true
+ return
+ }
+
+ // unmapASLocked doesn't require vmas or pmas to exist for ar, so it can be
+ // passed ranges that include addresses that can't be mapped by the
+ // application.
+ ar = ar.Intersect(mm.applicationAddrRange())
+
+ // Note that this AddressSpace may or may not be active. If the
+ // platform does not require cooperative sharing of AddressSpaces, they
+ // are retained between Deactivate/Activate calls. Despite not being
+ // active, it is still valid to perform operations on these address
+ // spaces.
+ mm.as.Unmap(ar.Start, uint64(ar.Length()))
+}
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
new file mode 100644
index 000000000..5c61acf36
--- /dev/null
+++ b/pkg/sentry/mm/aio_context.go
@@ -0,0 +1,387 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// aioManager creates and manages asynchronous I/O contexts.
+//
+// +stateify savable
+type aioManager struct {
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // aioContexts is the set of asynchronous I/O contexts.
+ contexts map[uint64]*AIOContext
+}
+
+func (a *aioManager) destroy() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ for _, ctx := range a.contexts {
+ ctx.destroy()
+ }
+}
+
+// newAIOContext creates a new context for asynchronous I/O.
+//
+// Returns false if 'id' is currently in use.
+func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if _, ok := a.contexts[id]; ok {
+ return false
+ }
+
+ a.contexts[id] = &AIOContext{
+ done: make(chan struct{}, 1),
+ maxOutstanding: events,
+ }
+ return true
+}
+
+// destroyAIOContext destroys an asynchronous I/O context.
+//
+// False is returned if the context does not exist.
+func (a *aioManager) destroyAIOContext(id uint64) bool {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ctx, ok := a.contexts[id]
+ if !ok {
+ return false
+ }
+ delete(a.contexts, id)
+ ctx.destroy()
+ return true
+}
+
+// lookupAIOContext looks up the given context.
+//
+// Returns false if context does not exist.
+func (a *aioManager) lookupAIOContext(id uint64) (*AIOContext, bool) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ctx, ok := a.contexts[id]
+ return ctx, ok
+}
+
+// ioResult is a completed I/O operation.
+//
+// +stateify savable
+type ioResult struct {
+ data interface{}
+ ioEntry
+}
+
+// AIOContext is a single asynchronous I/O context.
+//
+// +stateify savable
+type AIOContext struct {
+ // done is the notification channel used for all requests.
+ done chan struct{} `state:"nosave"`
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // results is the set of completed requests.
+ results ioList
+
+ // maxOutstanding is the maximum number of outstanding entries; this value
+ // is immutable.
+ maxOutstanding uint32
+
+ // outstanding is the number of requests outstanding; this will effectively
+ // be the number of entries in the result list or that are expected to be
+ // added to the result list.
+ outstanding uint32
+
+ // dead is set when the context is destroyed.
+ dead bool `state:"zerovalue"`
+}
+
+// destroy marks the context dead.
+func (ctx *AIOContext) destroy() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ ctx.dead = true
+ if ctx.outstanding == 0 {
+ close(ctx.done)
+ }
+}
+
+// Prepare reserves space for a new request, returning true if available.
+// Returns false if the context is busy.
+func (ctx *AIOContext) Prepare() bool {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ if ctx.outstanding >= ctx.maxOutstanding {
+ return false
+ }
+ ctx.outstanding++
+ return true
+}
+
+// PopRequest pops a completed request if available, this function does not do
+// any blocking. Returns false if no request is available.
+func (ctx *AIOContext) PopRequest() (interface{}, bool) {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ // Is there anything ready?
+ if e := ctx.results.Front(); e != nil {
+ ctx.results.Remove(e)
+ ctx.outstanding--
+ if ctx.outstanding == 0 && ctx.dead {
+ close(ctx.done)
+ }
+ return e.data, true
+ }
+ return nil, false
+}
+
+// FinishRequest finishes a pending request. It queues up the data
+// and notifies listeners.
+func (ctx *AIOContext) FinishRequest(data interface{}) {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ // Push to the list and notify opportunistically. The channel notify
+ // here is guaranteed to be safe because outstanding must be non-zero.
+ // The done channel is only closed when outstanding reaches zero.
+ ctx.results.PushBack(&ioResult{data: data})
+
+ select {
+ case ctx.done <- struct{}{}:
+ default:
+ }
+}
+
+// WaitChannel returns a channel that is notified when an AIO request is
+// completed.
+//
+// The boolean return value indicates whether or not the context is active.
+func (ctx *AIOContext) WaitChannel() (chan struct{}, bool) {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ if ctx.outstanding == 0 && ctx.dead {
+ return nil, false
+ }
+ return ctx.done, true
+}
+
+// aioMappable implements memmap.MappingIdentity and memmap.Mappable for AIO
+// ring buffers.
+//
+// +stateify savable
+type aioMappable struct {
+ refs.AtomicRefCount
+
+ mfp pgalloc.MemoryFileProvider
+ fr platform.FileRange
+}
+
+var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp())
+
+func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) {
+ fr, err := mfp.MemoryFile().Allocate(aioRingBufferSize, usage.Anonymous)
+ if err != nil {
+ return nil, err
+ }
+ return &aioMappable{mfp: mfp, fr: fr}, nil
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+func (m *aioMappable) DecRef() {
+ m.AtomicRefCount.DecRefWithDestructor(func() {
+ m.mfp.MemoryFile().DecRef(m.fr)
+ })
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (m *aioMappable) MappedName(ctx context.Context) string {
+ return "[aio]"
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (m *aioMappable) DeviceID() uint64 {
+ return 0
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (m *aioMappable) InodeID() uint64 {
+ return 0
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (m *aioMappable) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ // Linux: aio_ring_fops.fsync == NULL
+ return syserror.EINVAL
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (m *aioMappable) AddMapping(_ context.Context, _ memmap.MappingSpace, ar usermem.AddrRange, offset uint64, _ bool) error {
+ // Don't allow mappings to be expanded (in Linux, fs/aio.c:aio_ring_mmap()
+ // sets VM_DONTEXPAND).
+ if offset != 0 || uint64(ar.Length()) != aioRingBufferSize {
+ return syserror.EFAULT
+ }
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (m *aioMappable) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) {
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, _ bool) error {
+ // Don't allow mappings to be expanded (in Linux, fs/aio.c:aio_ring_mmap()
+ // sets VM_DONTEXPAND).
+ if offset != 0 || uint64(dstAR.Length()) != aioRingBufferSize {
+ return syserror.EFAULT
+ }
+ // Require that the mapping correspond to a live AIOContext. Compare
+ // Linux's fs/aio.c:aio_ring_mremap().
+ mm, ok := ms.(*MemoryManager)
+ if !ok {
+ return syserror.EINVAL
+ }
+ am := &mm.aioManager
+ am.mu.Lock()
+ defer am.mu.Unlock()
+ oldID := uint64(srcAR.Start)
+ aioCtx, ok := am.contexts[oldID]
+ if !ok {
+ return syserror.EINVAL
+ }
+ aioCtx.mu.Lock()
+ defer aioCtx.mu.Unlock()
+ if aioCtx.dead {
+ return syserror.EINVAL
+ }
+ // Use the new ID for the AIOContext.
+ am.contexts[uint64(dstAR.Start)] = aioCtx
+ delete(am.contexts, oldID)
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (m *aioMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ var err error
+ if required.End > m.fr.Length() {
+ err = &memmap.BusError{syserror.EFAULT}
+ }
+ if source := optional.Intersect(memmap.MappableRange{0, m.fr.Length()}); source.Length() != 0 {
+ return []memmap.Translation{
+ {
+ Source: source,
+ File: m.mfp.MemoryFile(),
+ Offset: m.fr.Start + source.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, err
+ }
+ return nil, err
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (m *aioMappable) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
+
+// NewAIOContext creates a new context for asynchronous I/O.
+//
+// NewAIOContext is analogous to Linux's fs/aio.c:ioctx_alloc().
+func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint64, error) {
+ // libaio get_ioevents() expects context "handle" to be a valid address.
+ // libaio peeks inside looking for a magic number. This function allocates
+ // a page per context and keeps it set to zeroes to ensure it will not
+ // match AIO_RING_MAGIC and make libaio happy.
+ m, err := newAIOMappable(mm.mfp)
+ if err != nil {
+ return 0, err
+ }
+ defer m.DecRef()
+ addr, err := mm.MMap(ctx, memmap.MMapOpts{
+ Length: aioRingBufferSize,
+ MappingIdentity: m,
+ Mappable: m,
+ // TODO(fvoznika): Linux does "do_mmap_pgoff(..., PROT_READ |
+ // PROT_WRITE, ...)" in fs/aio.c:aio_setup_ring(); why do we make this
+ // mapping read-only?
+ Perms: usermem.Read,
+ MaxPerms: usermem.Read,
+ })
+ if err != nil {
+ return 0, err
+ }
+ id := uint64(addr)
+ if !mm.aioManager.newAIOContext(events, id) {
+ mm.MUnmap(ctx, addr, aioRingBufferSize)
+ return 0, syserror.EINVAL
+ }
+ return id, nil
+}
+
+// DestroyAIOContext destroys an asynchronous I/O context. It returns false if
+// the context does not exist.
+func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) bool {
+ if _, ok := mm.LookupAIOContext(ctx, id); !ok {
+ return false
+ }
+
+ // Only unmaps after it assured that the address is a valid aio context to
+ // prevent random memory from been unmapped.
+ //
+ // Note: It's possible to unmap this address and map something else into
+ // the same address. Then it would be unmapping memory that it doesn't own.
+ // This is, however, the way Linux implements AIO. Keeps the same [weird]
+ // semantics in case anyone relies on it.
+ mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
+
+ return mm.aioManager.destroyAIOContext(id)
+}
+
+// LookupAIOContext looks up the given context. It returns false if the context
+// does not exist.
+func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOContext, bool) {
+ aioCtx, ok := mm.aioManager.lookupAIOContext(id)
+ if !ok {
+ return nil, false
+ }
+
+ // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes
+ // from id).
+ var buf [4]byte
+ _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
+ if err != nil {
+ return nil, false
+ }
+
+ return aioCtx, true
+}
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
new file mode 100644
index 000000000..c37fc9f7b
--- /dev/null
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+// afterLoad is invoked by stateify.
+func (a *AIOContext) afterLoad() {
+ a.done = make(chan struct{}, 1)
+}
diff --git a/pkg/sentry/mm/debug.go b/pkg/sentry/mm/debug.go
new file mode 100644
index 000000000..fe58cfc4c
--- /dev/null
+++ b/pkg/sentry/mm/debug.go
@@ -0,0 +1,98 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+const (
+ // If checkInvariants is true, perform runtime checks for invariants
+ // expected by the mm package. This is normally disabled since MM is a
+ // significant hot path in general, and some such checks (notably
+ // memmap.CheckTranslateResult) are very expensive.
+ checkInvariants = false
+
+ // If logIOErrors is true, log I/O errors that originate from MM before
+ // converting them to EFAULT.
+ logIOErrors = false
+)
+
+// String implements fmt.Stringer.String.
+func (mm *MemoryManager) String() string {
+ return mm.DebugString(context.Background())
+}
+
+// DebugString returns a string containing information about mm for debugging.
+func (mm *MemoryManager) DebugString(ctx context.Context) string {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ mm.activeMu.RLock()
+ defer mm.activeMu.RUnlock()
+ return mm.debugStringLocked(ctx)
+}
+
+// Preconditions: mm.mappingMu and mm.activeMu must be locked.
+func (mm *MemoryManager) debugStringLocked(ctx context.Context) string {
+ var b bytes.Buffer
+ b.WriteString("VMAs:\n")
+ for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
+ b.Write(mm.vmaMapsEntryLocked(ctx, vseg))
+ }
+ b.WriteString("PMAs:\n")
+ for pseg := mm.pmas.FirstSegment(); pseg.Ok(); pseg = pseg.NextSegment() {
+ b.Write(pseg.debugStringEntryLocked())
+ }
+ return string(b.Bytes())
+}
+
+// Preconditions: mm.activeMu must be locked.
+func (pseg pmaIterator) debugStringEntryLocked() []byte {
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "%08x-%08x ", pseg.Start(), pseg.End())
+
+ pma := pseg.ValuePtr()
+ if pma.effectivePerms.Read {
+ b.WriteByte('r')
+ } else {
+ b.WriteByte('-')
+ }
+ if pma.effectivePerms.Write {
+ if pma.needCOW {
+ b.WriteByte('c')
+ } else {
+ b.WriteByte('w')
+ }
+ } else {
+ b.WriteByte('-')
+ }
+ if pma.effectivePerms.Execute {
+ b.WriteByte('x')
+ } else {
+ b.WriteByte('-')
+ }
+ if pma.private {
+ b.WriteByte('p')
+ } else {
+ b.WriteByte('s')
+ }
+
+ fmt.Fprintf(&b, " %08x %T\n", pma.off, pma.file)
+ return b.Bytes()
+}
diff --git a/pkg/sentry/mm/file_refcount_set.go b/pkg/sentry/mm/file_refcount_set.go
new file mode 100755
index 000000000..99c088c83
--- /dev/null
+++ b/pkg/sentry/mm/file_refcount_set.go
@@ -0,0 +1,1274 @@
+package mm
+
+import (
+ __generics_imported0 "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+)
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ fileRefcountminDegree = 3
+
+ fileRefcountmaxDegree = 2 * fileRefcountminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type fileRefcountSet struct {
+ root fileRefcountnode `state:".(*fileRefcountSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *fileRefcountSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *fileRefcountSet) IsEmptyRange(r __generics_imported0.FileRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *fileRefcountSet) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *fileRefcountSet) SpanRange(r __generics_imported0.FileRange) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *fileRefcountSet) FirstSegment() fileRefcountIterator {
+ if s.root.nrSegments == 0 {
+ return fileRefcountIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *fileRefcountSet) LastSegment() fileRefcountIterator {
+ if s.root.nrSegments == 0 {
+ return fileRefcountIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *fileRefcountSet) FirstGap() fileRefcountGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return fileRefcountGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *fileRefcountSet) LastGap() fileRefcountGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return fileRefcountGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *fileRefcountSet) Find(key uint64) (fileRefcountIterator, fileRefcountGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return fileRefcountIterator{n, i}, fileRefcountGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return fileRefcountIterator{}, fileRefcountGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *fileRefcountSet) FindSegment(key uint64) fileRefcountIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *fileRefcountSet) LowerBoundSegment(min uint64) fileRefcountIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *fileRefcountSet) UpperBoundSegment(max uint64) fileRefcountIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *fileRefcountSet) FindGap(key uint64) fileRefcountGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *fileRefcountSet) LowerBoundGap(min uint64) fileRefcountGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *fileRefcountSet) UpperBoundGap(max uint64) fileRefcountGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *fileRefcountSet) Add(r __generics_imported0.FileRange, val int32) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *fileRefcountSet) AddWithoutMerging(r __generics_imported0.FileRange, val int32) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *fileRefcountSet) Insert(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (fileRefcountSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (fileRefcountSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (fileRefcountSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *fileRefcountSet) InsertWithoutMerging(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *fileRefcountSet) InsertWithoutMergingUnchecked(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return fileRefcountIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *fileRefcountSet) Remove(seg fileRefcountIterator) fileRefcountGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ fileRefcountSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(fileRefcountGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *fileRefcountSet) RemoveAll() {
+ s.root = fileRefcountnode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *fileRefcountSet) RemoveRange(r __generics_imported0.FileRange) fileRefcountGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *fileRefcountSet) Merge(first, second fileRefcountIterator) fileRefcountIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *fileRefcountSet) MergeUnchecked(first, second fileRefcountIterator) fileRefcountIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (fileRefcountSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return fileRefcountIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *fileRefcountSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *fileRefcountSet) MergeRange(r __generics_imported0.FileRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *fileRefcountSet) MergeAdjacent(r __generics_imported0.FileRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *fileRefcountSet) Split(seg fileRefcountIterator, split uint64) (fileRefcountIterator, fileRefcountIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *fileRefcountSet) SplitUnchecked(seg fileRefcountIterator, split uint64) (fileRefcountIterator, fileRefcountIterator) {
+ val1, val2 := (fileRefcountSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *fileRefcountSet) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *fileRefcountSet) Isolate(seg fileRefcountIterator, r __generics_imported0.FileRange) fileRefcountIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *fileRefcountSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg fileRefcountIterator)) fileRefcountGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return fileRefcountGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return fileRefcountGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type fileRefcountnode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *fileRefcountnode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [fileRefcountmaxDegree - 1]__generics_imported0.FileRange
+ values [fileRefcountmaxDegree - 1]int32
+ children [fileRefcountmaxDegree]*fileRefcountnode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *fileRefcountnode) firstSegment() fileRefcountIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return fileRefcountIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *fileRefcountnode) lastSegment() fileRefcountIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return fileRefcountIterator{n, n.nrSegments - 1}
+}
+
+func (n *fileRefcountnode) prevSibling() *fileRefcountnode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *fileRefcountnode) nextSibling() *fileRefcountnode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *fileRefcountnode) rebalanceBeforeInsert(gap fileRefcountGapIterator) fileRefcountGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < fileRefcountmaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &fileRefcountnode{
+ nrSegments: fileRefcountminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &fileRefcountnode{
+ nrSegments: fileRefcountminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:fileRefcountminDegree-1], n.keys[:fileRefcountminDegree-1])
+ copy(left.values[:fileRefcountminDegree-1], n.values[:fileRefcountminDegree-1])
+ copy(right.keys[:fileRefcountminDegree-1], n.keys[fileRefcountminDegree:])
+ copy(right.values[:fileRefcountminDegree-1], n.values[fileRefcountminDegree:])
+ n.keys[0], n.values[0] = n.keys[fileRefcountminDegree-1], n.values[fileRefcountminDegree-1]
+ fileRefcountzeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:fileRefcountminDegree], n.children[:fileRefcountminDegree])
+ copy(right.children[:fileRefcountminDegree], n.children[fileRefcountminDegree:])
+ fileRefcountzeroNodeSlice(n.children[2:])
+ for i := 0; i < fileRefcountminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < fileRefcountminDegree {
+ return fileRefcountGapIterator{left, gap.index}
+ }
+ return fileRefcountGapIterator{right, gap.index - fileRefcountminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[fileRefcountminDegree-1], n.values[fileRefcountminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &fileRefcountnode{
+ nrSegments: fileRefcountminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:fileRefcountminDegree-1], n.keys[fileRefcountminDegree:])
+ copy(sibling.values[:fileRefcountminDegree-1], n.values[fileRefcountminDegree:])
+ fileRefcountzeroValueSlice(n.values[fileRefcountminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:fileRefcountminDegree], n.children[fileRefcountminDegree:])
+ fileRefcountzeroNodeSlice(n.children[fileRefcountminDegree:])
+ for i := 0; i < fileRefcountminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = fileRefcountminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < fileRefcountminDegree {
+ return gap
+ }
+ return fileRefcountGapIterator{sibling, gap.index - fileRefcountminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *fileRefcountnode) rebalanceAfterRemove(gap fileRefcountGapIterator) fileRefcountGapIterator {
+ for {
+ if n.nrSegments >= fileRefcountminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= fileRefcountminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ fileRefcountSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return fileRefcountGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return fileRefcountGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= fileRefcountminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ fileRefcountSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return fileRefcountGapIterator{n, n.nrSegments}
+ }
+ return fileRefcountGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return fileRefcountGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return fileRefcountGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *fileRefcountnode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = fileRefcountGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ fileRefcountSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type fileRefcountIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *fileRefcountnode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg fileRefcountIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg fileRefcountIterator) Range() __generics_imported0.FileRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg fileRefcountIterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg fileRefcountIterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg fileRefcountIterator) SetRangeUnchecked(r __generics_imported0.FileRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg fileRefcountIterator) SetRange(r __generics_imported0.FileRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg fileRefcountIterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg fileRefcountIterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg fileRefcountIterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg fileRefcountIterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg fileRefcountIterator) Value() int32 {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg fileRefcountIterator) ValuePtr() *int32 {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg fileRefcountIterator) SetValue(val int32) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg fileRefcountIterator) PrevSegment() fileRefcountIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return fileRefcountIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return fileRefcountIterator{}
+ }
+ return fileRefcountsegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg fileRefcountIterator) NextSegment() fileRefcountIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return fileRefcountIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return fileRefcountIterator{}
+ }
+ return fileRefcountsegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg fileRefcountIterator) PrevGap() fileRefcountGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return fileRefcountGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg fileRefcountIterator) NextGap() fileRefcountGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return fileRefcountGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg fileRefcountIterator) PrevNonEmpty() (fileRefcountIterator, fileRefcountGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return fileRefcountIterator{}, gap
+ }
+ return gap.PrevSegment(), fileRefcountGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg fileRefcountIterator) NextNonEmpty() (fileRefcountIterator, fileRefcountGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return fileRefcountIterator{}, gap
+ }
+ return gap.NextSegment(), fileRefcountGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type fileRefcountGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *fileRefcountnode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap fileRefcountGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap fileRefcountGapIterator) Range() __generics_imported0.FileRange {
+ return __generics_imported0.FileRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap fileRefcountGapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return fileRefcountSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap fileRefcountGapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return fileRefcountSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap fileRefcountGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap fileRefcountGapIterator) PrevSegment() fileRefcountIterator {
+ return fileRefcountsegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap fileRefcountGapIterator) NextSegment() fileRefcountIterator {
+ return fileRefcountsegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap fileRefcountGapIterator) PrevGap() fileRefcountGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return fileRefcountGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap fileRefcountGapIterator) NextGap() fileRefcountGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return fileRefcountGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func fileRefcountsegmentBeforePosition(n *fileRefcountnode, i int) fileRefcountIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return fileRefcountIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return fileRefcountIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func fileRefcountsegmentAfterPosition(n *fileRefcountnode, i int) fileRefcountIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return fileRefcountIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return fileRefcountIterator{n, i}
+}
+
+func fileRefcountzeroValueSlice(slice []int32) {
+
+ for i := range slice {
+ fileRefcountSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func fileRefcountzeroNodeSlice(slice []*fileRefcountnode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *fileRefcountSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *fileRefcountnode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *fileRefcountnode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type fileRefcountSegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []int32
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *fileRefcountSet) ExportSortedSlices() *fileRefcountSegmentDataSlices {
+ var sds fileRefcountSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *fileRefcountSet) ImportSortedSlices(sds *fileRefcountSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *fileRefcountSet) saveRoot() *fileRefcountSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *fileRefcountSet) loadRoot(sds *fileRefcountSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go
new file mode 100644
index 000000000..e4c057d28
--- /dev/null
+++ b/pkg/sentry/mm/io.go
@@ -0,0 +1,639 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// There are two supported ways to copy data to/from application virtual
+// memory:
+//
+// 1. Internally-mapped copying: Determine the platform.File that backs the
+// copied-to/from virtual address, obtain a mapping of its pages, and read or
+// write to the mapping.
+//
+// 2. AddressSpace copying: If platform.Platform.SupportsAddressSpaceIO() is
+// true, AddressSpace permissions are applicable, and an AddressSpace is
+// available, copy directly through the AddressSpace, handling faults as
+// needed.
+//
+// (Given that internally-mapped copying requires that backing memory is always
+// implemented using a host file descriptor, we could also preadv/pwritev to it
+// instead. But this would incur a host syscall for each use of the mapped
+// page, whereas mmap is a one-time cost.)
+//
+// The fixed overhead of internally-mapped copying is expected to be higher
+// than that of AddressSpace copying since the former always needs to translate
+// addresses, whereas the latter only needs to do so when faults occur.
+// However, the throughput of internally-mapped copying is expected to be
+// somewhat higher than that of AddressSpace copying due to the high cost of
+// page faults and because implementations of the latter usually rely on
+// safecopy, which doesn't use AVX registers. So we prefer to use AddressSpace
+// copying (when available) for smaller copies, and switch to internally-mapped
+// copying once a size threshold is exceeded.
+const (
+ // copyMapMinBytes is the size threshold for switching to internally-mapped
+ // copying in CopyOut, CopyIn, and ZeroOut.
+ copyMapMinBytes = 32 << 10 // 32 KB
+
+ // rwMapMinBytes is the size threshold for switching to internally-mapped
+ // copying in CopyOutFrom and CopyInTo. It's lower than copyMapMinBytes
+ // since AddressSpace copying in this case requires additional buffering;
+ // see CopyOutFrom for details.
+ rwMapMinBytes = 512
+)
+
+// CheckIORange is similar to usermem.Addr.ToRange, but applies bounds checks
+// consistent with Linux's arch/x86/include/asm/uaccess.h:access_ok().
+//
+// Preconditions: length >= 0.
+func (mm *MemoryManager) CheckIORange(addr usermem.Addr, length int64) (usermem.AddrRange, bool) {
+ // Note that access_ok() constrains end even if length == 0.
+ ar, ok := addr.ToRange(uint64(length))
+ return ar, (ok && ar.End <= mm.layout.MaxAddr)
+}
+
+// checkIOVec applies bound checks consistent with Linux's
+// arch/x86/include/asm/uaccess.h:access_ok() to ars.
+func (mm *MemoryManager) checkIOVec(ars usermem.AddrRangeSeq) bool {
+ for !ars.IsEmpty() {
+ ar := ars.Head()
+ if _, ok := mm.CheckIORange(ar.Start, int64(ar.Length())); !ok {
+ return false
+ }
+ ars = ars.Tail()
+ }
+ return true
+}
+
+func (mm *MemoryManager) asioEnabled(opts usermem.IOOpts) bool {
+ return mm.haveASIO && !opts.IgnorePermissions && opts.AddressSpaceActive
+}
+
+// translateIOError converts errors to EFAULT, as is usually reported for all
+// I/O errors originating from MM in Linux.
+func translateIOError(ctx context.Context, err error) error {
+ if err == nil {
+ return nil
+ }
+ if logIOErrors {
+ ctx.Debugf("MM I/O error: %v", err)
+ }
+ return syserror.EFAULT
+}
+
+// CopyOut implements usermem.IO.CopyOut.
+func (mm *MemoryManager) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
+ ar, ok := mm.CheckIORange(addr, int64(len(src)))
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ if len(src) == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && len(src) < copyMapMinBytes {
+ return mm.asCopyOut(ctx, addr, src)
+ }
+
+ // Go through internal mappings.
+ n64, err := mm.withInternalMappings(ctx, ar, usermem.Write, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ n, err := safemem.CopySeq(ims, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src)))
+ return n, translateIOError(ctx, err)
+ })
+ return int(n64), err
+}
+
+func (mm *MemoryManager) asCopyOut(ctx context.Context, addr usermem.Addr, src []byte) (int, error) {
+ var done int
+ for {
+ n, err := mm.as.CopyOut(addr+usermem.Addr(done), src[done:])
+ done += n
+ if err == nil {
+ return done, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ ar, _ := addr.ToRange(uint64(len(src)))
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Write); err != nil {
+ return done, err
+ }
+ continue
+ }
+ return done, translateIOError(ctx, err)
+ }
+}
+
+// CopyIn implements usermem.IO.CopyIn.
+func (mm *MemoryManager) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
+ ar, ok := mm.CheckIORange(addr, int64(len(dst)))
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ if len(dst) == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && len(dst) < copyMapMinBytes {
+ return mm.asCopyIn(ctx, addr, dst)
+ }
+
+ // Go through internal mappings.
+ n64, err := mm.withInternalMappings(ctx, ar, usermem.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ n, err := safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), ims)
+ return n, translateIOError(ctx, err)
+ })
+ return int(n64), err
+}
+
+func (mm *MemoryManager) asCopyIn(ctx context.Context, addr usermem.Addr, dst []byte) (int, error) {
+ var done int
+ for {
+ n, err := mm.as.CopyIn(addr+usermem.Addr(done), dst[done:])
+ done += n
+ if err == nil {
+ return done, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ ar, _ := addr.ToRange(uint64(len(dst)))
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Read); err != nil {
+ return done, err
+ }
+ continue
+ }
+ return done, translateIOError(ctx, err)
+ }
+}
+
+// ZeroOut implements usermem.IO.ZeroOut.
+func (mm *MemoryManager) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
+ ar, ok := mm.CheckIORange(addr, toZero)
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ if toZero == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && toZero < copyMapMinBytes {
+ return mm.asZeroOut(ctx, addr, toZero)
+ }
+
+ // Go through internal mappings.
+ return mm.withInternalMappings(ctx, ar, usermem.Write, opts.IgnorePermissions, func(dsts safemem.BlockSeq) (uint64, error) {
+ n, err := safemem.ZeroSeq(dsts)
+ return n, translateIOError(ctx, err)
+ })
+}
+
+func (mm *MemoryManager) asZeroOut(ctx context.Context, addr usermem.Addr, toZero int64) (int64, error) {
+ var done int64
+ for {
+ n, err := mm.as.ZeroOut(addr+usermem.Addr(done), uintptr(toZero-done))
+ done += int64(n)
+ if err == nil {
+ return done, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ ar, _ := addr.ToRange(uint64(toZero))
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Write); err != nil {
+ return done, err
+ }
+ continue
+ }
+ return done, translateIOError(ctx, err)
+ }
+}
+
+// CopyOutFrom implements usermem.IO.CopyOutFrom.
+func (mm *MemoryManager) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
+ if !mm.checkIOVec(ars) {
+ return 0, syserror.EFAULT
+ }
+
+ if ars.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && ars.NumBytes() < rwMapMinBytes {
+ // We have to introduce a buffered copy, instead of just passing a
+ // safemem.BlockSeq representing addresses in the AddressSpace to src.
+ // This is because usermem.IO.CopyOutFrom() guarantees that it calls
+ // src.ReadToBlocks() at most once, which is incompatible with handling
+ // faults between calls. In the future, this is probably best resolved
+ // by introducing a CopyOutFrom variant or option that allows it to
+ // call src.ReadToBlocks() any number of times.
+ //
+ // This issue applies to CopyInTo as well.
+ buf := make([]byte, int(ars.NumBytes()))
+ bufN, bufErr := src.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)))
+ var done int64
+ for done < int64(bufN) {
+ ar := ars.Head()
+ cplen := int64(ar.Length())
+ if cplen > int64(bufN)-done {
+ cplen = int64(bufN) - done
+ }
+ n, err := mm.asCopyOut(ctx, ar.Start, buf[int(done):int(done+cplen)])
+ done += int64(n)
+ if err != nil {
+ return done, err
+ }
+ ars = ars.Tail()
+ }
+ // Do not convert errors returned by src to EFAULT.
+ return done, bufErr
+ }
+
+ // Go through internal mappings.
+ return mm.withVecInternalMappings(ctx, ars, usermem.Write, opts.IgnorePermissions, src.ReadToBlocks)
+}
+
+// CopyInTo implements usermem.IO.CopyInTo.
+func (mm *MemoryManager) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
+ if !mm.checkIOVec(ars) {
+ return 0, syserror.EFAULT
+ }
+
+ if ars.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.asioEnabled(opts) && ars.NumBytes() < rwMapMinBytes {
+ buf := make([]byte, int(ars.NumBytes()))
+ var done int
+ var bufErr error
+ for !ars.IsEmpty() {
+ ar := ars.Head()
+ var n int
+ n, bufErr = mm.asCopyIn(ctx, ar.Start, buf[done:done+int(ar.Length())])
+ done += n
+ if bufErr != nil {
+ break
+ }
+ ars = ars.Tail()
+ }
+ n, err := dst.WriteFromBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:done])))
+ if err != nil {
+ return int64(n), err
+ }
+ // Do not convert errors returned by dst to EFAULT.
+ return int64(n), bufErr
+ }
+
+ // Go through internal mappings.
+ return mm.withVecInternalMappings(ctx, ars, usermem.Read, opts.IgnorePermissions, dst.WriteFromBlocks)
+}
+
+// SwapUint32 implements usermem.IO.SwapUint32.
+func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
+ ar, ok := mm.CheckIORange(addr, 4)
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.haveASIO && opts.AddressSpaceActive && !opts.IgnorePermissions {
+ for {
+ old, err := mm.as.SwapUint32(addr, new)
+ if err == nil {
+ return old, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.ReadWrite); err != nil {
+ return 0, err
+ }
+ continue
+ }
+ return 0, translateIOError(ctx, err)
+ }
+ }
+
+ // Go through internal mappings.
+ var old uint32
+ _, err := mm.withInternalMappings(ctx, ar, usermem.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
+ // Atomicity is unachievable across mappings.
+ return 0, syserror.EFAULT
+ }
+ im := ims.Head()
+ var err error
+ old, err = safemem.SwapUint32(im, new)
+ if err != nil {
+ return 0, translateIOError(ctx, err)
+ }
+ // Return the number of bytes read.
+ return 4, nil
+ })
+ return old, err
+}
+
+// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32.
+func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
+ ar, ok := mm.CheckIORange(addr, 4)
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.haveASIO && opts.AddressSpaceActive && !opts.IgnorePermissions {
+ for {
+ prev, err := mm.as.CompareAndSwapUint32(addr, old, new)
+ if err == nil {
+ return prev, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.ReadWrite); err != nil {
+ return 0, err
+ }
+ continue
+ }
+ return 0, translateIOError(ctx, err)
+ }
+ }
+
+ // Go through internal mappings.
+ var prev uint32
+ _, err := mm.withInternalMappings(ctx, ar, usermem.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
+ // Atomicity is unachievable across mappings.
+ return 0, syserror.EFAULT
+ }
+ im := ims.Head()
+ var err error
+ prev, err = safemem.CompareAndSwapUint32(im, old, new)
+ if err != nil {
+ return 0, translateIOError(ctx, err)
+ }
+ // Return the number of bytes read.
+ return 4, nil
+ })
+ return prev, err
+}
+
+// LoadUint32 implements usermem.IO.LoadUint32.
+func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) {
+ ar, ok := mm.CheckIORange(addr, 4)
+ if !ok {
+ return 0, syserror.EFAULT
+ }
+
+ // Do AddressSpace IO if applicable.
+ if mm.haveASIO && opts.AddressSpaceActive && !opts.IgnorePermissions {
+ for {
+ val, err := mm.as.LoadUint32(addr)
+ if err == nil {
+ return val, nil
+ }
+ if f, ok := err.(platform.SegmentationFault); ok {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Read); err != nil {
+ return 0, err
+ }
+ continue
+ }
+ return 0, translateIOError(ctx, err)
+ }
+ }
+
+ // Go through internal mappings.
+ var val uint32
+ _, err := mm.withInternalMappings(ctx, ar, usermem.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
+ // Atomicity is unachievable across mappings.
+ return 0, syserror.EFAULT
+ }
+ im := ims.Head()
+ var err error
+ val, err = safemem.LoadUint32(im)
+ if err != nil {
+ return 0, translateIOError(ctx, err)
+ }
+ // Return the number of bytes read.
+ return 4, nil
+ })
+ return val, err
+}
+
+// handleASIOFault handles a page fault at address addr for an AddressSpaceIO
+// operation spanning ioar.
+//
+// Preconditions: mm.as != nil. ioar.Length() != 0. ioar.Contains(addr).
+func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr usermem.Addr, ioar usermem.AddrRange, at usermem.AccessType) error {
+ // Try to map all remaining pages in the I/O operation. This RoundUp can't
+ // overflow because otherwise it would have been caught by CheckIORange.
+ end, _ := ioar.End.RoundUp()
+ ar := usermem.AddrRange{addr.RoundDown(), end}
+
+ // Don't bother trying existingPMAsLocked; in most cases, if we did have
+ // existing pmas, we wouldn't have faulted.
+
+ // Ensure that we have usable vmas. Here and below, only return early if we
+ // can't map the first (faulting) page; failure to map later pages are
+ // silently ignored. This maximizes partial success.
+ mm.mappingMu.RLock()
+ vseg, vend, err := mm.getVMAsLocked(ctx, ar, at, false)
+ if vendaddr := vend.Start(); vendaddr < ar.End {
+ if vendaddr <= ar.Start {
+ mm.mappingMu.RUnlock()
+ return translateIOError(ctx, err)
+ }
+ ar.End = vendaddr
+ }
+
+ // Ensure that we have usable pmas.
+ mm.activeMu.Lock()
+ pseg, pend, err := mm.getPMAsLocked(ctx, vseg, ar, at)
+ mm.mappingMu.RUnlock()
+ if pendaddr := pend.Start(); pendaddr < ar.End {
+ if pendaddr <= ar.Start {
+ mm.activeMu.Unlock()
+ return translateIOError(ctx, err)
+ }
+ ar.End = pendaddr
+ }
+
+ // Downgrade to a read-lock on activeMu since we don't need to mutate pmas
+ // anymore.
+ mm.activeMu.DowngradeLock()
+
+ err = mm.mapASLocked(pseg, ar, false)
+ mm.activeMu.RUnlock()
+ return translateIOError(ctx, err)
+}
+
+// withInternalMappings ensures that pmas exist for all addresses in ar,
+// support access of type (at, ignorePermissions), and have internal mappings
+// cached. It then calls f with mm.activeMu locked for reading, passing
+// internal mappings for the subrange of ar for which this property holds.
+//
+// withInternalMappings takes a function returning uint64 since many safemem
+// functions have this property, but returns an int64 since this is usually
+// more useful for usermem.IO methods.
+//
+// Preconditions: 0 < ar.Length() <= math.MaxInt64.
+func (mm *MemoryManager) withInternalMappings(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
+ // If pmas are already available, we can do IO without touching mm.vmas or
+ // mm.mappingMu.
+ mm.activeMu.RLock()
+ if pseg := mm.existingPMAsLocked(ar, at, ignorePermissions, true /* needInternalMappings */); pseg.Ok() {
+ n, err := f(mm.internalMappingsLocked(pseg, ar))
+ mm.activeMu.RUnlock()
+ // Do not convert errors returned by f to EFAULT.
+ return int64(n), err
+ }
+ mm.activeMu.RUnlock()
+
+ // Ensure that we have usable vmas.
+ mm.mappingMu.RLock()
+ vseg, vend, verr := mm.getVMAsLocked(ctx, ar, at, ignorePermissions)
+ if vendaddr := vend.Start(); vendaddr < ar.End {
+ if vendaddr <= ar.Start {
+ mm.mappingMu.RUnlock()
+ return 0, translateIOError(ctx, verr)
+ }
+ ar.End = vendaddr
+ }
+
+ // Ensure that we have usable pmas.
+ mm.activeMu.Lock()
+ pseg, pend, perr := mm.getPMAsLocked(ctx, vseg, ar, at)
+ mm.mappingMu.RUnlock()
+ if pendaddr := pend.Start(); pendaddr < ar.End {
+ if pendaddr <= ar.Start {
+ mm.activeMu.Unlock()
+ return 0, translateIOError(ctx, perr)
+ }
+ ar.End = pendaddr
+ }
+ imend, imerr := mm.getPMAInternalMappingsLocked(pseg, ar)
+ mm.activeMu.DowngradeLock()
+ if imendaddr := imend.Start(); imendaddr < ar.End {
+ if imendaddr <= ar.Start {
+ mm.activeMu.RUnlock()
+ return 0, translateIOError(ctx, imerr)
+ }
+ ar.End = imendaddr
+ }
+
+ // Do I/O.
+ un, err := f(mm.internalMappingsLocked(pseg, ar))
+ mm.activeMu.RUnlock()
+ n := int64(un)
+
+ // Return the first error in order of progress through ar.
+ if err != nil {
+ // Do not convert errors returned by f to EFAULT.
+ return n, err
+ }
+ if imerr != nil {
+ return n, translateIOError(ctx, imerr)
+ }
+ if perr != nil {
+ return n, translateIOError(ctx, perr)
+ }
+ return n, translateIOError(ctx, verr)
+}
+
+// withVecInternalMappings ensures that pmas exist for all addresses in ars,
+// support access of type (at, ignorePermissions), and have internal mappings
+// cached. It then calls f with mm.activeMu locked for reading, passing
+// internal mappings for the subset of ars for which this property holds.
+//
+// Preconditions: !ars.IsEmpty().
+func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
+ // withInternalMappings is faster than withVecInternalMappings because of
+ // iterator plumbing (this isn't generally practical in the vector case due
+ // to iterator invalidation between AddrRanges). Use it if possible.
+ if ars.NumRanges() == 1 {
+ return mm.withInternalMappings(ctx, ars.Head(), at, ignorePermissions, f)
+ }
+
+ // If pmas are already available, we can do IO without touching mm.vmas or
+ // mm.mappingMu.
+ mm.activeMu.RLock()
+ if mm.existingVecPMAsLocked(ars, at, ignorePermissions, true /* needInternalMappings */) {
+ n, err := f(mm.vecInternalMappingsLocked(ars))
+ mm.activeMu.RUnlock()
+ // Do not convert errors returned by f to EFAULT.
+ return int64(n), err
+ }
+ mm.activeMu.RUnlock()
+
+ // Ensure that we have usable vmas.
+ mm.mappingMu.RLock()
+ vars, verr := mm.getVecVMAsLocked(ctx, ars, at, ignorePermissions)
+ if vars.NumBytes() == 0 {
+ mm.mappingMu.RUnlock()
+ return 0, translateIOError(ctx, verr)
+ }
+
+ // Ensure that we have usable pmas.
+ mm.activeMu.Lock()
+ pars, perr := mm.getVecPMAsLocked(ctx, vars, at)
+ mm.mappingMu.RUnlock()
+ if pars.NumBytes() == 0 {
+ mm.activeMu.Unlock()
+ return 0, translateIOError(ctx, perr)
+ }
+ imars, imerr := mm.getVecPMAInternalMappingsLocked(pars)
+ mm.activeMu.DowngradeLock()
+ if imars.NumBytes() == 0 {
+ mm.activeMu.RUnlock()
+ return 0, translateIOError(ctx, imerr)
+ }
+
+ // Do I/O.
+ un, err := f(mm.vecInternalMappingsLocked(imars))
+ mm.activeMu.RUnlock()
+ n := int64(un)
+
+ // Return the first error in order of progress through ars.
+ if err != nil {
+ // Do not convert errors from f to EFAULT.
+ return n, err
+ }
+ if imerr != nil {
+ return n, translateIOError(ctx, imerr)
+ }
+ if perr != nil {
+ return n, translateIOError(ctx, perr)
+ }
+ return n, translateIOError(ctx, verr)
+}
+
+// truncatedAddrRangeSeq returns a copy of ars, but with the end truncated to
+// at most address end on AddrRange arsit.Head(). It is used in vector I/O paths to
+// truncate usermem.AddrRangeSeq when errors occur.
+//
+// Preconditions: !arsit.IsEmpty(). end <= arsit.Head().End.
+func truncatedAddrRangeSeq(ars, arsit usermem.AddrRangeSeq, end usermem.Addr) usermem.AddrRangeSeq {
+ ar := arsit.Head()
+ if end <= ar.Start {
+ return ars.TakeFirst64(ars.NumBytes() - arsit.NumBytes())
+ }
+ return ars.TakeFirst64(ars.NumBytes() - arsit.NumBytes() + int64(end-ar.Start))
+}
diff --git a/pkg/sentry/mm/io_list.go b/pkg/sentry/mm/io_list.go
new file mode 100755
index 000000000..99c83c4b9
--- /dev/null
+++ b/pkg/sentry/mm/io_list.go
@@ -0,0 +1,173 @@
+package mm
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type ioElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (ioElementMapper) linkerFor(elem *ioResult) *ioResult { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type ioList struct {
+ head *ioResult
+ tail *ioResult
+}
+
+// Reset resets list l to the empty state.
+func (l *ioList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *ioList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *ioList) Front() *ioResult {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *ioList) Back() *ioResult {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *ioList) PushFront(e *ioResult) {
+ ioElementMapper{}.linkerFor(e).SetNext(l.head)
+ ioElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ ioElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *ioList) PushBack(e *ioResult) {
+ ioElementMapper{}.linkerFor(e).SetNext(nil)
+ ioElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ ioElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *ioList) PushBackList(m *ioList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ ioElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ ioElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *ioList) InsertAfter(b, e *ioResult) {
+ a := ioElementMapper{}.linkerFor(b).Next()
+ ioElementMapper{}.linkerFor(e).SetNext(a)
+ ioElementMapper{}.linkerFor(e).SetPrev(b)
+ ioElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ ioElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *ioList) InsertBefore(a, e *ioResult) {
+ b := ioElementMapper{}.linkerFor(a).Prev()
+ ioElementMapper{}.linkerFor(e).SetNext(a)
+ ioElementMapper{}.linkerFor(e).SetPrev(b)
+ ioElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ ioElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *ioList) Remove(e *ioResult) {
+ prev := ioElementMapper{}.linkerFor(e).Prev()
+ next := ioElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ ioElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ ioElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type ioEntry struct {
+ next *ioResult
+ prev *ioResult
+}
+
+// Next returns the entry that follows e in the list.
+func (e *ioEntry) Next() *ioResult {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *ioEntry) Prev() *ioResult {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *ioEntry) SetNext(elem *ioResult) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *ioEntry) SetPrev(elem *ioResult) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
new file mode 100644
index 000000000..7a65a62a2
--- /dev/null
+++ b/pkg/sentry/mm/lifecycle.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/atomicbitops"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// NewMemoryManager returns a new MemoryManager with no mappings and 1 user.
+func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *MemoryManager {
+ return &MemoryManager{
+ p: p,
+ mfp: mfp,
+ haveASIO: p.SupportsAddressSpaceIO(),
+ privateRefs: &privateRefs{},
+ users: 1,
+ auxv: arch.Auxv{},
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ }
+}
+
+// SetMmapLayout initializes mm's layout from the given arch.Context.
+//
+// Preconditions: mm contains no mappings and is not used concurrently.
+func (mm *MemoryManager) SetMmapLayout(ac arch.Context, r *limits.LimitSet) (arch.MmapLayout, error) {
+ layout, err := ac.NewMmapLayout(mm.p.MinUserAddress(), mm.p.MaxUserAddress(), r)
+ if err != nil {
+ return arch.MmapLayout{}, err
+ }
+ mm.layout = layout
+ return layout, nil
+}
+
+// Fork creates a copy of mm with 1 user, as for Linux syscalls fork() or
+// clone() (without CLONE_VM).
+func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ mm2 := &MemoryManager{
+ p: mm.p,
+ mfp: mm.mfp,
+ haveASIO: mm.haveASIO,
+ layout: mm.layout,
+ privateRefs: mm.privateRefs,
+ users: 1,
+ brk: mm.brk,
+ usageAS: mm.usageAS,
+ dataAS: mm.dataAS,
+ // "The child does not inherit its parent's memory locks (mlock(2),
+ // mlockall(2))." - fork(2). So lockedAS is 0 and defMLockMode is
+ // MLockNone, both of which are zero values. vma.mlockMode is reset
+ // when copied below.
+ captureInvalidations: true,
+ argv: mm.argv,
+ envv: mm.envv,
+ auxv: append(arch.Auxv(nil), mm.auxv...),
+ // IncRef'd below, once we know that there isn't an error.
+ executable: mm.executable,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ }
+
+ // Copy vmas.
+ dstvgap := mm2.vmas.FirstGap()
+ for srcvseg := mm.vmas.FirstSegment(); srcvseg.Ok(); srcvseg = srcvseg.NextSegment() {
+ vma := srcvseg.Value() // makes a copy of the vma
+ vmaAR := srcvseg.Range()
+ // Inform the Mappable, if any, of the new mapping.
+ if vma.mappable != nil {
+ if err := vma.mappable.AddMapping(ctx, mm2, vmaAR, vma.off, vma.canWriteMappableLocked()); err != nil {
+ mm2.removeVMAsLocked(ctx, mm2.applicationAddrRange())
+ return nil, err
+ }
+ }
+ if vma.id != nil {
+ vma.id.IncRef()
+ }
+ vma.mlockMode = memmap.MLockNone
+ dstvgap = mm2.vmas.Insert(dstvgap, vmaAR, vma).NextGap()
+ // We don't need to update mm2.usageAS since we copied it from mm
+ // above.
+ }
+
+ // Copy pmas. We have to lock mm.activeMu for writing to make existing
+ // private pmas copy-on-write. We also have to lock mm2.activeMu since
+ // after copying vmas above, memmap.Mappables may call mm2.Invalidate. We
+ // only copy private pmas, since in the common case where fork(2) is
+ // immediately followed by execve(2), copying non-private pmas that can be
+ // regenerated by calling memmap.Mappable.Translate is a waste of time.
+ // (Linux does the same; compare kernel/fork.c:dup_mmap() =>
+ // mm/memory.c:copy_page_range().)
+ mm2.activeMu.Lock()
+ defer mm2.activeMu.Unlock()
+ mm.activeMu.Lock()
+ defer mm.activeMu.Unlock()
+ dstpgap := mm2.pmas.FirstGap()
+ var unmapAR usermem.AddrRange
+ for srcpseg := mm.pmas.FirstSegment(); srcpseg.Ok(); srcpseg = srcpseg.NextSegment() {
+ pma := srcpseg.ValuePtr()
+ if !pma.private {
+ continue
+ }
+ if !pma.needCOW {
+ pma.needCOW = true
+ if pma.effectivePerms.Write {
+ // We don't want to unmap the whole address space, even though
+ // doing so would reduce calls to unmapASLocked(), because mm
+ // will most likely continue to be used after the fork, so
+ // unmapping pmas unnecessarily will result in extra page
+ // faults. But we do want to merge consecutive AddrRanges
+ // across pma boundaries.
+ if unmapAR.End == srcpseg.Start() {
+ unmapAR.End = srcpseg.End()
+ } else {
+ if unmapAR.Length() != 0 {
+ mm.unmapASLocked(unmapAR)
+ }
+ unmapAR = srcpseg.Range()
+ }
+ pma.effectivePerms.Write = false
+ }
+ pma.maxPerms.Write = false
+ }
+ fr := srcpseg.fileRange()
+ mm2.incPrivateRef(fr)
+ srcpseg.ValuePtr().file.IncRef(fr)
+ addrRange := srcpseg.Range()
+ mm2.addRSSLocked(addrRange)
+ dstpgap = mm2.pmas.Insert(dstpgap, addrRange, *pma).NextGap()
+ }
+ if unmapAR.Length() != 0 {
+ mm.unmapASLocked(unmapAR)
+ }
+
+ // Between when we call memmap.Mappable.AddMapping while copying vmas and
+ // when we lock mm2.activeMu to copy pmas, calls to mm2.Invalidate() are
+ // ineffective because the pmas they invalidate haven't yet been copied,
+ // possibly allowing mm2 to get invalidated translations:
+ //
+ // Invalidating Mappable mm.Fork
+ // --------------------- -------
+ //
+ // mm2.Invalidate()
+ // mm.activeMu.Lock()
+ // mm.Invalidate() /* blocks */
+ // mm2.activeMu.Lock()
+ // (mm copies invalidated pma to mm2)
+ //
+ // This would technically be both safe (since we only copy private pmas,
+ // which will still hold a reference on their memory) and consistent with
+ // Linux, but we avoid it anyway by setting mm2.captureInvalidations during
+ // construction, causing calls to mm2.Invalidate() to be captured in
+ // mm2.capturedInvalidations, to be replayed after pmas are copied - i.e.
+ // here.
+ mm2.captureInvalidations = false
+ for _, invArgs := range mm2.capturedInvalidations {
+ mm2.invalidateLocked(invArgs.ar, invArgs.opts.InvalidatePrivate, true)
+ }
+ mm2.capturedInvalidations = nil
+
+ if mm2.executable != nil {
+ mm2.executable.IncRef()
+ }
+ return mm2, nil
+}
+
+// IncUsers increments mm's user count and returns true. If the user count is
+// already 0, IncUsers does nothing and returns false.
+func (mm *MemoryManager) IncUsers() bool {
+ return atomicbitops.IncUnlessZeroInt32(&mm.users)
+}
+
+// DecUsers decrements mm's user count. If the user count reaches 0, all
+// mappings in mm are unmapped.
+func (mm *MemoryManager) DecUsers(ctx context.Context) {
+ if users := atomic.AddInt32(&mm.users, -1); users > 0 {
+ return
+ } else if users < 0 {
+ panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users))
+ }
+
+ mm.aioManager.destroy()
+
+ mm.metadataMu.Lock()
+ exe := mm.executable
+ mm.executable = nil
+ mm.metadataMu.Unlock()
+ if exe != nil {
+ exe.DecRef()
+ }
+
+ mm.activeMu.Lock()
+ // Sanity check.
+ if atomic.LoadInt32(&mm.active) != 0 {
+ panic("active address space lost?")
+ }
+ // Make sure the AddressSpace is returned.
+ if mm.as != nil {
+ mm.as.Release()
+ mm.as = nil
+ }
+ mm.activeMu.Unlock()
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ // If mm is being dropped before mm.SetMmapLayout was called,
+ // mm.applicationAddrRange() will be empty.
+ if ar := mm.applicationAddrRange(); ar.Length() != 0 {
+ mm.unmapLocked(ctx, ar)
+ }
+}
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
new file mode 100644
index 000000000..9768e51f1
--- /dev/null
+++ b/pkg/sentry/mm/metadata.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// ArgvStart returns the start of the application argument vector.
+//
+// There is no guarantee that this value is sensible w.r.t. ArgvEnd.
+func (mm *MemoryManager) ArgvStart() usermem.Addr {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.argv.Start
+}
+
+// SetArgvStart sets the start of the application argument vector.
+func (mm *MemoryManager) SetArgvStart(a usermem.Addr) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.argv.Start = a
+}
+
+// ArgvEnd returns the end of the application argument vector.
+//
+// There is no guarantee that this value is sensible w.r.t. ArgvStart.
+func (mm *MemoryManager) ArgvEnd() usermem.Addr {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.argv.End
+}
+
+// SetArgvEnd sets the end of the application argument vector.
+func (mm *MemoryManager) SetArgvEnd(a usermem.Addr) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.argv.End = a
+}
+
+// EnvvStart returns the start of the application environment vector.
+//
+// There is no guarantee that this value is sensible w.r.t. EnvvEnd.
+func (mm *MemoryManager) EnvvStart() usermem.Addr {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.envv.Start
+}
+
+// SetEnvvStart sets the start of the application environment vector.
+func (mm *MemoryManager) SetEnvvStart(a usermem.Addr) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.envv.Start = a
+}
+
+// EnvvEnd returns the end of the application environment vector.
+//
+// There is no guarantee that this value is sensible w.r.t. EnvvStart.
+func (mm *MemoryManager) EnvvEnd() usermem.Addr {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.envv.End
+}
+
+// SetEnvvEnd sets the end of the application environment vector.
+func (mm *MemoryManager) SetEnvvEnd(a usermem.Addr) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.envv.End = a
+}
+
+// Auxv returns the current map of auxiliary vectors.
+func (mm *MemoryManager) Auxv() arch.Auxv {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return append(arch.Auxv(nil), mm.auxv...)
+}
+
+// SetAuxv sets the entire map of auxiliary vectors.
+func (mm *MemoryManager) SetAuxv(auxv arch.Auxv) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.auxv = append(arch.Auxv(nil), auxv...)
+}
+
+// Executable returns the executable, if available.
+//
+// An additional reference will be taken in the case of a non-nil executable,
+// which must be released by the caller.
+func (mm *MemoryManager) Executable() *fs.Dirent {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+
+ if mm.executable == nil {
+ return nil
+ }
+
+ mm.executable.IncRef()
+ return mm.executable
+}
+
+// SetExecutable sets the executable.
+//
+// This takes a reference on d.
+func (mm *MemoryManager) SetExecutable(d *fs.Dirent) {
+ mm.metadataMu.Lock()
+
+ // Grab a new reference.
+ d.IncRef()
+
+ // Set the executable.
+ orig := mm.executable
+ mm.executable = d
+
+ mm.metadataMu.Unlock()
+
+ // Release the old reference.
+ //
+ // Do this without holding the lock, since it may wind up doing some
+ // I/O to sync the dirent, etc.
+ if orig != nil {
+ orig.DecRef()
+ }
+}
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
new file mode 100644
index 000000000..eb6defa2b
--- /dev/null
+++ b/pkg/sentry/mm/mm.go
@@ -0,0 +1,456 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package mm provides a memory management subsystem. See README.md for a
+// detailed overview.
+//
+// Lock order:
+//
+// fs locks, except for memmap.Mappable locks
+// mm.MemoryManager.metadataMu
+// mm.MemoryManager.mappingMu
+// Locks taken by memmap.Mappable methods other than Translate
+// mm.MemoryManager.activeMu
+// Locks taken by memmap.Mappable.Translate
+// mm.privateRefs.mu
+// platform.AddressSpace locks
+// platform.File locks
+// mm.aioManager.mu
+// mm.AIOContext.mu
+//
+// Only mm.MemoryManager.Fork is permitted to lock mm.MemoryManager.activeMu in
+// multiple mm.MemoryManagers, as it does so in a well-defined order (forked
+// child first).
+package mm
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/third_party/gvsync"
+)
+
+// MemoryManager implements a virtual address space.
+//
+// +stateify savable
+type MemoryManager struct {
+ // p and mfp are immutable.
+ p platform.Platform
+ mfp pgalloc.MemoryFileProvider
+
+ // haveASIO is the cached result of p.SupportsAddressSpaceIO(). Aside from
+ // eliminating an indirect call in the hot I/O path, this makes
+ // MemoryManager.asioEnabled() a leaf function, allowing it to be inlined.
+ //
+ // haveASIO is immutable.
+ haveASIO bool `state:"nosave"`
+
+ // layout is the memory layout.
+ //
+ // layout is set by the binary loader before the MemoryManager can be used.
+ layout arch.MmapLayout
+
+ // privateRefs stores reference counts for private memory (memory whose
+ // ownership is shared by one or more pmas instead of being owned by a
+ // memmap.Mappable).
+ //
+ // privateRefs is immutable.
+ privateRefs *privateRefs
+
+ // users is the number of dependences on the mappings in the MemoryManager.
+ // When the number of references in users reaches zero, all mappings are
+ // unmapped.
+ //
+ // users is accessed using atomic memory operations.
+ users int32
+
+ // mappingMu is analogous to Linux's struct mm_struct::mmap_sem.
+ mappingMu gvsync.DowngradableRWMutex `state:"nosave"`
+
+ // vmas stores virtual memory areas. Since vmas are stored by value,
+ // clients should usually use vmaIterator.ValuePtr() instead of
+ // vmaIterator.Value() to get a pointer to the vma rather than a copy.
+ //
+ // Invariants: vmas are always page-aligned.
+ //
+ // vmas is protected by mappingMu.
+ vmas vmaSet
+
+ // brk is the mm's brk, which is manipulated using the brk(2) system call.
+ // The brk is initially set up by the loader which maps an executable
+ // binary into the mm.
+ //
+ // brk is protected by mappingMu.
+ brk usermem.AddrRange
+
+ // usageAS is vmas.Span(), cached to accelerate RLIMIT_AS checks.
+ //
+ // usageAS is protected by mappingMu.
+ usageAS uint64
+
+ // lockedAS is the combined size in bytes of all vmas with vma.mlockMode !=
+ // memmap.MLockNone.
+ //
+ // lockedAS is protected by mappingMu.
+ lockedAS uint64
+
+ // dataAS is the size of private data segments, like mm_struct->data_vm.
+ // It means the vma which is private, writable, not stack.
+ //
+ // dataAS is protected by mappingMu.
+ dataAS uint64
+
+ // New VMAs created by MMap use whichever of memmap.MMapOpts.MLockMode or
+ // defMLockMode is greater.
+ //
+ // defMLockMode is protected by mappingMu.
+ defMLockMode memmap.MLockMode
+
+ // activeMu is loosely analogous to Linux's struct
+ // mm_struct::page_table_lock.
+ activeMu gvsync.DowngradableRWMutex `state:"nosave"`
+
+ // pmas stores platform mapping areas used to implement vmas. Since pmas
+ // are stored by value, clients should usually use pmaIterator.ValuePtr()
+ // instead of pmaIterator.Value() to get a pointer to the pma rather than
+ // a copy.
+ //
+ // Inserting or removing segments from pmas should happen along with a
+ // call to mm.insertRSS or mm.removeRSS.
+ //
+ // Invariants: pmas are always page-aligned. If a pma exists for a given
+ // address, a vma must also exist for that address.
+ //
+ // pmas is protected by activeMu.
+ pmas pmaSet
+
+ // curRSS is pmas.Span(), cached to accelerate updates to maxRSS. It is
+ // reported as the MemoryManager's RSS.
+ //
+ // maxRSS should be modified only via insertRSS and removeRSS, not
+ // directly.
+ //
+ // maxRSS is protected by activeMu.
+ curRSS uint64
+
+ // maxRSS is the maximum resident set size in bytes of a MemoryManager.
+ // It is tracked as the application adds and removes mappings to pmas.
+ //
+ // maxRSS should be modified only via insertRSS, not directly.
+ //
+ // maxRSS is protected by activeMu.
+ maxRSS uint64
+
+ // as is the platform.AddressSpace that pmas are mapped into. active is the
+ // number of contexts that require as to be non-nil; if active == 0, as may
+ // be nil.
+ //
+ // as is protected by activeMu. active is manipulated with atomic memory
+ // operations; transitions to and from zero are additionally protected by
+ // activeMu. (This is because such transitions may need to be atomic with
+ // changes to as.)
+ as platform.AddressSpace `state:"nosave"`
+ active int32 `state:"zerovalue"`
+
+ // unmapAllOnActivate indicates that the next Activate call should activate
+ // an empty AddressSpace.
+ //
+ // This is used to ensure that an AddressSpace cached in
+ // NewAddressSpace is not used after some change in the MemoryManager
+ // or VMAs has made that AddressSpace stale.
+ //
+ // unmapAllOnActivate is protected by activeMu. It must only be set when
+ // there is no active or cached AddressSpace. If as != nil, then
+ // invalidations should be propagated immediately.
+ unmapAllOnActivate bool `state:"nosave"`
+
+ // If captureInvalidations is true, calls to MM.Invalidate() are recorded
+ // in capturedInvalidations rather than being applied immediately to pmas.
+ // This is to avoid a race condition in MM.Fork(); see that function for
+ // details.
+ //
+ // Both captureInvalidations and capturedInvalidations are protected by
+ // activeMu. Neither need to be saved since captureInvalidations is only
+ // enabled during MM.Fork(), during which saving can't occur.
+ captureInvalidations bool `state:"zerovalue"`
+ capturedInvalidations []invalidateArgs `state:"nosave"`
+
+ metadataMu sync.Mutex `state:"nosave"`
+
+ // argv is the application argv. This is set up by the loader and may be
+ // modified by prctl(PR_SET_MM_ARG_START/PR_SET_MM_ARG_END). No
+ // requirements apply to argv; we do not require that argv.WellFormed().
+ //
+ // argv is protected by metadataMu.
+ argv usermem.AddrRange
+
+ // envv is the application envv. This is set up by the loader and may be
+ // modified by prctl(PR_SET_MM_ENV_START/PR_SET_MM_ENV_END). No
+ // requirements apply to envv; we do not require that envv.WellFormed().
+ //
+ // envv is protected by metadataMu.
+ envv usermem.AddrRange
+
+ // auxv is the ELF's auxiliary vector.
+ //
+ // auxv is protected by metadataMu.
+ auxv arch.Auxv
+
+ // executable is the executable for this MemoryManager. If executable
+ // is not nil, it holds a reference on the Dirent.
+ //
+ // executable is protected by metadataMu.
+ executable *fs.Dirent
+
+ // aioManager keeps track of AIOContexts used for async IOs. AIOManager
+ // must be cloned when CLONE_VM is used.
+ aioManager aioManager
+}
+
+// vma represents a virtual memory area.
+//
+// +stateify savable
+type vma struct {
+ // mappable is the virtual memory object mapped by this vma. If mappable is
+ // nil, the vma represents a private anonymous mapping.
+ mappable memmap.Mappable
+
+ // off is the offset into mappable at which this vma begins. If mappable is
+ // nil, off is meaningless.
+ off uint64
+
+ // To speedup VMA save/restore, we group and save the following booleans
+ // as a single integer.
+
+ // realPerms are the memory permissions on this vma, as defined by the
+ // application.
+ realPerms usermem.AccessType `state:".(int)"`
+
+ // effectivePerms are the memory permissions on this vma which are
+ // actually used to control access.
+ //
+ // Invariant: effectivePerms == realPerms.Effective().
+ effectivePerms usermem.AccessType `state:"manual"`
+
+ // maxPerms limits the set of permissions that may ever apply to this
+ // memory, as well as accesses for which usermem.IOOpts.IgnorePermissions
+ // is true (e.g. ptrace(PTRACE_POKEDATA)).
+ //
+ // Invariant: maxPerms == maxPerms.Effective().
+ maxPerms usermem.AccessType `state:"manual"`
+
+ // private is true if this is a MAP_PRIVATE mapping, such that writes to
+ // the mapping are propagated to a copy.
+ private bool `state:"manual"`
+
+ // growsDown is true if the mapping may be automatically extended downward
+ // under certain conditions. If growsDown is true, mappable must be nil.
+ //
+ // There is currently no corresponding growsUp flag; in Linux, the only
+ // architectures that can have VM_GROWSUP mappings are ia64, parisc, and
+ // metag, none of which we currently support.
+ growsDown bool `state:"manual"`
+
+ mlockMode memmap.MLockMode
+
+ // If id is not nil, it controls the lifecycle of mappable and provides vma
+ // metadata shown in /proc/[pid]/maps, and the vma holds a reference.
+ id memmap.MappingIdentity
+
+ // If hint is non-empty, it is a description of the vma printed in
+ // /proc/[pid]/maps. hint takes priority over id.MappedName().
+ hint string
+}
+
+const (
+ vmaRealPermsRead = 1 << iota
+ vmaRealPermsWrite
+ vmaRealPermsExecute
+ vmaEffectivePermsRead
+ vmaEffectivePermsWrite
+ vmaEffectivePermsExecute
+ vmaMaxPermsRead
+ vmaMaxPermsWrite
+ vmaMaxPermsExecute
+ vmaPrivate
+ vmaGrowsDown
+)
+
+func (v *vma) saveRealPerms() int {
+ var b int
+ if v.realPerms.Read {
+ b |= vmaRealPermsRead
+ }
+ if v.realPerms.Write {
+ b |= vmaRealPermsWrite
+ }
+ if v.realPerms.Execute {
+ b |= vmaRealPermsExecute
+ }
+ if v.effectivePerms.Read {
+ b |= vmaEffectivePermsRead
+ }
+ if v.effectivePerms.Write {
+ b |= vmaEffectivePermsWrite
+ }
+ if v.effectivePerms.Execute {
+ b |= vmaEffectivePermsExecute
+ }
+ if v.maxPerms.Read {
+ b |= vmaMaxPermsRead
+ }
+ if v.maxPerms.Write {
+ b |= vmaMaxPermsWrite
+ }
+ if v.maxPerms.Execute {
+ b |= vmaMaxPermsExecute
+ }
+ if v.private {
+ b |= vmaPrivate
+ }
+ if v.growsDown {
+ b |= vmaGrowsDown
+ }
+ return b
+}
+
+func (v *vma) loadRealPerms(b int) {
+ if b&vmaRealPermsRead > 0 {
+ v.realPerms.Read = true
+ }
+ if b&vmaRealPermsWrite > 0 {
+ v.realPerms.Write = true
+ }
+ if b&vmaRealPermsExecute > 0 {
+ v.realPerms.Execute = true
+ }
+ if b&vmaEffectivePermsRead > 0 {
+ v.effectivePerms.Read = true
+ }
+ if b&vmaEffectivePermsWrite > 0 {
+ v.effectivePerms.Write = true
+ }
+ if b&vmaEffectivePermsExecute > 0 {
+ v.effectivePerms.Execute = true
+ }
+ if b&vmaMaxPermsRead > 0 {
+ v.maxPerms.Read = true
+ }
+ if b&vmaMaxPermsWrite > 0 {
+ v.maxPerms.Write = true
+ }
+ if b&vmaMaxPermsExecute > 0 {
+ v.maxPerms.Execute = true
+ }
+ if b&vmaPrivate > 0 {
+ v.private = true
+ }
+ if b&vmaGrowsDown > 0 {
+ v.growsDown = true
+ }
+}
+
+// pma represents a platform mapping area.
+//
+// +stateify savable
+type pma struct {
+ // file is the file mapped by this pma. Only pmas for which file ==
+ // MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to
+ // the corresponding file range while they exist.
+ file platform.File `state:"nosave"`
+
+ // off is the offset into file at which this pma begins.
+ //
+ // Note that pmas do *not* hold references on offsets in file! If private
+ // is true, MemoryManager.privateRefs holds the reference instead. If
+ // private is false, the corresponding memmap.Mappable holds the reference
+ // instead (per memmap.Mappable.Translate requirement).
+ off uint64
+
+ // translatePerms is the permissions returned by memmap.Mappable.Translate.
+ // If private is true, translatePerms is usermem.AnyAccess.
+ translatePerms usermem.AccessType
+
+ // effectivePerms is the permissions allowed for non-ignorePermissions
+ // accesses. maxPerms is the permissions allowed for ignorePermissions
+ // accesses. These are vma.effectivePerms and vma.maxPerms respectively,
+ // masked by pma.translatePerms and with Write disallowed if pma.needCOW is
+ // true.
+ //
+ // These are stored in the pma so that the IO implementation can avoid
+ // iterating mm.vmas when pmas already exist.
+ effectivePerms usermem.AccessType
+ maxPerms usermem.AccessType
+
+ // needCOW is true if writes to the mapping must be propagated to a copy.
+ needCOW bool
+
+ // private is true if this pma represents private memory.
+ //
+ // If private is true, file must be MemoryManager.mfp.MemoryFile(), the pma
+ // holds a reference on the mapped memory that is tracked in privateRefs,
+ // and calls to Invalidate for which
+ // memmap.InvalidateOpts.InvalidatePrivate is false should ignore the pma.
+ //
+ // If private is false, this pma caches a translation from the
+ // corresponding vma's memmap.Mappable.Translate.
+ private bool
+
+ // If internalMappings is not empty, it is the cached return value of
+ // file.MapInternal for the platform.FileRange mapped by this pma.
+ internalMappings safemem.BlockSeq `state:"nosave"`
+}
+
+// +stateify savable
+type privateRefs struct {
+ mu sync.Mutex `state:"nosave"`
+
+ // refs maps offsets into MemoryManager.mfp.MemoryFile() to the number of
+ // pmas (or, equivalently, MemoryManagers) that share ownership of the
+ // memory at that offset.
+ refs fileRefcountSet
+}
+
+type invalidateArgs struct {
+ ar usermem.AddrRange
+ opts memmap.InvalidateOpts
+}
+
+// fileRefcountSetFunctions implements segment.Functions for fileRefcountSet.
+type fileRefcountSetFunctions struct{}
+
+func (fileRefcountSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (fileRefcountSetFunctions) MaxKey() uint64 {
+ return ^uint64(0)
+}
+
+func (fileRefcountSetFunctions) ClearValue(_ *int32) {
+}
+
+func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) {
+ return rc1, rc1 == rc2
+}
+
+func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) {
+ return rc, rc
+}
diff --git a/pkg/sentry/mm/mm_state_autogen.go b/pkg/sentry/mm/mm_state_autogen.go
new file mode 100755
index 000000000..160f347f8
--- /dev/null
+++ b/pkg/sentry/mm/mm_state_autogen.go
@@ -0,0 +1,380 @@
+// automatically generated by stateify.
+
+package mm
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *aioManager) beforeSave() {}
+func (x *aioManager) save(m state.Map) {
+ x.beforeSave()
+ m.Save("contexts", &x.contexts)
+}
+
+func (x *aioManager) afterLoad() {}
+func (x *aioManager) load(m state.Map) {
+ m.Load("contexts", &x.contexts)
+}
+
+func (x *ioResult) beforeSave() {}
+func (x *ioResult) save(m state.Map) {
+ x.beforeSave()
+ m.Save("data", &x.data)
+ m.Save("ioEntry", &x.ioEntry)
+}
+
+func (x *ioResult) afterLoad() {}
+func (x *ioResult) load(m state.Map) {
+ m.Load("data", &x.data)
+ m.Load("ioEntry", &x.ioEntry)
+}
+
+func (x *AIOContext) beforeSave() {}
+func (x *AIOContext) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.dead) { m.Failf("dead is %v, expected zero", x.dead) }
+ m.Save("results", &x.results)
+ m.Save("maxOutstanding", &x.maxOutstanding)
+ m.Save("outstanding", &x.outstanding)
+}
+
+func (x *AIOContext) load(m state.Map) {
+ m.Load("results", &x.results)
+ m.Load("maxOutstanding", &x.maxOutstanding)
+ m.Load("outstanding", &x.outstanding)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *aioMappable) beforeSave() {}
+func (x *aioMappable) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("mfp", &x.mfp)
+ m.Save("fr", &x.fr)
+}
+
+func (x *aioMappable) afterLoad() {}
+func (x *aioMappable) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("mfp", &x.mfp)
+ m.Load("fr", &x.fr)
+}
+
+func (x *fileRefcountSet) beforeSave() {}
+func (x *fileRefcountSet) save(m state.Map) {
+ x.beforeSave()
+ var root *fileRefcountSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *fileRefcountSet) afterLoad() {}
+func (x *fileRefcountSet) load(m state.Map) {
+ m.LoadValue("root", new(*fileRefcountSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*fileRefcountSegmentDataSlices)) })
+}
+
+func (x *fileRefcountnode) beforeSave() {}
+func (x *fileRefcountnode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *fileRefcountnode) afterLoad() {}
+func (x *fileRefcountnode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *fileRefcountSegmentDataSlices) beforeSave() {}
+func (x *fileRefcountSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *fileRefcountSegmentDataSlices) afterLoad() {}
+func (x *fileRefcountSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func (x *ioList) beforeSave() {}
+func (x *ioList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *ioList) afterLoad() {}
+func (x *ioList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *ioEntry) beforeSave() {}
+func (x *ioEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *ioEntry) afterLoad() {}
+func (x *ioEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *MemoryManager) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.active) { m.Failf("active is %v, expected zero", x.active) }
+ if !state.IsZeroValue(x.captureInvalidations) { m.Failf("captureInvalidations is %v, expected zero", x.captureInvalidations) }
+ m.Save("p", &x.p)
+ m.Save("mfp", &x.mfp)
+ m.Save("layout", &x.layout)
+ m.Save("privateRefs", &x.privateRefs)
+ m.Save("users", &x.users)
+ m.Save("vmas", &x.vmas)
+ m.Save("brk", &x.brk)
+ m.Save("usageAS", &x.usageAS)
+ m.Save("lockedAS", &x.lockedAS)
+ m.Save("dataAS", &x.dataAS)
+ m.Save("defMLockMode", &x.defMLockMode)
+ m.Save("pmas", &x.pmas)
+ m.Save("curRSS", &x.curRSS)
+ m.Save("maxRSS", &x.maxRSS)
+ m.Save("argv", &x.argv)
+ m.Save("envv", &x.envv)
+ m.Save("auxv", &x.auxv)
+ m.Save("executable", &x.executable)
+ m.Save("aioManager", &x.aioManager)
+}
+
+func (x *MemoryManager) load(m state.Map) {
+ m.Load("p", &x.p)
+ m.Load("mfp", &x.mfp)
+ m.Load("layout", &x.layout)
+ m.Load("privateRefs", &x.privateRefs)
+ m.Load("users", &x.users)
+ m.Load("vmas", &x.vmas)
+ m.Load("brk", &x.brk)
+ m.Load("usageAS", &x.usageAS)
+ m.Load("lockedAS", &x.lockedAS)
+ m.Load("dataAS", &x.dataAS)
+ m.Load("defMLockMode", &x.defMLockMode)
+ m.Load("pmas", &x.pmas)
+ m.Load("curRSS", &x.curRSS)
+ m.Load("maxRSS", &x.maxRSS)
+ m.Load("argv", &x.argv)
+ m.Load("envv", &x.envv)
+ m.Load("auxv", &x.auxv)
+ m.Load("executable", &x.executable)
+ m.Load("aioManager", &x.aioManager)
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *vma) beforeSave() {}
+func (x *vma) save(m state.Map) {
+ x.beforeSave()
+ var realPerms int = x.saveRealPerms()
+ m.SaveValue("realPerms", realPerms)
+ m.Save("mappable", &x.mappable)
+ m.Save("off", &x.off)
+ m.Save("mlockMode", &x.mlockMode)
+ m.Save("id", &x.id)
+ m.Save("hint", &x.hint)
+}
+
+func (x *vma) afterLoad() {}
+func (x *vma) load(m state.Map) {
+ m.Load("mappable", &x.mappable)
+ m.Load("off", &x.off)
+ m.Load("mlockMode", &x.mlockMode)
+ m.Load("id", &x.id)
+ m.Load("hint", &x.hint)
+ m.LoadValue("realPerms", new(int), func(y interface{}) { x.loadRealPerms(y.(int)) })
+}
+
+func (x *pma) beforeSave() {}
+func (x *pma) save(m state.Map) {
+ x.beforeSave()
+ m.Save("off", &x.off)
+ m.Save("translatePerms", &x.translatePerms)
+ m.Save("effectivePerms", &x.effectivePerms)
+ m.Save("maxPerms", &x.maxPerms)
+ m.Save("needCOW", &x.needCOW)
+ m.Save("private", &x.private)
+}
+
+func (x *pma) afterLoad() {}
+func (x *pma) load(m state.Map) {
+ m.Load("off", &x.off)
+ m.Load("translatePerms", &x.translatePerms)
+ m.Load("effectivePerms", &x.effectivePerms)
+ m.Load("maxPerms", &x.maxPerms)
+ m.Load("needCOW", &x.needCOW)
+ m.Load("private", &x.private)
+}
+
+func (x *privateRefs) beforeSave() {}
+func (x *privateRefs) save(m state.Map) {
+ x.beforeSave()
+ m.Save("refs", &x.refs)
+}
+
+func (x *privateRefs) afterLoad() {}
+func (x *privateRefs) load(m state.Map) {
+ m.Load("refs", &x.refs)
+}
+
+func (x *pmaSet) beforeSave() {}
+func (x *pmaSet) save(m state.Map) {
+ x.beforeSave()
+ var root *pmaSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *pmaSet) afterLoad() {}
+func (x *pmaSet) load(m state.Map) {
+ m.LoadValue("root", new(*pmaSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*pmaSegmentDataSlices)) })
+}
+
+func (x *pmanode) beforeSave() {}
+func (x *pmanode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *pmanode) afterLoad() {}
+func (x *pmanode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *pmaSegmentDataSlices) beforeSave() {}
+func (x *pmaSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *pmaSegmentDataSlices) afterLoad() {}
+func (x *pmaSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func (x *SpecialMappable) beforeSave() {}
+func (x *SpecialMappable) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("mfp", &x.mfp)
+ m.Save("fr", &x.fr)
+ m.Save("name", &x.name)
+}
+
+func (x *SpecialMappable) afterLoad() {}
+func (x *SpecialMappable) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("mfp", &x.mfp)
+ m.Load("fr", &x.fr)
+ m.Load("name", &x.name)
+}
+
+func (x *vmaSet) beforeSave() {}
+func (x *vmaSet) save(m state.Map) {
+ x.beforeSave()
+ var root *vmaSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *vmaSet) afterLoad() {}
+func (x *vmaSet) load(m state.Map) {
+ m.LoadValue("root", new(*vmaSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*vmaSegmentDataSlices)) })
+}
+
+func (x *vmanode) beforeSave() {}
+func (x *vmanode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *vmanode) afterLoad() {}
+func (x *vmanode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *vmaSegmentDataSlices) beforeSave() {}
+func (x *vmaSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *vmaSegmentDataSlices) afterLoad() {}
+func (x *vmaSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func init() {
+ state.Register("mm.aioManager", (*aioManager)(nil), state.Fns{Save: (*aioManager).save, Load: (*aioManager).load})
+ state.Register("mm.ioResult", (*ioResult)(nil), state.Fns{Save: (*ioResult).save, Load: (*ioResult).load})
+ state.Register("mm.AIOContext", (*AIOContext)(nil), state.Fns{Save: (*AIOContext).save, Load: (*AIOContext).load})
+ state.Register("mm.aioMappable", (*aioMappable)(nil), state.Fns{Save: (*aioMappable).save, Load: (*aioMappable).load})
+ state.Register("mm.fileRefcountSet", (*fileRefcountSet)(nil), state.Fns{Save: (*fileRefcountSet).save, Load: (*fileRefcountSet).load})
+ state.Register("mm.fileRefcountnode", (*fileRefcountnode)(nil), state.Fns{Save: (*fileRefcountnode).save, Load: (*fileRefcountnode).load})
+ state.Register("mm.fileRefcountSegmentDataSlices", (*fileRefcountSegmentDataSlices)(nil), state.Fns{Save: (*fileRefcountSegmentDataSlices).save, Load: (*fileRefcountSegmentDataSlices).load})
+ state.Register("mm.ioList", (*ioList)(nil), state.Fns{Save: (*ioList).save, Load: (*ioList).load})
+ state.Register("mm.ioEntry", (*ioEntry)(nil), state.Fns{Save: (*ioEntry).save, Load: (*ioEntry).load})
+ state.Register("mm.MemoryManager", (*MemoryManager)(nil), state.Fns{Save: (*MemoryManager).save, Load: (*MemoryManager).load})
+ state.Register("mm.vma", (*vma)(nil), state.Fns{Save: (*vma).save, Load: (*vma).load})
+ state.Register("mm.pma", (*pma)(nil), state.Fns{Save: (*pma).save, Load: (*pma).load})
+ state.Register("mm.privateRefs", (*privateRefs)(nil), state.Fns{Save: (*privateRefs).save, Load: (*privateRefs).load})
+ state.Register("mm.pmaSet", (*pmaSet)(nil), state.Fns{Save: (*pmaSet).save, Load: (*pmaSet).load})
+ state.Register("mm.pmanode", (*pmanode)(nil), state.Fns{Save: (*pmanode).save, Load: (*pmanode).load})
+ state.Register("mm.pmaSegmentDataSlices", (*pmaSegmentDataSlices)(nil), state.Fns{Save: (*pmaSegmentDataSlices).save, Load: (*pmaSegmentDataSlices).load})
+ state.Register("mm.SpecialMappable", (*SpecialMappable)(nil), state.Fns{Save: (*SpecialMappable).save, Load: (*SpecialMappable).load})
+ state.Register("mm.vmaSet", (*vmaSet)(nil), state.Fns{Save: (*vmaSet).save, Load: (*vmaSet).load})
+ state.Register("mm.vmanode", (*vmanode)(nil), state.Fns{Save: (*vmanode).save, Load: (*vmanode).load})
+ state.Register("mm.vmaSegmentDataSlices", (*vmaSegmentDataSlices)(nil), state.Fns{Save: (*vmaSegmentDataSlices).save, Load: (*vmaSegmentDataSlices).load})
+}
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
new file mode 100644
index 000000000..ece561ff0
--- /dev/null
+++ b/pkg/sentry/mm/pma.go
@@ -0,0 +1,1036 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// existingPMAsLocked checks that pmas exist for all addresses in ar, and
+// support access of type (at, ignorePermissions). If so, it returns an
+// iterator to the pma containing ar.Start. Otherwise it returns a terminal
+// iterator.
+//
+// Preconditions: mm.activeMu must be locked. ar.Length() != 0.
+func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ first := mm.pmas.FindSegment(ar.Start)
+ pseg := first
+ for pseg.Ok() {
+ pma := pseg.ValuePtr()
+ perms := pma.effectivePerms
+ if ignorePermissions {
+ perms = pma.maxPerms
+ }
+ if !perms.SupersetOf(at) {
+ return pmaIterator{}
+ }
+ if needInternalMappings && pma.internalMappings.IsEmpty() {
+ return pmaIterator{}
+ }
+
+ if ar.End <= pseg.End() {
+ return first
+ }
+ pseg, _ = pseg.NextNonEmpty()
+ }
+
+ // Ran out of pmas before reaching ar.End.
+ return pmaIterator{}
+}
+
+// existingVecPMAsLocked returns true if pmas exist for all addresses in ars,
+// and support access of type (at, ignorePermissions).
+//
+// Preconditions: mm.activeMu must be locked.
+func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) bool {
+ for ; !ars.IsEmpty(); ars = ars.Tail() {
+ if ar := ars.Head(); ar.Length() != 0 && !mm.existingPMAsLocked(ar, at, ignorePermissions, needInternalMappings).Ok() {
+ return false
+ }
+ }
+ return true
+}
+
+// getPMAsLocked ensures that pmas exist for all addresses in ar, and support
+// access of type at. It returns:
+//
+// - An iterator to the pma containing ar.Start. If no pma contains ar.Start,
+// the iterator is unspecified.
+//
+// - An iterator to the gap after the last pma containing an address in ar. If
+// pmas exist for no addresses in ar, the iterator is to a gap that begins
+// before ar.Start.
+//
+// - An error that is non-nil if pmas exist for only a subset of ar.
+//
+// Preconditions: mm.mappingMu must be locked. mm.activeMu must be locked for
+// writing. ar.Length() != 0. vseg.Range().Contains(ar.Start). vmas must exist
+// for all addresses in ar, and support accesses of type at (i.e. permission
+// checks must have been performed against vmas).
+func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if !vseg.Range().Contains(ar.Start) {
+ panic(fmt.Sprintf("initial vma %v does not cover start of ar %v", vseg.Range(), ar))
+ }
+ }
+
+ // Page-align ar so that all AddrRanges are aligned.
+ end, ok := ar.End.RoundUp()
+ var alignerr error
+ if !ok {
+ end = ar.End.RoundDown()
+ alignerr = syserror.EFAULT
+ }
+ ar = usermem.AddrRange{ar.Start.RoundDown(), end}
+
+ pstart, pend, perr := mm.getPMAsInternalLocked(ctx, vseg, ar, at)
+ if pend.Start() <= ar.Start {
+ return pmaIterator{}, pend, perr
+ }
+ // getPMAsInternalLocked may not have returned pstart due to iterator
+ // invalidation.
+ if !pstart.Ok() {
+ pstart = mm.findOrSeekPrevUpperBoundPMA(ar.Start, pend)
+ }
+ if perr != nil {
+ return pstart, pend, perr
+ }
+ return pstart, pend, alignerr
+}
+
+// getVecPMAsLocked ensures that pmas exist for all addresses in ars, and
+// support access of type at. It returns the subset of ars for which pmas
+// exist. If this is not equal to ars, it returns a non-nil error explaining
+// why.
+//
+// Preconditions: mm.mappingMu must be locked. mm.activeMu must be locked for
+// writing. vmas must exist for all addresses in ars, and support accesses of
+// type at (i.e. permission checks must have been performed against vmas).
+func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType) (usermem.AddrRangeSeq, error) {
+ for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
+ ar := arsit.Head()
+ if ar.Length() == 0 {
+ continue
+ }
+ if checkInvariants {
+ if !ar.WellFormed() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ // Page-align ar so that all AddrRanges are aligned.
+ end, ok := ar.End.RoundUp()
+ var alignerr error
+ if !ok {
+ end = ar.End.RoundDown()
+ alignerr = syserror.EFAULT
+ }
+ ar = usermem.AddrRange{ar.Start.RoundDown(), end}
+
+ _, pend, perr := mm.getPMAsInternalLocked(ctx, mm.vmas.FindSegment(ar.Start), ar, at)
+ if perr != nil {
+ return truncatedAddrRangeSeq(ars, arsit, pend.Start()), perr
+ }
+ if alignerr != nil {
+ return truncatedAddrRangeSeq(ars, arsit, pend.Start()), alignerr
+ }
+ }
+
+ return ars, nil
+}
+
+// getPMAsInternalLocked is equivalent to getPMAsLocked, with the following
+// exceptions:
+//
+// - getPMAsInternalLocked returns a pmaIterator on a best-effort basis (that
+// is, the returned iterator may be terminal, even if a pma that contains
+// ar.Start exists). Returning this iterator on a best-effort basis allows
+// callers that require it to use it when it's cheaply available, while also
+// avoiding the overhead of retrieving it when it's not.
+//
+// - getPMAsInternalLocked additionally requires that ar is page-aligned.
+//
+// getPMAsInternalLocked is an implementation helper for getPMAsLocked and
+// getVecPMAsLocked; other clients should call one of those instead.
+func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if !vseg.Range().Contains(ar.Start) {
+ panic(fmt.Sprintf("initial vma %v does not cover start of ar %v", vseg.Range(), ar))
+ }
+ }
+
+ mf := mm.mfp.MemoryFile()
+ // Limit the range we allocate to ar, aligned to privateAllocUnit.
+ maskAR := privateAligned(ar)
+ didUnmapAS := false
+ // The range in which we iterate vmas and pmas is still limited to ar, to
+ // ensure that we don't allocate or COW-break a pma we don't need.
+ pseg, pgap := mm.pmas.Find(ar.Start)
+ pstart := pseg
+ for {
+ // Get pmas for this vma.
+ vsegAR := vseg.Range().Intersect(ar)
+ vma := vseg.ValuePtr()
+ pmaLoop:
+ for {
+ switch {
+ case pgap.Ok() && pgap.Start() < vsegAR.End:
+ // Need a pma here.
+ optAR := vseg.Range().Intersect(pgap.Range())
+ if checkInvariants {
+ if optAR.Length() <= 0 {
+ panic(fmt.Sprintf("vseg %v and pgap %v do not overlap", vseg, pgap))
+ }
+ }
+ if vma.mappable == nil {
+ // Private anonymous mappings get pmas by allocating.
+ allocAR := optAR.Intersect(maskAR)
+ fr, err := mf.Allocate(uint64(allocAR.Length()), usage.Anonymous)
+ if err != nil {
+ return pstart, pgap, err
+ }
+ if checkInvariants {
+ if !fr.WellFormed() || fr.Length() != uint64(allocAR.Length()) {
+ panic(fmt.Sprintf("Allocate(%v) returned invalid FileRange %v", allocAR.Length(), fr))
+ }
+ }
+ mm.addRSSLocked(allocAR)
+ mm.incPrivateRef(fr)
+ mf.IncRef(fr)
+ pseg, pgap = mm.pmas.Insert(pgap, allocAR, pma{
+ file: mf,
+ off: fr.Start,
+ translatePerms: usermem.AnyAccess,
+ effectivePerms: vma.effectivePerms,
+ maxPerms: vma.maxPerms,
+ // Since we just allocated this memory and have the
+ // only reference, the new pma does not need
+ // copy-on-write.
+ private: true,
+ }).NextNonEmpty()
+ pstart = pmaIterator{} // iterators invalidated
+ } else {
+ // Other mappings get pmas by translating.
+ optMR := vseg.mappableRangeOf(optAR)
+ reqAR := optAR.Intersect(ar)
+ reqMR := vseg.mappableRangeOf(reqAR)
+ perms := at
+ if vma.private {
+ // This pma will be copy-on-write; don't require write
+ // permission, but do require read permission to
+ // facilitate the copy.
+ //
+ // If at.Write is true, we will need to break
+ // copy-on-write immediately, which occurs after
+ // translation below.
+ perms.Read = true
+ perms.Write = false
+ }
+ ts, err := vma.mappable.Translate(ctx, reqMR, optMR, perms)
+ if checkInvariants {
+ if err := memmap.CheckTranslateResult(reqMR, optMR, perms, ts, err); err != nil {
+ panic(fmt.Sprintf("Mappable(%T).Translate(%v, %v, %v): %v", vma.mappable, reqMR, optMR, perms, err))
+ }
+ }
+ // Install a pma for each translation.
+ if len(ts) == 0 {
+ return pstart, pgap, err
+ }
+ pstart = pmaIterator{} // iterators invalidated
+ for _, t := range ts {
+ newpmaAR := vseg.addrRangeOf(t.Source)
+ newpma := pma{
+ file: t.File,
+ off: t.Offset,
+ translatePerms: t.Perms,
+ effectivePerms: vma.effectivePerms.Intersect(t.Perms),
+ maxPerms: vma.maxPerms.Intersect(t.Perms),
+ }
+ if vma.private {
+ newpma.effectivePerms.Write = false
+ newpma.maxPerms.Write = false
+ newpma.needCOW = true
+ }
+ mm.addRSSLocked(newpmaAR)
+ t.File.IncRef(t.FileRange())
+ // This is valid because memmap.Mappable.Translate is
+ // required to return Translations in increasing
+ // Translation.Source order.
+ pseg = mm.pmas.Insert(pgap, newpmaAR, newpma)
+ pgap = pseg.NextGap()
+ }
+ // The error returned by Translate is only significant if
+ // it occurred before ar.End.
+ if err != nil && vseg.addrRangeOf(ts[len(ts)-1].Source).End < ar.End {
+ return pstart, pgap, err
+ }
+ // Rewind pseg to the first pma inserted and continue the
+ // loop to check if we need to break copy-on-write.
+ pseg, pgap = mm.findOrSeekPrevUpperBoundPMA(vseg.addrRangeOf(ts[0].Source).Start, pgap), pmaGapIterator{}
+ continue
+ }
+
+ case pseg.Ok() && pseg.Start() < vsegAR.End:
+ oldpma := pseg.ValuePtr()
+ if at.Write && mm.isPMACopyOnWriteLocked(vseg, pseg) {
+ // Break copy-on-write by copying.
+ if checkInvariants {
+ if !oldpma.maxPerms.Read {
+ panic(fmt.Sprintf("pma %v needs to be copied for writing, but is not readable: %v", pseg.Range(), oldpma))
+ }
+ }
+ // The majority of copy-on-write breaks on executable pages
+ // come from:
+ //
+ // - The ELF loader, which must zero out bytes on the last
+ // page of each segment after the end of the segment.
+ //
+ // - gdb's use of ptrace to insert breakpoints.
+ //
+ // Neither of these cases has enough spatial locality to
+ // benefit from copying nearby pages, so if the vma is
+ // executable, only copy the pages required.
+ var copyAR usermem.AddrRange
+ if vseg.ValuePtr().effectivePerms.Execute {
+ copyAR = pseg.Range().Intersect(ar)
+ } else {
+ copyAR = pseg.Range().Intersect(maskAR)
+ }
+ // Get internal mappings from the pma to copy from.
+ if err := pseg.getInternalMappingsLocked(); err != nil {
+ return pstart, pseg.PrevGap(), err
+ }
+ // Copy contents.
+ fr, err := mf.AllocateAndFill(uint64(copyAR.Length()), usage.Anonymous, &safemem.BlockSeqReader{mm.internalMappingsLocked(pseg, copyAR)})
+ if _, ok := err.(safecopy.BusError); ok {
+ // If we got SIGBUS during the copy, deliver SIGBUS to
+ // userspace (instead of SIGSEGV) if we're breaking
+ // copy-on-write due to application page fault.
+ err = &memmap.BusError{err}
+ }
+ if fr.Length() == 0 {
+ return pstart, pseg.PrevGap(), err
+ }
+ // Unmap all of maskAR, not just copyAR, to minimize host
+ // syscalls. AddressSpace mappings must be removed before
+ // mm.decPrivateRef().
+ if !didUnmapAS {
+ mm.unmapASLocked(maskAR)
+ didUnmapAS = true
+ }
+ // Replace the pma with a copy in the part of the address
+ // range where copying was successful. This doesn't change
+ // RSS.
+ copyAR.End = copyAR.Start + usermem.Addr(fr.Length())
+ if copyAR != pseg.Range() {
+ pseg = mm.pmas.Isolate(pseg, copyAR)
+ pstart = pmaIterator{} // iterators invalidated
+ }
+ oldpma = pseg.ValuePtr()
+ if oldpma.private {
+ mm.decPrivateRef(pseg.fileRange())
+ }
+ oldpma.file.DecRef(pseg.fileRange())
+ mm.incPrivateRef(fr)
+ mf.IncRef(fr)
+ oldpma.file = mf
+ oldpma.off = fr.Start
+ oldpma.translatePerms = usermem.AnyAccess
+ oldpma.effectivePerms = vma.effectivePerms
+ oldpma.maxPerms = vma.maxPerms
+ oldpma.needCOW = false
+ oldpma.private = true
+ oldpma.internalMappings = safemem.BlockSeq{}
+ // Try to merge the pma with its neighbors.
+ if prev := pseg.PrevSegment(); prev.Ok() {
+ if merged := mm.pmas.Merge(prev, pseg); merged.Ok() {
+ pseg = merged
+ pstart = pmaIterator{} // iterators invalidated
+ }
+ }
+ if next := pseg.NextSegment(); next.Ok() {
+ if merged := mm.pmas.Merge(pseg, next); merged.Ok() {
+ pseg = merged
+ pstart = pmaIterator{} // iterators invalidated
+ }
+ }
+ // The error returned by AllocateAndFill is only
+ // significant if it occurred before ar.End.
+ if err != nil && pseg.End() < ar.End {
+ return pstart, pseg.NextGap(), err
+ }
+ // Ensure pseg and pgap are correct for the next iteration
+ // of the loop.
+ pseg, pgap = pseg.NextNonEmpty()
+ } else if !oldpma.translatePerms.SupersetOf(at) {
+ // Get new pmas (with sufficient permissions) by calling
+ // memmap.Mappable.Translate again.
+ if checkInvariants {
+ if oldpma.private {
+ panic(fmt.Sprintf("private pma %v has non-maximal pma.translatePerms: %v", pseg.Range(), oldpma))
+ }
+ }
+ // Allow the entire pma to be replaced.
+ optAR := pseg.Range()
+ optMR := vseg.mappableRangeOf(optAR)
+ reqAR := optAR.Intersect(ar)
+ reqMR := vseg.mappableRangeOf(reqAR)
+ perms := oldpma.translatePerms.Union(at)
+ ts, err := vma.mappable.Translate(ctx, reqMR, optMR, perms)
+ if checkInvariants {
+ if err := memmap.CheckTranslateResult(reqMR, optMR, perms, ts, err); err != nil {
+ panic(fmt.Sprintf("Mappable(%T).Translate(%v, %v, %v): %v", vma.mappable, reqMR, optMR, perms, err))
+ }
+ }
+ // Remove the part of the existing pma covered by new
+ // Translations, then insert new pmas. This doesn't change
+ // RSS. Note that we don't need to call unmapASLocked: any
+ // existing AddressSpace mappings are still valid (though
+ // less permissive than the new pmas indicate) until
+ // Invalidate is called, and will be replaced by future
+ // calls to mapASLocked.
+ if len(ts) == 0 {
+ return pstart, pseg.PrevGap(), err
+ }
+ transMR := memmap.MappableRange{ts[0].Source.Start, ts[len(ts)-1].Source.End}
+ transAR := vseg.addrRangeOf(transMR)
+ pseg = mm.pmas.Isolate(pseg, transAR)
+ pseg.ValuePtr().file.DecRef(pseg.fileRange())
+ pgap = mm.pmas.Remove(pseg)
+ pstart = pmaIterator{} // iterators invalidated
+ for _, t := range ts {
+ newpmaAR := vseg.addrRangeOf(t.Source)
+ newpma := pma{
+ file: t.File,
+ off: t.Offset,
+ translatePerms: t.Perms,
+ effectivePerms: vma.effectivePerms.Intersect(t.Perms),
+ maxPerms: vma.maxPerms.Intersect(t.Perms),
+ }
+ if vma.private {
+ newpma.effectivePerms.Write = false
+ newpma.maxPerms.Write = false
+ newpma.needCOW = true
+ }
+ t.File.IncRef(t.FileRange())
+ pseg = mm.pmas.Insert(pgap, newpmaAR, newpma)
+ pgap = pseg.NextGap()
+ }
+ // The error returned by Translate is only significant if
+ // it occurred before ar.End.
+ if err != nil && pseg.End() < ar.End {
+ return pstart, pgap, err
+ }
+ // Ensure pseg and pgap are correct for the next iteration
+ // of the loop.
+ if pgap.Range().Length() == 0 {
+ pseg, pgap = pgap.NextSegment(), pmaGapIterator{}
+ } else {
+ pseg = pmaIterator{}
+ }
+ } else {
+ // We have a usable pma; continue.
+ pseg, pgap = pseg.NextNonEmpty()
+ }
+
+ default:
+ break pmaLoop
+ }
+ }
+ // Go to the next vma.
+ if ar.End <= vseg.End() {
+ if pgap.Ok() {
+ return pstart, pgap, nil
+ }
+ return pstart, pseg.PrevGap(), nil
+ }
+ vseg = vseg.NextSegment()
+ }
+}
+
+const (
+ // When memory is allocated for a private pma, align the allocated address
+ // range to a privateAllocUnit boundary when possible. Larger values of
+ // privateAllocUnit may reduce page faults by allowing fewer, larger pmas
+ // to be mapped, but may result in larger amounts of wasted memory in the
+ // presence of fragmentation. privateAllocUnit must be a power-of-2
+ // multiple of usermem.PageSize.
+ privateAllocUnit = usermem.HugePageSize
+
+ privateAllocMask = privateAllocUnit - 1
+)
+
+func privateAligned(ar usermem.AddrRange) usermem.AddrRange {
+ aligned := usermem.AddrRange{ar.Start &^ privateAllocMask, ar.End}
+ if end := (ar.End + privateAllocMask) &^ privateAllocMask; end >= ar.End {
+ aligned.End = end
+ }
+ if checkInvariants {
+ if !aligned.IsSupersetOf(ar) {
+ panic(fmt.Sprintf("aligned AddrRange %#v is not a superset of ar %#v", aligned, ar))
+ }
+ }
+ return aligned
+}
+
+// isPMACopyOnWriteLocked returns true if the contents of the pma represented
+// by pseg must be copied to a new private pma to be written to.
+//
+// If the pma is a copy-on-write private pma, and holds the only reference on
+// the memory it maps, isPMACopyOnWriteLocked will take ownership of the memory
+// and update the pma to indicate that it does not require copy-on-write.
+//
+// Preconditions: vseg.Range().IsSupersetOf(pseg.Range()). mm.mappingMu must be
+// locked. mm.activeMu must be locked for writing.
+func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterator) bool {
+ pma := pseg.ValuePtr()
+ if !pma.needCOW {
+ return false
+ }
+ if !pma.private {
+ return true
+ }
+ // If we have the only reference on private memory to be copied, just take
+ // ownership of it instead of copying. If we do hold the only reference,
+ // additional references can only be taken by mm.Fork(), which is excluded
+ // by mm.activeMu, so this isn't racy.
+ mm.privateRefs.mu.Lock()
+ defer mm.privateRefs.mu.Unlock()
+ fr := pseg.fileRange()
+ // This check relies on mm.privateRefs.refs being kept fully merged.
+ rseg := mm.privateRefs.refs.FindSegment(fr.Start)
+ if rseg.Ok() && rseg.Value() == 1 && fr.End <= rseg.End() {
+ pma.needCOW = false
+ // pma.private => pma.translatePerms == usermem.AnyAccess
+ vma := vseg.ValuePtr()
+ pma.effectivePerms = vma.effectivePerms
+ pma.maxPerms = vma.maxPerms
+ return false
+ }
+ return true
+}
+
+// Invalidate implements memmap.MappingSpace.Invalidate.
+func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ mm.activeMu.Lock()
+ defer mm.activeMu.Unlock()
+ if mm.captureInvalidations {
+ mm.capturedInvalidations = append(mm.capturedInvalidations, invalidateArgs{ar, opts})
+ return
+ }
+ mm.invalidateLocked(ar, opts.InvalidatePrivate, true)
+}
+
+// invalidateLocked removes pmas and AddressSpace mappings of those pmas for
+// addresses in ar.
+//
+// Preconditions: mm.activeMu must be locked for writing. ar.Length() != 0. ar
+// must be page-aligned.
+func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ var didUnmapAS bool
+ pseg := mm.pmas.LowerBoundSegment(ar.Start)
+ for pseg.Ok() && pseg.Start() < ar.End {
+ pma := pseg.ValuePtr()
+ if (invalidatePrivate && pma.private) || (invalidateShared && !pma.private) {
+ pseg = mm.pmas.Isolate(pseg, ar)
+ pma = pseg.ValuePtr()
+ if !didUnmapAS {
+ // Unmap all of ar, not just pseg.Range(), to minimize host
+ // syscalls. AddressSpace mappings must be removed before
+ // mm.decPrivateRef().
+ mm.unmapASLocked(ar)
+ didUnmapAS = true
+ }
+ if pma.private {
+ mm.decPrivateRef(pseg.fileRange())
+ }
+ mm.removeRSSLocked(pseg.Range())
+ pma.file.DecRef(pseg.fileRange())
+ pseg = mm.pmas.Remove(pseg).NextSegment()
+ } else {
+ pseg = pseg.NextSegment()
+ }
+ }
+}
+
+// Pin returns the platform.File ranges currently mapped by addresses in ar in
+// mm, acquiring a reference on the returned ranges which the caller must
+// release by calling Unpin. If not all addresses are mapped, Pin returns a
+// non-nil error. Note that Pin may return both a non-empty slice of
+// PinnedRanges and a non-nil error.
+//
+// Pin does not prevent mapped ranges from changing, making it unsuitable for
+// most I/O. It should only be used in contexts that would use get_user_pages()
+// in the Linux kernel.
+//
+// Preconditions: ar.Length() != 0. ar must be page-aligned.
+func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ // Ensure that we have usable vmas.
+ mm.mappingMu.RLock()
+ vseg, vend, verr := mm.getVMAsLocked(ctx, ar, at, ignorePermissions)
+ if vendaddr := vend.Start(); vendaddr < ar.End {
+ if vendaddr <= ar.Start {
+ mm.mappingMu.RUnlock()
+ return nil, verr
+ }
+ ar.End = vendaddr
+ }
+
+ // Ensure that we have usable pmas.
+ mm.activeMu.Lock()
+ pseg, pend, perr := mm.getPMAsLocked(ctx, vseg, ar, at)
+ mm.mappingMu.RUnlock()
+ if pendaddr := pend.Start(); pendaddr < ar.End {
+ if pendaddr <= ar.Start {
+ mm.activeMu.Unlock()
+ return nil, perr
+ }
+ ar.End = pendaddr
+ }
+
+ // Gather pmas.
+ var prs []PinnedRange
+ for pseg.Ok() && pseg.Start() < ar.End {
+ psar := pseg.Range().Intersect(ar)
+ f := pseg.ValuePtr().file
+ fr := pseg.fileRangeOf(psar)
+ f.IncRef(fr)
+ prs = append(prs, PinnedRange{
+ Source: psar,
+ File: f,
+ Offset: fr.Start,
+ })
+ pseg = pseg.NextSegment()
+ }
+ mm.activeMu.Unlock()
+
+ // Return the first error in order of progress through ar.
+ if perr != nil {
+ return prs, perr
+ }
+ return prs, verr
+}
+
+// PinnedRanges are returned by MemoryManager.Pin.
+type PinnedRange struct {
+ // Source is the corresponding range of addresses.
+ Source usermem.AddrRange
+
+ // File is the mapped file.
+ File platform.File
+
+ // Offset is the offset into File at which this PinnedRange begins.
+ Offset uint64
+}
+
+// FileRange returns the platform.File offsets mapped by pr.
+func (pr PinnedRange) FileRange() platform.FileRange {
+ return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
+}
+
+// Unpin releases the reference held by prs.
+func Unpin(prs []PinnedRange) {
+ for i := range prs {
+ prs[i].File.DecRef(prs[i].FileRange())
+ }
+}
+
+// movePMAsLocked moves all pmas in oldAR to newAR.
+//
+// Preconditions: mm.activeMu must be locked for writing. oldAR.Length() != 0.
+// oldAR.Length() <= newAR.Length(). !oldAR.Overlaps(newAR).
+// mm.pmas.IsEmptyRange(newAR). oldAR and newAR must be page-aligned.
+func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
+ if checkInvariants {
+ if !oldAR.WellFormed() || oldAR.Length() <= 0 || !oldAR.IsPageAligned() {
+ panic(fmt.Sprintf("invalid oldAR: %v", oldAR))
+ }
+ if !newAR.WellFormed() || newAR.Length() <= 0 || !newAR.IsPageAligned() {
+ panic(fmt.Sprintf("invalid newAR: %v", newAR))
+ }
+ if oldAR.Length() > newAR.Length() {
+ panic(fmt.Sprintf("old address range %v may contain pmas that will not fit in new address range %v", oldAR, newAR))
+ }
+ if oldAR.Overlaps(newAR) {
+ panic(fmt.Sprintf("old and new address ranges overlap: %v, %v", oldAR, newAR))
+ }
+ // mm.pmas.IsEmptyRange is checked by mm.pmas.Insert.
+ }
+
+ type movedPMA struct {
+ oldAR usermem.AddrRange
+ pma pma
+ }
+ var movedPMAs []movedPMA
+ pseg := mm.pmas.LowerBoundSegment(oldAR.Start)
+ for pseg.Ok() && pseg.Start() < oldAR.End {
+ pseg = mm.pmas.Isolate(pseg, oldAR)
+ movedPMAs = append(movedPMAs, movedPMA{
+ oldAR: pseg.Range(),
+ pma: pseg.Value(),
+ })
+ pseg = mm.pmas.Remove(pseg).NextSegment()
+ // No RSS change is needed since we're re-inserting the same pmas
+ // below.
+ }
+
+ off := newAR.Start - oldAR.Start
+ pgap := mm.pmas.FindGap(newAR.Start)
+ for i := range movedPMAs {
+ mpma := &movedPMAs[i]
+ pmaNewAR := usermem.AddrRange{mpma.oldAR.Start + off, mpma.oldAR.End + off}
+ pgap = mm.pmas.Insert(pgap, pmaNewAR, mpma.pma).NextGap()
+ }
+
+ mm.unmapASLocked(oldAR)
+}
+
+// getPMAInternalMappingsLocked ensures that pmas for all addresses in ar have
+// cached internal mappings. It returns:
+//
+// - An iterator to the gap after the last pma with internal mappings
+// containing an address in ar. If internal mappings exist for no addresses in
+// ar, the iterator is to a gap that begins before ar.Start.
+//
+// - An error that is non-nil if internal mappings exist for only a subset of
+// ar.
+//
+// Preconditions: mm.activeMu must be locked for writing.
+// pseg.Range().Contains(ar.Start). pmas must exist for all addresses in ar.
+// ar.Length() != 0.
+//
+// Postconditions: getPMAInternalMappingsLocked does not invalidate iterators
+// into mm.pmas.
+func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) (pmaGapIterator, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !pseg.Range().Contains(ar.Start) {
+ panic(fmt.Sprintf("initial pma %v does not cover start of ar %v", pseg.Range(), ar))
+ }
+ }
+
+ for {
+ if err := pseg.getInternalMappingsLocked(); err != nil {
+ return pseg.PrevGap(), err
+ }
+ if ar.End <= pseg.End() {
+ return pseg.NextGap(), nil
+ }
+ pseg, _ = pseg.NextNonEmpty()
+ }
+}
+
+// getVecPMAInternalMappingsLocked ensures that pmas for all addresses in ars
+// have cached internal mappings. It returns the subset of ars for which
+// internal mappings exist. If this is not equal to ars, it returns a non-nil
+// error explaining why.
+//
+// Preconditions: mm.activeMu must be locked for writing. pmas must exist for
+// all addresses in ar.
+//
+// Postconditions: getVecPMAInternalMappingsLocked does not invalidate iterators
+// into mm.pmas.
+func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSeq) (usermem.AddrRangeSeq, error) {
+ for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
+ ar := arsit.Head()
+ if ar.Length() == 0 {
+ continue
+ }
+ if pend, err := mm.getPMAInternalMappingsLocked(mm.pmas.FindSegment(ar.Start), ar); err != nil {
+ return truncatedAddrRangeSeq(ars, arsit, pend.Start()), err
+ }
+ }
+ return ars, nil
+}
+
+// internalMappingsLocked returns internal mappings for addresses in ar.
+//
+// Preconditions: mm.activeMu must be locked. Internal mappings must have been
+// previously established for all addresses in ar. ar.Length() != 0.
+// pseg.Range().Contains(ar.Start).
+func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !pseg.Range().Contains(ar.Start) {
+ panic(fmt.Sprintf("initial pma %v does not cover start of ar %v", pseg.Range(), ar))
+ }
+ }
+
+ if ar.End <= pseg.End() {
+ // Since only one pma is involved, we can use pma.internalMappings
+ // directly, avoiding a slice allocation.
+ offset := uint64(ar.Start - pseg.Start())
+ return pseg.ValuePtr().internalMappings.DropFirst64(offset).TakeFirst64(uint64(ar.Length()))
+ }
+
+ var ims []safemem.Block
+ for {
+ pr := pseg.Range().Intersect(ar)
+ for pims := pseg.ValuePtr().internalMappings.DropFirst64(uint64(pr.Start - pseg.Start())).TakeFirst64(uint64(pr.Length())); !pims.IsEmpty(); pims = pims.Tail() {
+ ims = append(ims, pims.Head())
+ }
+ if ar.End <= pseg.End() {
+ break
+ }
+ pseg = pseg.NextSegment()
+ }
+ return safemem.BlockSeqFromSlice(ims)
+}
+
+// vecInternalMappingsLocked returns internal mappings for addresses in ars.
+//
+// Preconditions: mm.activeMu must be locked. Internal mappings must have been
+// previously established for all addresses in ars.
+func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) safemem.BlockSeq {
+ var ims []safemem.Block
+ for ; !ars.IsEmpty(); ars = ars.Tail() {
+ ar := ars.Head()
+ if ar.Length() == 0 {
+ continue
+ }
+ for pims := mm.internalMappingsLocked(mm.pmas.FindSegment(ar.Start), ar); !pims.IsEmpty(); pims = pims.Tail() {
+ ims = append(ims, pims.Head())
+ }
+ }
+ return safemem.BlockSeqFromSlice(ims)
+}
+
+// incPrivateRef acquires a reference on private pages in fr.
+func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
+ mm.privateRefs.mu.Lock()
+ defer mm.privateRefs.mu.Unlock()
+ refSet := &mm.privateRefs.refs
+ seg, gap := refSet.Find(fr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < fr.End:
+ seg = refSet.Isolate(seg, fr)
+ seg.SetValue(seg.Value() + 1)
+ seg, gap = seg.NextNonEmpty()
+ case gap.Ok() && gap.Start() < fr.End:
+ seg, gap = refSet.InsertWithoutMerging(gap, gap.Range().Intersect(fr), 1).NextNonEmpty()
+ default:
+ refSet.MergeAdjacent(fr)
+ return
+ }
+ }
+}
+
+// decPrivateRef releases a reference on private pages in fr.
+func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) {
+ var freed []platform.FileRange
+
+ mm.privateRefs.mu.Lock()
+ refSet := &mm.privateRefs.refs
+ seg := refSet.LowerBoundSegment(fr.Start)
+ for seg.Ok() && seg.Start() < fr.End {
+ seg = refSet.Isolate(seg, fr)
+ if old := seg.Value(); old == 1 {
+ freed = append(freed, seg.Range())
+ seg = refSet.Remove(seg).NextSegment()
+ } else {
+ seg.SetValue(old - 1)
+ seg = seg.NextSegment()
+ }
+ }
+ refSet.MergeAdjacent(fr)
+ mm.privateRefs.mu.Unlock()
+
+ mf := mm.mfp.MemoryFile()
+ for _, fr := range freed {
+ mf.DecRef(fr)
+ }
+}
+
+// addRSSLocked updates the current and maximum resident set size of a
+// MemoryManager to reflect the insertion of a pma at ar.
+//
+// Preconditions: mm.activeMu must be locked for writing.
+func (mm *MemoryManager) addRSSLocked(ar usermem.AddrRange) {
+ mm.curRSS += uint64(ar.Length())
+ if mm.curRSS > mm.maxRSS {
+ mm.maxRSS = mm.curRSS
+ }
+}
+
+// removeRSSLocked updates the current resident set size of a MemoryManager to
+// reflect the removal of a pma at ar.
+//
+// Preconditions: mm.activeMu must be locked for writing.
+func (mm *MemoryManager) removeRSSLocked(ar usermem.AddrRange) {
+ mm.curRSS -= uint64(ar.Length())
+}
+
+// pmaSetFunctions implements segment.Functions for pmaSet.
+type pmaSetFunctions struct{}
+
+func (pmaSetFunctions) MinKey() usermem.Addr {
+ return 0
+}
+
+func (pmaSetFunctions) MaxKey() usermem.Addr {
+ return ^usermem.Addr(0)
+}
+
+func (pmaSetFunctions) ClearValue(pma *pma) {
+ pma.file = nil
+ pma.internalMappings = safemem.BlockSeq{}
+}
+
+func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRange, pma2 pma) (pma, bool) {
+ if pma1.file != pma2.file ||
+ pma1.off+uint64(ar1.Length()) != pma2.off ||
+ pma1.translatePerms != pma2.translatePerms ||
+ pma1.effectivePerms != pma2.effectivePerms ||
+ pma1.maxPerms != pma2.maxPerms ||
+ pma1.needCOW != pma2.needCOW ||
+ pma1.private != pma2.private {
+ return pma{}, false
+ }
+
+ // Discard internal mappings instead of trying to merge them, since merging
+ // them requires an allocation and getting them again from the
+ // platform.File might not.
+ pma1.internalMappings = safemem.BlockSeq{}
+ return pma1, true
+}
+
+func (pmaSetFunctions) Split(ar usermem.AddrRange, p pma, split usermem.Addr) (pma, pma) {
+ newlen1 := uint64(split - ar.Start)
+ p2 := p
+ p2.off += newlen1
+ if !p.internalMappings.IsEmpty() {
+ p.internalMappings = p.internalMappings.TakeFirst64(newlen1)
+ p2.internalMappings = p2.internalMappings.DropFirst64(newlen1)
+ }
+ return p, p2
+}
+
+// findOrSeekPrevUpperBoundPMA returns mm.pmas.UpperBoundSegment(addr), but may do
+// so by scanning linearly backward from pgap.
+//
+// Preconditions: mm.activeMu must be locked. addr <= pgap.Start().
+func (mm *MemoryManager) findOrSeekPrevUpperBoundPMA(addr usermem.Addr, pgap pmaGapIterator) pmaIterator {
+ if checkInvariants {
+ if !pgap.Ok() {
+ panic("terminal pma iterator")
+ }
+ if addr > pgap.Start() {
+ panic(fmt.Sprintf("can't seek backward to %#x from %#x", addr, pgap.Start()))
+ }
+ }
+ // Optimistically check if pgap.PrevSegment() is the PMA we're looking for,
+ // which is the case if findOrSeekPrevUpperBoundPMA is called to find the
+ // start of a range containing only a single PMA.
+ if pseg := pgap.PrevSegment(); pseg.Start() <= addr {
+ return pseg
+ }
+ return mm.pmas.UpperBoundSegment(addr)
+}
+
+// getInternalMappingsLocked ensures that pseg.ValuePtr().internalMappings is
+// non-empty.
+//
+// Preconditions: mm.activeMu must be locked for writing.
+func (pseg pmaIterator) getInternalMappingsLocked() error {
+ pma := pseg.ValuePtr()
+ if pma.internalMappings.IsEmpty() {
+ // This must use maxPerms (instead of perms) because some permission
+ // constraints are only visible to vmas; for example, mappings of
+ // read-only files have vma.maxPerms.Write unset, but this may not be
+ // visible to the memmap.Mappable.
+ perms := pma.maxPerms
+ // We will never execute application code through an internal mapping.
+ perms.Execute = false
+ ims, err := pma.file.MapInternal(pseg.fileRange(), perms)
+ if err != nil {
+ return err
+ }
+ pma.internalMappings = ims
+ }
+ return nil
+}
+
+func (pseg pmaIterator) fileRange() platform.FileRange {
+ return pseg.fileRangeOf(pseg.Range())
+}
+
+// Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0.
+func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
+ if checkInvariants {
+ if !pseg.Ok() {
+ panic("terminal pma iterator")
+ }
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !pseg.Range().IsSupersetOf(ar) {
+ panic(fmt.Sprintf("ar %v out of bounds %v", ar, pseg.Range()))
+ }
+ }
+
+ pma := pseg.ValuePtr()
+ pstart := pseg.Start()
+ return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
+}
diff --git a/pkg/sentry/mm/pma_set.go b/pkg/sentry/mm/pma_set.go
new file mode 100755
index 000000000..6380d8619
--- /dev/null
+++ b/pkg/sentry/mm/pma_set.go
@@ -0,0 +1,1274 @@
+package mm
+
+import (
+ __generics_imported0 "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ pmaminDegree = 8
+
+ pmamaxDegree = 2 * pmaminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type pmaSet struct {
+ root pmanode `state:".(*pmaSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *pmaSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *pmaSet) IsEmptyRange(r __generics_imported0.AddrRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *pmaSet) Span() __generics_imported0.Addr {
+ var sz __generics_imported0.Addr
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *pmaSet) SpanRange(r __generics_imported0.AddrRange) __generics_imported0.Addr {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz __generics_imported0.Addr
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *pmaSet) FirstSegment() pmaIterator {
+ if s.root.nrSegments == 0 {
+ return pmaIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *pmaSet) LastSegment() pmaIterator {
+ if s.root.nrSegments == 0 {
+ return pmaIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *pmaSet) FirstGap() pmaGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return pmaGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *pmaSet) LastGap() pmaGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return pmaGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *pmaSet) Find(key __generics_imported0.Addr) (pmaIterator, pmaGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return pmaIterator{n, i}, pmaGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return pmaIterator{}, pmaGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *pmaSet) FindSegment(key __generics_imported0.Addr) pmaIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *pmaSet) LowerBoundSegment(min __generics_imported0.Addr) pmaIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *pmaSet) UpperBoundSegment(max __generics_imported0.Addr) pmaIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *pmaSet) FindGap(key __generics_imported0.Addr) pmaGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *pmaSet) LowerBoundGap(min __generics_imported0.Addr) pmaGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *pmaSet) UpperBoundGap(max __generics_imported0.Addr) pmaGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *pmaSet) Add(r __generics_imported0.AddrRange, val pma) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *pmaSet) AddWithoutMerging(r __generics_imported0.AddrRange, val pma) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *pmaSet) Insert(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (pmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *pmaSet) InsertWithoutMerging(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *pmaSet) InsertWithoutMergingUnchecked(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return pmaIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *pmaSet) Remove(seg pmaIterator) pmaGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ pmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(pmaGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *pmaSet) RemoveAll() {
+ s.root = pmanode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *pmaSet) RemoveRange(r __generics_imported0.AddrRange) pmaGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *pmaSet) Merge(first, second pmaIterator) pmaIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *pmaSet) MergeUnchecked(first, second pmaIterator) pmaIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (pmaSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return pmaIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *pmaSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *pmaSet) MergeRange(r __generics_imported0.AddrRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *pmaSet) MergeAdjacent(r __generics_imported0.AddrRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *pmaSet) Split(seg pmaIterator, split __generics_imported0.Addr) (pmaIterator, pmaIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *pmaSet) SplitUnchecked(seg pmaIterator, split __generics_imported0.Addr) (pmaIterator, pmaIterator) {
+ val1, val2 := (pmaSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.AddrRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *pmaSet) SplitAt(split __generics_imported0.Addr) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *pmaSet) Isolate(seg pmaIterator, r __generics_imported0.AddrRange) pmaIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *pmaSet) ApplyContiguous(r __generics_imported0.AddrRange, fn func(seg pmaIterator)) pmaGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return pmaGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return pmaGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type pmanode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *pmanode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [pmamaxDegree - 1]__generics_imported0.AddrRange
+ values [pmamaxDegree - 1]pma
+ children [pmamaxDegree]*pmanode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *pmanode) firstSegment() pmaIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return pmaIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *pmanode) lastSegment() pmaIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return pmaIterator{n, n.nrSegments - 1}
+}
+
+func (n *pmanode) prevSibling() *pmanode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *pmanode) nextSibling() *pmanode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *pmanode) rebalanceBeforeInsert(gap pmaGapIterator) pmaGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < pmamaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &pmanode{
+ nrSegments: pmaminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &pmanode{
+ nrSegments: pmaminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:pmaminDegree-1], n.keys[:pmaminDegree-1])
+ copy(left.values[:pmaminDegree-1], n.values[:pmaminDegree-1])
+ copy(right.keys[:pmaminDegree-1], n.keys[pmaminDegree:])
+ copy(right.values[:pmaminDegree-1], n.values[pmaminDegree:])
+ n.keys[0], n.values[0] = n.keys[pmaminDegree-1], n.values[pmaminDegree-1]
+ pmazeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:pmaminDegree], n.children[:pmaminDegree])
+ copy(right.children[:pmaminDegree], n.children[pmaminDegree:])
+ pmazeroNodeSlice(n.children[2:])
+ for i := 0; i < pmaminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < pmaminDegree {
+ return pmaGapIterator{left, gap.index}
+ }
+ return pmaGapIterator{right, gap.index - pmaminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[pmaminDegree-1], n.values[pmaminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &pmanode{
+ nrSegments: pmaminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:pmaminDegree-1], n.keys[pmaminDegree:])
+ copy(sibling.values[:pmaminDegree-1], n.values[pmaminDegree:])
+ pmazeroValueSlice(n.values[pmaminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:pmaminDegree], n.children[pmaminDegree:])
+ pmazeroNodeSlice(n.children[pmaminDegree:])
+ for i := 0; i < pmaminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = pmaminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < pmaminDegree {
+ return gap
+ }
+ return pmaGapIterator{sibling, gap.index - pmaminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *pmanode) rebalanceAfterRemove(gap pmaGapIterator) pmaGapIterator {
+ for {
+ if n.nrSegments >= pmaminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= pmaminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ pmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return pmaGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return pmaGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= pmaminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ pmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return pmaGapIterator{n, n.nrSegments}
+ }
+ return pmaGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return pmaGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return pmaGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *pmanode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = pmaGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ pmaSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type pmaIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *pmanode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg pmaIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg pmaIterator) Range() __generics_imported0.AddrRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg pmaIterator) Start() __generics_imported0.Addr {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg pmaIterator) End() __generics_imported0.Addr {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg pmaIterator) SetRangeUnchecked(r __generics_imported0.AddrRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg pmaIterator) SetRange(r __generics_imported0.AddrRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg pmaIterator) SetStartUnchecked(start __generics_imported0.Addr) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg pmaIterator) SetStart(start __generics_imported0.Addr) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg pmaIterator) SetEndUnchecked(end __generics_imported0.Addr) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg pmaIterator) SetEnd(end __generics_imported0.Addr) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg pmaIterator) Value() pma {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg pmaIterator) ValuePtr() *pma {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg pmaIterator) SetValue(val pma) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg pmaIterator) PrevSegment() pmaIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return pmaIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return pmaIterator{}
+ }
+ return pmasegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg pmaIterator) NextSegment() pmaIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return pmaIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return pmaIterator{}
+ }
+ return pmasegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg pmaIterator) PrevGap() pmaGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return pmaGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg pmaIterator) NextGap() pmaGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return pmaGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg pmaIterator) PrevNonEmpty() (pmaIterator, pmaGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return pmaIterator{}, gap
+ }
+ return gap.PrevSegment(), pmaGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg pmaIterator) NextNonEmpty() (pmaIterator, pmaGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return pmaIterator{}, gap
+ }
+ return gap.NextSegment(), pmaGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type pmaGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *pmanode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap pmaGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap pmaGapIterator) Range() __generics_imported0.AddrRange {
+ return __generics_imported0.AddrRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap pmaGapIterator) Start() __generics_imported0.Addr {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return pmaSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap pmaGapIterator) End() __generics_imported0.Addr {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return pmaSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap pmaGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap pmaGapIterator) PrevSegment() pmaIterator {
+ return pmasegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap pmaGapIterator) NextSegment() pmaIterator {
+ return pmasegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap pmaGapIterator) PrevGap() pmaGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return pmaGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap pmaGapIterator) NextGap() pmaGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return pmaGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func pmasegmentBeforePosition(n *pmanode, i int) pmaIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return pmaIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return pmaIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func pmasegmentAfterPosition(n *pmanode, i int) pmaIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return pmaIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return pmaIterator{n, i}
+}
+
+func pmazeroValueSlice(slice []pma) {
+
+ for i := range slice {
+ pmaSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func pmazeroNodeSlice(slice []*pmanode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *pmaSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *pmanode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *pmanode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type pmaSegmentDataSlices struct {
+ Start []__generics_imported0.Addr
+ End []__generics_imported0.Addr
+ Values []pma
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *pmaSet) ExportSortedSlices() *pmaSegmentDataSlices {
+ var sds pmaSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *pmaSet) ImportSortedSlices(sds *pmaSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := __generics_imported0.AddrRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *pmaSet) saveRoot() *pmaSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *pmaSet) loadRoot(sds *pmaSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
new file mode 100644
index 000000000..c8302a553
--- /dev/null
+++ b/pkg/sentry/mm/procfs.go
@@ -0,0 +1,289 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+const (
+ // devMinorBits is the number of minor bits in a device number. Linux:
+ // include/linux/kdev_t.h:MINORBITS
+ devMinorBits = 20
+
+ vsyscallEnd = usermem.Addr(0xffffffffff601000)
+ vsyscallMapsEntry = "ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0 [vsyscall]\n"
+ vsyscallSmapsEntry = vsyscallMapsEntry +
+ "Size: 4 kB\n" +
+ "Rss: 0 kB\n" +
+ "Pss: 0 kB\n" +
+ "Shared_Clean: 0 kB\n" +
+ "Shared_Dirty: 0 kB\n" +
+ "Private_Clean: 0 kB\n" +
+ "Private_Dirty: 0 kB\n" +
+ "Referenced: 0 kB\n" +
+ "Anonymous: 0 kB\n" +
+ "AnonHugePages: 0 kB\n" +
+ "Shared_Hugetlb: 0 kB\n" +
+ "Private_Hugetlb: 0 kB\n" +
+ "Swap: 0 kB\n" +
+ "SwapPss: 0 kB\n" +
+ "KernelPageSize: 4 kB\n" +
+ "MMUPageSize: 4 kB\n" +
+ "Locked: 0 kB\n" +
+ "VmFlags: rd ex \n"
+)
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (mm *MemoryManager) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadMapsSeqFileData is called by fs/proc.mapsData.ReadSeqFileData to
+// implement /proc/[pid]/maps.
+func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ var data []seqfile.SeqData
+ var start usermem.Addr
+ if handle != nil {
+ start = *handle.(*usermem.Addr)
+ }
+ for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
+ // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
+ // "panic: autosave error: type usermem.Addr is not registered".
+ vmaAddr := vseg.End()
+ data = append(data, seqfile.SeqData{
+ Buf: mm.vmaMapsEntryLocked(ctx, vseg),
+ Handle: &vmaAddr,
+ })
+ }
+
+ // We always emulate vsyscall, so advertise it here. Everything about a
+ // vsyscall region is static, so just hard code the maps entry since we
+ // don't have a real vma backing it. The vsyscall region is at the end of
+ // the virtual address space so nothing should be mapped after it (if
+ // something is really mapped in the tiny ~10 MiB segment afterwards, we'll
+ // get the sorting on the maps file wrong at worst; but that's not possible
+ // on any current platform).
+ //
+ // Artifically adjust the seqfile handle so we only output vsyscall entry once.
+ if start != vsyscallEnd {
+ // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
+ vmaAddr := vsyscallEnd
+ data = append(data, seqfile.SeqData{
+ Buf: []byte(vsyscallMapsEntry),
+ Handle: &vmaAddr,
+ })
+ }
+ return data, 1
+}
+
+// vmaMapsEntryLocked returns a /proc/[pid]/maps entry for the vma iterated by
+// vseg, including the trailing newline.
+//
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) vmaMapsEntryLocked(ctx context.Context, vseg vmaIterator) []byte {
+ var b bytes.Buffer
+ mm.appendVMAMapsEntryLocked(ctx, vseg, &b)
+ return b.Bytes()
+}
+
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaIterator, b *bytes.Buffer) {
+ vma := vseg.ValuePtr()
+ private := "p"
+ if !vma.private {
+ private = "s"
+ }
+
+ var dev, ino uint64
+ if vma.id != nil {
+ dev = vma.id.DeviceID()
+ ino = vma.id.InodeID()
+ }
+ devMajor := uint32(dev >> devMinorBits)
+ devMinor := uint32(dev & ((1 << devMinorBits) - 1))
+
+ // Do not include the guard page: fs/proc/task_mmu.c:show_map_vma() =>
+ // stack_guard_page_start().
+ fmt.Fprintf(b, "%08x-%08x %s%s %08x %02x:%02x %d ",
+ vseg.Start(), vseg.End(), vma.realPerms, private, vma.off, devMajor, devMinor, ino)
+
+ // Figure out our filename or hint.
+ var s string
+ if vma.hint != "" {
+ s = vma.hint
+ } else if vma.id != nil {
+ // FIXME(jamieliu): We are holding mm.mappingMu here, which is
+ // consistent with Linux's holding mmap_sem in
+ // fs/proc/task_mmu.c:show_map_vma() => fs/seq_file.c:seq_file_path().
+ // However, it's not clear that fs.File.MappedName() is actually
+ // consistent with this lock order.
+ s = vma.id.MappedName(ctx)
+ }
+ if s != "" {
+ // Per linux, we pad until the 74th character.
+ if pad := 73 - b.Len(); pad > 0 {
+ b.WriteString(strings.Repeat(" ", pad))
+ }
+ b.WriteString(s)
+ }
+ b.WriteString("\n")
+}
+
+// ReadSmapsSeqFileData is called by fs/proc.smapsData.ReadSeqFileData to
+// implement /proc/[pid]/smaps.
+func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ var data []seqfile.SeqData
+ var start usermem.Addr
+ if handle != nil {
+ start = *handle.(*usermem.Addr)
+ }
+ for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
+ // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
+ // "panic: autosave error: type usermem.Addr is not registered".
+ vmaAddr := vseg.End()
+ data = append(data, seqfile.SeqData{
+ Buf: mm.vmaSmapsEntryLocked(ctx, vseg),
+ Handle: &vmaAddr,
+ })
+ }
+
+ // We always emulate vsyscall, so advertise it here. See
+ // ReadMapsSeqFileData for additional commentary.
+ if start != vsyscallEnd {
+ // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
+ vmaAddr := vsyscallEnd
+ data = append(data, seqfile.SeqData{
+ Buf: []byte(vsyscallSmapsEntry),
+ Handle: &vmaAddr,
+ })
+ }
+ return data, 1
+}
+
+// vmaSmapsEntryLocked returns a /proc/[pid]/smaps entry for the vma iterated
+// by vseg, including the trailing newline.
+//
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) vmaSmapsEntryLocked(ctx context.Context, vseg vmaIterator) []byte {
+ var b bytes.Buffer
+ mm.appendVMAMapsEntryLocked(ctx, vseg, &b)
+ vma := vseg.ValuePtr()
+
+ // We take mm.activeMu here in each call to vmaSmapsEntryLocked, instead of
+ // requiring it to be locked as a precondition, to reduce the latency
+ // impact of reading /proc/[pid]/smaps on concurrent performance-sensitive
+ // operations requiring activeMu for writing like faults.
+ mm.activeMu.RLock()
+ var rss uint64
+ var anon uint64
+ vsegAR := vseg.Range()
+ for pseg := mm.pmas.LowerBoundSegment(vsegAR.Start); pseg.Ok() && pseg.Start() < vsegAR.End; pseg = pseg.NextSegment() {
+ psegAR := pseg.Range().Intersect(vsegAR)
+ size := uint64(psegAR.Length())
+ rss += size
+ if pseg.ValuePtr().private {
+ anon += size
+ }
+ }
+ mm.activeMu.RUnlock()
+
+ fmt.Fprintf(&b, "Size: %8d kB\n", vseg.Range().Length()/1024)
+ fmt.Fprintf(&b, "Rss: %8d kB\n", rss/1024)
+ // Currently we report PSS = RSS, i.e. we pretend each page mapped by a pma
+ // is only mapped by that pma. This avoids having to query memmap.Mappables
+ // for reference count information on each page. As a corollary, all pages
+ // are accounted as "private" whether or not the vma is private; compare
+ // Linux's fs/proc/task_mmu.c:smaps_account().
+ fmt.Fprintf(&b, "Pss: %8d kB\n", rss/1024)
+ fmt.Fprintf(&b, "Shared_Clean: %8d kB\n", 0)
+ fmt.Fprintf(&b, "Shared_Dirty: %8d kB\n", 0)
+ // Pretend that all pages are dirty if the vma is writable, and clean otherwise.
+ clean := rss
+ if vma.effectivePerms.Write {
+ clean = 0
+ }
+ fmt.Fprintf(&b, "Private_Clean: %8d kB\n", clean/1024)
+ fmt.Fprintf(&b, "Private_Dirty: %8d kB\n", (rss-clean)/1024)
+ // Pretend that all pages are "referenced" (recently touched).
+ fmt.Fprintf(&b, "Referenced: %8d kB\n", rss/1024)
+ fmt.Fprintf(&b, "Anonymous: %8d kB\n", anon/1024)
+ // Hugepages (hugetlb and THP) are not implemented.
+ fmt.Fprintf(&b, "AnonHugePages: %8d kB\n", 0)
+ fmt.Fprintf(&b, "Shared_Hugetlb: %8d kB\n", 0)
+ fmt.Fprintf(&b, "Private_Hugetlb: %7d kB\n", 0)
+ // Swap is not implemented.
+ fmt.Fprintf(&b, "Swap: %8d kB\n", 0)
+ fmt.Fprintf(&b, "SwapPss: %8d kB\n", 0)
+ fmt.Fprintf(&b, "KernelPageSize: %8d kB\n", usermem.PageSize/1024)
+ fmt.Fprintf(&b, "MMUPageSize: %8d kB\n", usermem.PageSize/1024)
+ locked := rss
+ if vma.mlockMode == memmap.MLockNone {
+ locked = 0
+ }
+ fmt.Fprintf(&b, "Locked: %8d kB\n", locked/1024)
+
+ b.WriteString("VmFlags: ")
+ if vma.realPerms.Read {
+ b.WriteString("rd ")
+ }
+ if vma.realPerms.Write {
+ b.WriteString("wr ")
+ }
+ if vma.realPerms.Execute {
+ b.WriteString("ex ")
+ }
+ if vma.canWriteMappableLocked() { // VM_SHARED
+ b.WriteString("sh ")
+ }
+ if vma.maxPerms.Read {
+ b.WriteString("mr ")
+ }
+ if vma.maxPerms.Write {
+ b.WriteString("mw ")
+ }
+ if vma.maxPerms.Execute {
+ b.WriteString("me ")
+ }
+ if !vma.private { // VM_MAYSHARE
+ b.WriteString("ms ")
+ }
+ if vma.growsDown {
+ b.WriteString("gd ")
+ }
+ if vma.mlockMode != memmap.MLockNone { // VM_LOCKED
+ b.WriteString("lo ")
+ }
+ if vma.mlockMode == memmap.MLockLazy { // VM_LOCKONFAULT
+ b.WriteString("?? ") // no explicit encoding in fs/proc/task_mmu.c:show_smap_vma_flags()
+ }
+ if vma.private && vma.effectivePerms.Write { // VM_ACCOUNT
+ b.WriteString("ac ")
+ }
+ b.WriteString("\n")
+
+ return b.Bytes()
+}
diff --git a/pkg/sentry/mm/save_restore.go b/pkg/sentry/mm/save_restore.go
new file mode 100644
index 000000000..0385957bd
--- /dev/null
+++ b/pkg/sentry/mm/save_restore.go
@@ -0,0 +1,57 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// InvalidateUnsavable invokes memmap.Mappable.InvalidateUnsavable on all
+// Mappables mapped by mm.
+func (mm *MemoryManager) InvalidateUnsavable(ctx context.Context) error {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
+ if vma := vseg.ValuePtr(); vma.mappable != nil {
+ if err := vma.mappable.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// beforeSave is invoked by stateify.
+func (mm *MemoryManager) beforeSave() {
+ mf := mm.mfp.MemoryFile()
+ for pseg := mm.pmas.FirstSegment(); pseg.Ok(); pseg = pseg.NextSegment() {
+ if pma := pseg.ValuePtr(); pma.file != mf {
+ // InvalidateUnsavable should have caused all such pmas to be
+ // invalidated.
+ panic(fmt.Sprintf("Can't save pma %#v with non-MemoryFile of type %T:\n%s", pseg.Range(), pma.file, mm))
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (mm *MemoryManager) afterLoad() {
+ mm.haveASIO = mm.p.SupportsAddressSpaceIO()
+ mf := mm.mfp.MemoryFile()
+ for pseg := mm.pmas.FirstSegment(); pseg.Ok(); pseg = pseg.NextSegment() {
+ pseg.ValuePtr().file = mf
+ }
+}
diff --git a/pkg/sentry/mm/shm.go b/pkg/sentry/mm/shm.go
new file mode 100644
index 000000000..12913007b
--- /dev/null
+++ b/pkg/sentry/mm/shm.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/shm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// DetachShm unmaps a sysv shared memory segment.
+func (mm *MemoryManager) DetachShm(ctx context.Context, addr usermem.Addr) error {
+ if addr != addr.RoundDown() {
+ // "... shmaddr is not aligned on a page boundary." - man shmdt(2)
+ return syserror.EINVAL
+ }
+
+ var detached *shm.Shm
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+
+ // Find and remove the first vma containing an address >= addr that maps a
+ // segment originally attached at addr.
+ vseg := mm.vmas.LowerBoundSegment(addr)
+ for vseg.Ok() {
+ vma := vseg.ValuePtr()
+ if shm, ok := vma.mappable.(*shm.Shm); ok && vseg.Start() >= addr && uint64(vseg.Start()-addr) == vma.off {
+ detached = shm
+ vseg = mm.unmapLocked(ctx, vseg.Range()).NextSegment()
+ break
+ } else {
+ vseg = vseg.NextSegment()
+ }
+ }
+
+ if detached == nil {
+ // There is no shared memory segment attached at addr.
+ return syserror.EINVAL
+ }
+
+ // Remove all vmas that could have been created by the same attach.
+ end := addr + usermem.Addr(detached.EffectiveSize())
+ for vseg.Ok() && vseg.End() <= end {
+ vma := vseg.ValuePtr()
+ if vma.mappable == detached && uint64(vseg.Start()-addr) == vma.off {
+ vseg = mm.unmapLocked(ctx, vseg.Range()).NextSegment()
+ } else {
+ vseg = vseg.NextSegment()
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
new file mode 100644
index 000000000..687959005
--- /dev/null
+++ b/pkg/sentry/mm/special_mappable.go
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// SpecialMappable implements memmap.MappingIdentity and memmap.Mappable with
+// semantics similar to Linux's mm/mmap.c:_install_special_mapping(), except
+// that SpecialMappable takes ownership of the memory that it represents
+// (_install_special_mapping() does not.)
+//
+// +stateify savable
+type SpecialMappable struct {
+ refs.AtomicRefCount
+
+ mfp pgalloc.MemoryFileProvider
+ fr platform.FileRange
+ name string
+}
+
+// NewSpecialMappable returns a SpecialMappable that owns fr, which represents
+// offsets in mfp.MemoryFile() that contain the SpecialMappable's data. The
+// SpecialMappable will use the given name in /proc/[pid]/maps.
+//
+// Preconditions: fr.Length() != 0.
+func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable {
+ return &SpecialMappable{mfp: mfp, fr: fr, name: name}
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+func (m *SpecialMappable) DecRef() {
+ m.AtomicRefCount.DecRefWithDestructor(func() {
+ m.mfp.MemoryFile().DecRef(m.fr)
+ })
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (m *SpecialMappable) MappedName(ctx context.Context) string {
+ return m.name
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (m *SpecialMappable) DeviceID() uint64 {
+ return 0
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (m *SpecialMappable) InodeID() uint64 {
+ return 0
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (m *SpecialMappable) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ // Linux: vm_file is NULL, causing msync to skip it entirely.
+ return nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (*SpecialMappable) AddMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) error {
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (*SpecialMappable) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) {
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (*SpecialMappable) CopyMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, usermem.AddrRange, uint64, bool) error {
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (m *SpecialMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ var err error
+ if required.End > m.fr.Length() {
+ err = &memmap.BusError{syserror.EFAULT}
+ }
+ if source := optional.Intersect(memmap.MappableRange{0, m.fr.Length()}); source.Length() != 0 {
+ return []memmap.Translation{
+ {
+ Source: source,
+ File: m.mfp.MemoryFile(),
+ Offset: m.fr.Start + source.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, err
+ }
+ return nil, err
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (m *SpecialMappable) InvalidateUnsavable(ctx context.Context) error {
+ // Since data is stored in pgalloc.MemoryFile, the contents of which are
+ // preserved across save/restore, we don't need to do anything.
+ return nil
+}
+
+// MemoryFileProvider returns the MemoryFileProvider whose MemoryFile stores
+// the SpecialMappable's contents.
+func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider {
+ return m.mfp
+}
+
+// FileRange returns the offsets into MemoryFileProvider().MemoryFile() that
+// store the SpecialMappable's contents.
+func (m *SpecialMappable) FileRange() platform.FileRange {
+ return m.fr
+}
+
+// Length returns the length of the SpecialMappable.
+func (m *SpecialMappable) Length() uint64 {
+ return m.fr.Length()
+}
+
+// NewSharedAnonMappable returns a SpecialMappable that implements the
+// semantics of mmap(MAP_SHARED|MAP_ANONYMOUS) and mappings of /dev/zero.
+//
+// TODO(jamieliu): The use of SpecialMappable is a lazy code reuse hack. Linux
+// uses an ephemeral file created by mm/shmem.c:shmem_zero_setup(); we should
+// do the same to get non-zero device and inode IDs.
+func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*SpecialMappable, error) {
+ if length == 0 {
+ return nil, syserror.EINVAL
+ }
+ alignedLen, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return nil, syserror.EINVAL
+ }
+ fr, err := mfp.MemoryFile().Allocate(uint64(alignedLen), usage.Anonymous)
+ if err != nil {
+ return nil, err
+ }
+ return NewSpecialMappable("/dev/zero (deleted)", mfp, fr), nil
+}
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
new file mode 100644
index 000000000..0368c6794
--- /dev/null
+++ b/pkg/sentry/mm/syscalls.go
@@ -0,0 +1,1197 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+ mrand "math/rand"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/futex"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// HandleUserFault handles an application page fault. sp is the faulting
+// application thread's stack pointer.
+//
+// Preconditions: mm.as != nil.
+func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr usermem.Addr, at usermem.AccessType, sp usermem.Addr) error {
+ ar, ok := addr.RoundDown().ToRange(usermem.PageSize)
+ if !ok {
+ return syserror.EFAULT
+ }
+
+ // Don't bother trying existingPMAsLocked; in most cases, if we did have
+ // existing pmas, we wouldn't have faulted.
+
+ // Ensure that we have a usable vma. Here and below, since we are only
+ // asking for a single page, there is no possibility of partial success,
+ // and any error is immediately fatal.
+ mm.mappingMu.RLock()
+ vseg, _, err := mm.getVMAsLocked(ctx, ar, at, false)
+ if err != nil {
+ mm.mappingMu.RUnlock()
+ return err
+ }
+
+ // Ensure that we have a usable pma.
+ mm.activeMu.Lock()
+ pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, at)
+ mm.mappingMu.RUnlock()
+ if err != nil {
+ mm.activeMu.Unlock()
+ return err
+ }
+
+ // Downgrade to a read-lock on activeMu since we don't need to mutate pmas
+ // anymore.
+ mm.activeMu.DowngradeLock()
+
+ // Map the faulted page into the active AddressSpace.
+ err = mm.mapASLocked(pseg, ar, false)
+ mm.activeMu.RUnlock()
+ return err
+}
+
+// MMap establishes a memory mapping.
+func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (usermem.Addr, error) {
+ if opts.Length == 0 {
+ return 0, syserror.EINVAL
+ }
+ length, ok := usermem.Addr(opts.Length).RoundUp()
+ if !ok {
+ return 0, syserror.ENOMEM
+ }
+ opts.Length = uint64(length)
+
+ if opts.Mappable != nil {
+ // Offset must be aligned.
+ if usermem.Addr(opts.Offset).RoundDown() != usermem.Addr(opts.Offset) {
+ return 0, syserror.EINVAL
+ }
+ // Offset + length must not overflow.
+ if end := opts.Offset + opts.Length; end < opts.Offset {
+ return 0, syserror.ENOMEM
+ }
+ } else {
+ opts.Offset = 0
+ if !opts.Private {
+ if opts.MappingIdentity != nil {
+ return 0, syserror.EINVAL
+ }
+ m, err := NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
+ if err != nil {
+ return 0, err
+ }
+ defer m.DecRef()
+ opts.MappingIdentity = m
+ opts.Mappable = m
+ }
+ }
+
+ if opts.Addr.RoundDown() != opts.Addr {
+ // MAP_FIXED requires addr to be page-aligned; non-fixed mappings
+ // don't.
+ if opts.Fixed {
+ return 0, syserror.EINVAL
+ }
+ opts.Addr = opts.Addr.RoundDown()
+ }
+
+ if !opts.MaxPerms.SupersetOf(opts.Perms) {
+ return 0, syserror.EACCES
+ }
+ if opts.Unmap && !opts.Fixed {
+ return 0, syserror.EINVAL
+ }
+ if opts.GrowsDown && opts.Mappable != nil {
+ return 0, syserror.EINVAL
+ }
+
+ // Get the new vma.
+ mm.mappingMu.Lock()
+ if opts.MLockMode < mm.defMLockMode {
+ opts.MLockMode = mm.defMLockMode
+ }
+ vseg, ar, err := mm.createVMALocked(ctx, opts)
+ if err != nil {
+ mm.mappingMu.Unlock()
+ return 0, err
+ }
+
+ // TODO(jamieliu): In Linux, VM_LOCKONFAULT (which may be set on the new
+ // vma by mlockall(MCL_FUTURE|MCL_ONFAULT) => mm_struct::def_flags) appears
+ // to effectively disable MAP_POPULATE by unsetting FOLL_POPULATE in
+ // mm/util.c:vm_mmap_pgoff() => mm/gup.c:__mm_populate() =>
+ // populate_vma_page_range(). Confirm this behavior.
+ switch {
+ case opts.Precommit || opts.MLockMode == memmap.MLockEager:
+ // Get pmas and map with precommit as requested.
+ mm.populateVMAAndUnlock(ctx, vseg, ar, true)
+
+ case opts.Mappable == nil && length <= privateAllocUnit:
+ // NOTE(b/63077076, b/63360184): Get pmas and map eagerly in the hope
+ // that doing so will save on future page faults. We only do this for
+ // anonymous mappings, since otherwise the cost of
+ // memmap.Mappable.Translate is unknown; and only for small mappings,
+ // to avoid needing to allocate large amounts of memory that we may
+ // subsequently need to checkpoint.
+ mm.populateVMAAndUnlock(ctx, vseg, ar, false)
+
+ default:
+ mm.mappingMu.Unlock()
+ }
+
+ return ar.Start, nil
+}
+
+// populateVMA obtains pmas for addresses in ar in the given vma, and maps them
+// into mm.as if it is active.
+//
+// Preconditions: mm.mappingMu must be locked. vseg.Range().IsSupersetOf(ar).
+func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) {
+ if !vseg.ValuePtr().effectivePerms.Any() {
+ // Linux doesn't populate inaccessible pages. See
+ // mm/gup.c:populate_vma_page_range.
+ return
+ }
+
+ mm.activeMu.Lock()
+ // Can't defer mm.activeMu.Unlock(); see below.
+
+ // Even if we get new pmas, we can't actually map them if we don't have an
+ // AddressSpace.
+ if mm.as == nil {
+ mm.activeMu.Unlock()
+ return
+ }
+
+ // Ensure that we have usable pmas.
+ pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, usermem.NoAccess)
+ if err != nil {
+ // mm/util.c:vm_mmap_pgoff() ignores the error, if any, from
+ // mm/gup.c:mm_populate(). If it matters, we'll get it again when
+ // userspace actually tries to use the failing page.
+ mm.activeMu.Unlock()
+ return
+ }
+
+ // Downgrade to a read-lock on activeMu since we don't need to mutate pmas
+ // anymore.
+ mm.activeMu.DowngradeLock()
+
+ // As above, errors are silently ignored.
+ mm.mapASLocked(pseg, ar, precommit)
+ mm.activeMu.RUnlock()
+}
+
+// populateVMAAndUnlock is equivalent to populateVMA, but also unconditionally
+// unlocks mm.mappingMu. In cases where populateVMAAndUnlock is usable, it is
+// preferable to populateVMA since it unlocks mm.mappingMu before performing
+// expensive operations that don't require it to be locked.
+//
+// Preconditions: mm.mappingMu must be locked for writing.
+// vseg.Range().IsSupersetOf(ar).
+//
+// Postconditions: mm.mappingMu will be unlocked.
+func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) {
+ // See populateVMA above for commentary.
+ if !vseg.ValuePtr().effectivePerms.Any() {
+ mm.mappingMu.Unlock()
+ return
+ }
+
+ mm.activeMu.Lock()
+
+ if mm.as == nil {
+ mm.activeMu.Unlock()
+ mm.mappingMu.Unlock()
+ return
+ }
+
+ // mm.mappingMu doesn't need to be write-locked for getPMAsLocked, and it
+ // isn't needed at all for mapASLocked.
+ mm.mappingMu.DowngradeLock()
+ pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, usermem.NoAccess)
+ mm.mappingMu.RUnlock()
+ if err != nil {
+ mm.activeMu.Unlock()
+ return
+ }
+
+ mm.activeMu.DowngradeLock()
+ mm.mapASLocked(pseg, ar, precommit)
+ mm.activeMu.RUnlock()
+}
+
+// MapStack allocates the initial process stack.
+func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error) {
+ // maxStackSize is the maximum supported process stack size in bytes.
+ //
+ // This limit exists because stack growing isn't implemented, so the entire
+ // process stack must be mapped up-front.
+ const maxStackSize = 128 << 20
+
+ stackSize := limits.FromContext(ctx).Get(limits.Stack)
+ r, ok := usermem.Addr(stackSize.Cur).RoundUp()
+ sz := uint64(r)
+ if !ok {
+ // RLIM_INFINITY rounds up to 0.
+ sz = linux.DefaultStackSoftLimit
+ } else if sz > maxStackSize {
+ ctx.Warningf("Capping stack size from RLIMIT_STACK of %v down to %v.", sz, maxStackSize)
+ sz = maxStackSize
+ } else if sz == 0 {
+ return usermem.AddrRange{}, syserror.ENOMEM
+ }
+ szaddr := usermem.Addr(sz)
+ ctx.Debugf("Allocating stack with size of %v bytes", sz)
+
+ // Determine the stack's desired location. Unlike Linux, address
+ // randomization can't be disabled.
+ stackEnd := mm.layout.MaxAddr - usermem.Addr(mrand.Int63n(int64(mm.layout.MaxStackRand))).RoundDown()
+ if stackEnd < szaddr {
+ return usermem.AddrRange{}, syserror.ENOMEM
+ }
+ stackStart := stackEnd - szaddr
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ _, ar, err := mm.createVMALocked(ctx, memmap.MMapOpts{
+ Length: sz,
+ Addr: stackStart,
+ Perms: usermem.ReadWrite,
+ MaxPerms: usermem.AnyAccess,
+ Private: true,
+ GrowsDown: true,
+ MLockMode: mm.defMLockMode,
+ Hint: "[stack]",
+ })
+ return ar, err
+}
+
+// MUnmap implements the semantics of Linux's munmap(2).
+func (mm *MemoryManager) MUnmap(ctx context.Context, addr usermem.Addr, length uint64) error {
+ if addr != addr.RoundDown() {
+ return syserror.EINVAL
+ }
+ if length == 0 {
+ return syserror.EINVAL
+ }
+ la, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return syserror.EINVAL
+ }
+ ar, ok := addr.ToRange(uint64(la))
+ if !ok {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ mm.unmapLocked(ctx, ar)
+ return nil
+}
+
+// MRemapOpts specifies options to MRemap.
+type MRemapOpts struct {
+ // Move controls whether MRemap moves the remapped mapping to a new address.
+ Move MRemapMoveMode
+
+ // NewAddr is the new address for the remapping. NewAddr is ignored unless
+ // Move is MMRemapMustMove.
+ NewAddr usermem.Addr
+}
+
+// MRemapMoveMode controls MRemap's moving behavior.
+type MRemapMoveMode int
+
+const (
+ // MRemapNoMove prevents MRemap from moving the remapped mapping.
+ MRemapNoMove MRemapMoveMode = iota
+
+ // MRemapMayMove allows MRemap to move the remapped mapping.
+ MRemapMayMove
+
+ // MRemapMustMove requires MRemap to move the remapped mapping to
+ // MRemapOpts.NewAddr, replacing any existing mappings in the remapped
+ // range.
+ MRemapMustMove
+)
+
+// MRemap implements the semantics of Linux's mremap(2).
+func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSize uint64, newSize uint64, opts MRemapOpts) (usermem.Addr, error) {
+ // "Note that old_address has to be page aligned." - mremap(2)
+ if oldAddr.RoundDown() != oldAddr {
+ return 0, syserror.EINVAL
+ }
+
+ // Linux treats an old_size that rounds up to 0 as 0, which is otherwise a
+ // valid size. However, new_size can't be 0 after rounding.
+ oldSizeAddr, _ := usermem.Addr(oldSize).RoundUp()
+ oldSize = uint64(oldSizeAddr)
+ newSizeAddr, ok := usermem.Addr(newSize).RoundUp()
+ if !ok || newSizeAddr == 0 {
+ return 0, syserror.EINVAL
+ }
+ newSize = uint64(newSizeAddr)
+
+ oldEnd, ok := oldAddr.AddLength(oldSize)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+
+ // All cases require that a vma exists at oldAddr.
+ vseg := mm.vmas.FindSegment(oldAddr)
+ if !vseg.Ok() {
+ return 0, syserror.EFAULT
+ }
+
+ // Behavior matrix:
+ //
+ // Move | oldSize = 0 | oldSize < newSize | oldSize = newSize | oldSize > newSize
+ // ---------+-------------+-------------------+-------------------+------------------
+ // NoMove | ENOMEM [1] | Grow in-place | No-op | Shrink in-place
+ // MayMove | Copy [1] | Grow in-place or | No-op | Shrink in-place
+ // | | move | |
+ // MustMove | Copy | Move and grow | Move | Shrink and move
+ //
+ // [1] In-place growth is impossible because the vma at oldAddr already
+ // occupies at least part of the destination. Thus the NoMove case always
+ // fails and the MayMove case always falls back to copying.
+
+ if vma := vseg.ValuePtr(); newSize > oldSize && vma.mlockMode != memmap.MLockNone {
+ // Check against RLIMIT_MEMLOCK. Unlike mmap, mlock, and mlockall,
+ // mremap in Linux does not check mm/mlock.c:can_do_mlock() and
+ // therefore does not return EPERM if RLIMIT_MEMLOCK is 0 and
+ // !CAP_IPC_LOCK.
+ mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
+ if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
+ if newLockedAS := mm.lockedAS - oldSize + newSize; newLockedAS > mlockLimit {
+ return 0, syserror.EAGAIN
+ }
+ }
+ }
+
+ if opts.Move != MRemapMustMove {
+ // Handle no-ops and in-place shrinking. These cases don't care if
+ // [oldAddr, oldEnd) maps to a single vma, or is even mapped at all
+ // (aside from oldAddr).
+ if newSize <= oldSize {
+ if newSize < oldSize {
+ // If oldAddr+oldSize didn't overflow, oldAddr+newSize can't
+ // either.
+ newEnd := oldAddr + usermem.Addr(newSize)
+ mm.unmapLocked(ctx, usermem.AddrRange{newEnd, oldEnd})
+ }
+ return oldAddr, nil
+ }
+
+ // Handle in-place growing.
+
+ // Check that oldEnd maps to the same vma as oldAddr.
+ if vseg.End() < oldEnd {
+ return 0, syserror.EFAULT
+ }
+ // "Grow" the existing vma by creating a new mergeable one.
+ vma := vseg.ValuePtr()
+ var newOffset uint64
+ if vma.mappable != nil {
+ newOffset = vseg.mappableRange().End
+ }
+ vseg, ar, err := mm.createVMALocked(ctx, memmap.MMapOpts{
+ Length: newSize - oldSize,
+ MappingIdentity: vma.id,
+ Mappable: vma.mappable,
+ Offset: newOffset,
+ Addr: oldEnd,
+ Fixed: true,
+ Perms: vma.realPerms,
+ MaxPerms: vma.maxPerms,
+ Private: vma.private,
+ GrowsDown: vma.growsDown,
+ MLockMode: vma.mlockMode,
+ Hint: vma.hint,
+ })
+ if err == nil {
+ if vma.mlockMode == memmap.MLockEager {
+ mm.populateVMA(ctx, vseg, ar, true)
+ }
+ return oldAddr, nil
+ }
+ // In-place growth failed. In the MRemapMayMove case, fall through to
+ // copying/moving below.
+ if opts.Move == MRemapNoMove {
+ return 0, err
+ }
+ }
+
+ // Find a location for the new mapping.
+ var newAR usermem.AddrRange
+ switch opts.Move {
+ case MRemapMayMove:
+ newAddr, err := mm.findAvailableLocked(newSize, findAvailableOpts{})
+ if err != nil {
+ return 0, err
+ }
+ newAR, _ = newAddr.ToRange(newSize)
+
+ case MRemapMustMove:
+ newAddr := opts.NewAddr
+ if newAddr.RoundDown() != newAddr {
+ return 0, syserror.EINVAL
+ }
+ var ok bool
+ newAR, ok = newAddr.ToRange(newSize)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+ if (usermem.AddrRange{oldAddr, oldEnd}).Overlaps(newAR) {
+ return 0, syserror.EINVAL
+ }
+
+ // Unmap any mappings at the destination.
+ mm.unmapLocked(ctx, newAR)
+
+ // If the sizes specify shrinking, unmap everything between the new and
+ // old sizes at the source. Unmapping before the following checks is
+ // correct: compare Linux's mm/mremap.c:mremap_to() => do_munmap(),
+ // vma_to_resize().
+ if newSize < oldSize {
+ oldNewEnd := oldAddr + usermem.Addr(newSize)
+ mm.unmapLocked(ctx, usermem.AddrRange{oldNewEnd, oldEnd})
+ oldEnd = oldNewEnd
+ }
+
+ // unmapLocked may have invalidated vseg; look it up again.
+ vseg = mm.vmas.FindSegment(oldAddr)
+ }
+
+ oldAR := usermem.AddrRange{oldAddr, oldEnd}
+
+ // Check that oldEnd maps to the same vma as oldAddr.
+ if vseg.End() < oldEnd {
+ return 0, syserror.EFAULT
+ }
+
+ // Check against RLIMIT_AS.
+ newUsageAS := mm.usageAS - uint64(oldAR.Length()) + uint64(newAR.Length())
+ if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS {
+ return 0, syserror.ENOMEM
+ }
+
+ if vma := vseg.ValuePtr(); vma.mappable != nil {
+ // Check that offset+length does not overflow.
+ if vma.off+uint64(newAR.Length()) < vma.off {
+ return 0, syserror.EINVAL
+ }
+ // Inform the Mappable, if any, of the new mapping.
+ if err := vma.mappable.CopyMapping(ctx, mm, oldAR, newAR, vseg.mappableOffsetAt(oldAR.Start), vma.canWriteMappableLocked()); err != nil {
+ return 0, err
+ }
+ }
+
+ if oldSize == 0 {
+ // Handle copying.
+ //
+ // We can't use createVMALocked because it calls Mappable.AddMapping,
+ // whereas we've already called Mappable.CopyMapping (which is
+ // consistent with Linux). Call vseg.Value() (rather than
+ // vseg.ValuePtr()) to make a copy of the vma.
+ vma := vseg.Value()
+ if vma.mappable != nil {
+ vma.off = vseg.mappableOffsetAt(oldAR.Start)
+ }
+ if vma.id != nil {
+ vma.id.IncRef()
+ }
+ vseg := mm.vmas.Insert(mm.vmas.FindGap(newAR.Start), newAR, vma)
+ mm.usageAS += uint64(newAR.Length())
+ if vma.isPrivateDataLocked() {
+ mm.dataAS += uint64(newAR.Length())
+ }
+ if vma.mlockMode != memmap.MLockNone {
+ mm.lockedAS += uint64(newAR.Length())
+ if vma.mlockMode == memmap.MLockEager {
+ mm.populateVMA(ctx, vseg, newAR, true)
+ }
+ }
+ return newAR.Start, nil
+ }
+
+ // Handle moving.
+ //
+ // Remove the existing vma before inserting the new one to minimize
+ // iterator invalidation. We do this directly (instead of calling
+ // removeVMAsLocked) because:
+ //
+ // 1. We can't drop the reference on vma.id, which will be transferred to
+ // the new vma.
+ //
+ // 2. We can't call vma.mappable.RemoveMapping, because pmas are still at
+ // oldAR, so calling RemoveMapping could cause us to miss an invalidation
+ // overlapping oldAR.
+ //
+ // Call vseg.Value() (rather than vseg.ValuePtr()) to make a copy of the
+ // vma.
+ vseg = mm.vmas.Isolate(vseg, oldAR)
+ vma := vseg.Value()
+ mm.vmas.Remove(vseg)
+ vseg = mm.vmas.Insert(mm.vmas.FindGap(newAR.Start), newAR, vma)
+ mm.usageAS = mm.usageAS - uint64(oldAR.Length()) + uint64(newAR.Length())
+ if vma.isPrivateDataLocked() {
+ mm.dataAS = mm.dataAS - uint64(oldAR.Length()) + uint64(newAR.Length())
+ }
+ if vma.mlockMode != memmap.MLockNone {
+ mm.lockedAS = mm.lockedAS - uint64(oldAR.Length()) + uint64(newAR.Length())
+ }
+
+ // Move pmas. This is technically optional for non-private pmas, which
+ // could just go through memmap.Mappable.Translate again, but it's required
+ // for private pmas.
+ mm.activeMu.Lock()
+ mm.movePMAsLocked(oldAR, newAR)
+ mm.activeMu.Unlock()
+
+ // Now that pmas have been moved to newAR, we can notify vma.mappable that
+ // oldAR is no longer mapped.
+ if vma.mappable != nil {
+ vma.mappable.RemoveMapping(ctx, mm, oldAR, vma.off, vma.canWriteMappableLocked())
+ }
+
+ if vma.mlockMode == memmap.MLockEager {
+ mm.populateVMA(ctx, vseg, newAR, true)
+ }
+
+ return newAR.Start, nil
+}
+
+// MProtect implements the semantics of Linux's mprotect(2).
+func (mm *MemoryManager) MProtect(addr usermem.Addr, length uint64, realPerms usermem.AccessType, growsDown bool) error {
+ if addr.RoundDown() != addr {
+ return syserror.EINVAL
+ }
+ if length == 0 {
+ return nil
+ }
+ rlength, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return syserror.ENOMEM
+ }
+ ar, ok := addr.ToRange(uint64(rlength))
+ if !ok {
+ return syserror.ENOMEM
+ }
+ effectivePerms := realPerms.Effective()
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ // Non-growsDown mprotect requires that all of ar is mapped, and stops at
+ // the first non-empty gap. growsDown mprotect requires that the first vma
+ // be growsDown, but does not require it to extend all the way to ar.Start;
+ // vmas after the first must be contiguous but need not be growsDown, like
+ // the non-growsDown case.
+ vseg := mm.vmas.LowerBoundSegment(ar.Start)
+ if !vseg.Ok() {
+ return syserror.ENOMEM
+ }
+ if growsDown {
+ if !vseg.ValuePtr().growsDown {
+ return syserror.EINVAL
+ }
+ if ar.End <= vseg.Start() {
+ return syserror.ENOMEM
+ }
+ ar.Start = vseg.Start()
+ } else {
+ if ar.Start < vseg.Start() {
+ return syserror.ENOMEM
+ }
+ }
+
+ mm.activeMu.Lock()
+ defer mm.activeMu.Unlock()
+ defer func() {
+ mm.vmas.MergeRange(ar)
+ mm.vmas.MergeAdjacent(ar)
+ mm.pmas.MergeRange(ar)
+ mm.pmas.MergeAdjacent(ar)
+ }()
+ pseg := mm.pmas.LowerBoundSegment(ar.Start)
+ var didUnmapAS bool
+ for {
+ // Check for permission validity before splitting vmas, for consistency
+ // with Linux.
+ if !vseg.ValuePtr().maxPerms.SupersetOf(effectivePerms) {
+ return syserror.EACCES
+ }
+ vseg = mm.vmas.Isolate(vseg, ar)
+
+ // Update vma permissions.
+ vma := vseg.ValuePtr()
+ vmaLength := vseg.Range().Length()
+ if vma.isPrivateDataLocked() {
+ mm.dataAS -= uint64(vmaLength)
+ }
+
+ vma.realPerms = realPerms
+ vma.effectivePerms = effectivePerms
+ if vma.isPrivateDataLocked() {
+ mm.dataAS += uint64(vmaLength)
+ }
+
+ // Propagate vma permission changes to pmas.
+ for pseg.Ok() && pseg.Start() < vseg.End() {
+ if pseg.Range().Overlaps(vseg.Range()) {
+ pseg = mm.pmas.Isolate(pseg, vseg.Range())
+ pma := pseg.ValuePtr()
+ if !effectivePerms.SupersetOf(pma.effectivePerms) && !didUnmapAS {
+ // Unmap all of ar, not just vseg.Range(), to minimize host
+ // syscalls.
+ mm.unmapASLocked(ar)
+ didUnmapAS = true
+ }
+ pma.effectivePerms = effectivePerms.Intersect(pma.translatePerms)
+ if pma.needCOW {
+ pma.effectivePerms.Write = false
+ }
+ }
+ pseg = pseg.NextSegment()
+ }
+
+ // Continue to the next vma.
+ if ar.End <= vseg.End() {
+ return nil
+ }
+ vseg, _ = vseg.NextNonEmpty()
+ if !vseg.Ok() {
+ return syserror.ENOMEM
+ }
+ }
+}
+
+// BrkSetup sets mm's brk address to addr and its brk size to 0.
+func (mm *MemoryManager) BrkSetup(ctx context.Context, addr usermem.Addr) {
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ // Unmap the existing brk.
+ if mm.brk.Length() != 0 {
+ mm.unmapLocked(ctx, mm.brk)
+ }
+ mm.brk = usermem.AddrRange{addr, addr}
+}
+
+// Brk implements the semantics of Linux's brk(2), except that it returns an
+// error on failure.
+func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Addr, error) {
+ mm.mappingMu.Lock()
+ // Can't defer mm.mappingMu.Unlock(); see below.
+
+ if addr < mm.brk.Start {
+ addr = mm.brk.End
+ mm.mappingMu.Unlock()
+ return addr, syserror.EINVAL
+ }
+
+ // TODO(gvisor.dev/issue/156): This enforces RLIMIT_DATA, but is
+ // slightly more permissive than the usual data limit. In particular,
+ // this only limits the size of the heap; a true RLIMIT_DATA limits the
+ // size of heap + data + bss. The segment sizes need to be plumbed from
+ // the loader package to fully enforce RLIMIT_DATA.
+ if uint64(addr-mm.brk.Start) > limits.FromContext(ctx).Get(limits.Data).Cur {
+ addr = mm.brk.End
+ mm.mappingMu.Unlock()
+ return addr, syserror.ENOMEM
+ }
+
+ oldbrkpg, _ := mm.brk.End.RoundUp()
+ newbrkpg, ok := addr.RoundUp()
+ if !ok {
+ addr = mm.brk.End
+ mm.mappingMu.Unlock()
+ return addr, syserror.EFAULT
+ }
+
+ switch {
+ case oldbrkpg < newbrkpg:
+ vseg, ar, err := mm.createVMALocked(ctx, memmap.MMapOpts{
+ Length: uint64(newbrkpg - oldbrkpg),
+ Addr: oldbrkpg,
+ Fixed: true,
+ // Compare Linux's
+ // arch/x86/include/asm/page_types.h:VM_DATA_DEFAULT_FLAGS.
+ Perms: usermem.ReadWrite,
+ MaxPerms: usermem.AnyAccess,
+ Private: true,
+ // Linux: mm/mmap.c:sys_brk() => do_brk_flags() includes
+ // mm->def_flags.
+ MLockMode: mm.defMLockMode,
+ Hint: "[heap]",
+ })
+ if err != nil {
+ addr = mm.brk.End
+ mm.mappingMu.Unlock()
+ return addr, err
+ }
+ mm.brk.End = addr
+ if mm.defMLockMode == memmap.MLockEager {
+ mm.populateVMAAndUnlock(ctx, vseg, ar, true)
+ } else {
+ mm.mappingMu.Unlock()
+ }
+
+ case newbrkpg < oldbrkpg:
+ mm.unmapLocked(ctx, usermem.AddrRange{newbrkpg, oldbrkpg})
+ fallthrough
+
+ default:
+ mm.brk.End = addr
+ mm.mappingMu.Unlock()
+ }
+
+ return addr, nil
+}
+
+// MLock implements the semantics of Linux's mlock()/mlock2()/munlock(),
+// depending on mode.
+func (mm *MemoryManager) MLock(ctx context.Context, addr usermem.Addr, length uint64, mode memmap.MLockMode) error {
+ // Linux allows this to overflow.
+ la, _ := usermem.Addr(length + addr.PageOffset()).RoundUp()
+ ar, ok := addr.RoundDown().ToRange(uint64(la))
+ if !ok {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ // Can't defer mm.mappingMu.Unlock(); see below.
+
+ if mode != memmap.MLockNone {
+ // Check against RLIMIT_MEMLOCK.
+ if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
+ mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
+ if mlockLimit == 0 {
+ mm.mappingMu.Unlock()
+ return syserror.EPERM
+ }
+ if newLockedAS := mm.lockedAS + uint64(ar.Length()) - mm.mlockedBytesRangeLocked(ar); newLockedAS > mlockLimit {
+ mm.mappingMu.Unlock()
+ return syserror.ENOMEM
+ }
+ }
+ }
+
+ // Check this after RLIMIT_MEMLOCK for consistency with Linux.
+ if ar.Length() == 0 {
+ mm.mappingMu.Unlock()
+ return nil
+ }
+
+ // Apply the new mlock mode to vmas.
+ var unmapped bool
+ vseg := mm.vmas.FindSegment(ar.Start)
+ for {
+ if !vseg.Ok() {
+ unmapped = true
+ break
+ }
+ vseg = mm.vmas.Isolate(vseg, ar)
+ vma := vseg.ValuePtr()
+ prevMode := vma.mlockMode
+ vma.mlockMode = mode
+ if mode != memmap.MLockNone && prevMode == memmap.MLockNone {
+ mm.lockedAS += uint64(vseg.Range().Length())
+ } else if mode == memmap.MLockNone && prevMode != memmap.MLockNone {
+ mm.lockedAS -= uint64(vseg.Range().Length())
+ }
+ if ar.End <= vseg.End() {
+ break
+ }
+ vseg, _ = vseg.NextNonEmpty()
+ }
+ mm.vmas.MergeRange(ar)
+ mm.vmas.MergeAdjacent(ar)
+ if unmapped {
+ mm.mappingMu.Unlock()
+ return syserror.ENOMEM
+ }
+
+ if mode == memmap.MLockEager {
+ // Ensure that we have usable pmas. Since we didn't return ENOMEM
+ // above, ar must be fully covered by vmas, so we can just use
+ // NextSegment below.
+ mm.activeMu.Lock()
+ mm.mappingMu.DowngradeLock()
+ for vseg := mm.vmas.FindSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
+ if !vseg.ValuePtr().effectivePerms.Any() {
+ // Linux: mm/gup.c:__get_user_pages() returns EFAULT in this
+ // case, which is converted to ENOMEM by mlock.
+ mm.activeMu.Unlock()
+ mm.mappingMu.RUnlock()
+ return syserror.ENOMEM
+ }
+ _, _, err := mm.getPMAsLocked(ctx, vseg, vseg.Range().Intersect(ar), usermem.NoAccess)
+ if err != nil {
+ mm.activeMu.Unlock()
+ mm.mappingMu.RUnlock()
+ // Linux: mm/mlock.c:__mlock_posix_error_return()
+ if err == syserror.EFAULT {
+ return syserror.ENOMEM
+ }
+ if err == syserror.ENOMEM {
+ return syserror.EAGAIN
+ }
+ return err
+ }
+ }
+
+ // Map pmas into the active AddressSpace, if we have one.
+ mm.mappingMu.RUnlock()
+ if mm.as != nil {
+ mm.activeMu.DowngradeLock()
+ err := mm.mapASLocked(mm.pmas.LowerBoundSegment(ar.Start), ar, true /* precommit */)
+ mm.activeMu.RUnlock()
+ if err != nil {
+ return err
+ }
+ } else {
+ mm.activeMu.Unlock()
+ }
+ } else {
+ mm.mappingMu.Unlock()
+ }
+
+ return nil
+}
+
+// MLockAllOpts holds options to MLockAll.
+type MLockAllOpts struct {
+ // If Current is true, change the memory-locking behavior of all mappings
+ // to Mode. If Future is true, upgrade the memory-locking behavior of all
+ // future mappings to Mode. At least one of Current or Future must be true.
+ Current bool
+ Future bool
+ Mode memmap.MLockMode
+}
+
+// MLockAll implements the semantics of Linux's mlockall()/munlockall(),
+// depending on opts.
+func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error {
+ if !opts.Current && !opts.Future {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.Lock()
+ // Can't defer mm.mappingMu.Unlock(); see below.
+
+ if opts.Current {
+ if opts.Mode != memmap.MLockNone {
+ // Check against RLIMIT_MEMLOCK.
+ if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
+ mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
+ if mlockLimit == 0 {
+ mm.mappingMu.Unlock()
+ return syserror.EPERM
+ }
+ if uint64(mm.vmas.Span()) > mlockLimit {
+ mm.mappingMu.Unlock()
+ return syserror.ENOMEM
+ }
+ }
+ }
+ for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
+ vma := vseg.ValuePtr()
+ prevMode := vma.mlockMode
+ vma.mlockMode = opts.Mode
+ if opts.Mode != memmap.MLockNone && prevMode == memmap.MLockNone {
+ mm.lockedAS += uint64(vseg.Range().Length())
+ } else if opts.Mode == memmap.MLockNone && prevMode != memmap.MLockNone {
+ mm.lockedAS -= uint64(vseg.Range().Length())
+ }
+ }
+ }
+
+ if opts.Future {
+ mm.defMLockMode = opts.Mode
+ }
+
+ if opts.Current && opts.Mode == memmap.MLockEager {
+ // Linux: mm/mlock.c:sys_mlockall() => include/linux/mm.h:mm_populate()
+ // ignores the return value of __mm_populate(), so all errors below are
+ // ignored.
+ //
+ // Try to get usable pmas.
+ mm.activeMu.Lock()
+ mm.mappingMu.DowngradeLock()
+ for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
+ if vseg.ValuePtr().effectivePerms.Any() {
+ mm.getPMAsLocked(ctx, vseg, vseg.Range(), usermem.NoAccess)
+ }
+ }
+
+ // Map all pmas into the active AddressSpace, if we have one.
+ mm.mappingMu.RUnlock()
+ if mm.as != nil {
+ mm.activeMu.DowngradeLock()
+ mm.mapASLocked(mm.pmas.FirstSegment(), mm.applicationAddrRange(), true /* precommit */)
+ mm.activeMu.RUnlock()
+ } else {
+ mm.activeMu.Unlock()
+ }
+ } else {
+ mm.mappingMu.Unlock()
+ }
+ return nil
+}
+
+// Decommit implements the semantics of Linux's madvise(MADV_DONTNEED).
+func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error {
+ ar, ok := addr.ToRange(length)
+ if !ok {
+ return syserror.EINVAL
+ }
+
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ mm.activeMu.Lock()
+ defer mm.activeMu.Unlock()
+
+ // Linux's mm/madvise.c:madvise_dontneed() => mm/memory.c:zap_page_range()
+ // is analogous to our mm.invalidateLocked(ar, true, true). We inline this
+ // here, with the special case that we synchronously decommit
+ // uniquely-owned (non-copy-on-write) pages for private anonymous vma,
+ // which is the common case for MADV_DONTNEED. Invalidating these pmas, and
+ // allowing them to be reallocated when touched again, increases pma
+ // fragmentation, which may significantly reduce performance for
+ // non-vectored I/O implementations. Also, decommitting synchronously
+ // ensures that Decommit immediately reduces host memory usage.
+ var didUnmapAS bool
+ pseg := mm.pmas.LowerBoundSegment(ar.Start)
+ mf := mm.mfp.MemoryFile()
+ for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
+ vma := vseg.ValuePtr()
+ if vma.mlockMode != memmap.MLockNone {
+ return syserror.EINVAL
+ }
+ vsegAR := vseg.Range().Intersect(ar)
+ // pseg should already correspond to either this vma or a later one,
+ // since there can't be a pma without a corresponding vma.
+ if checkInvariants {
+ if pseg.Ok() && pseg.End() <= vsegAR.Start {
+ panic(fmt.Sprintf("pma %v precedes vma %v", pseg.Range(), vsegAR))
+ }
+ }
+ for pseg.Ok() && pseg.Start() < vsegAR.End {
+ pma := pseg.ValuePtr()
+ if pma.private && !mm.isPMACopyOnWriteLocked(vseg, pseg) {
+ psegAR := pseg.Range().Intersect(ar)
+ if vsegAR.IsSupersetOf(psegAR) && vma.mappable == nil {
+ if err := mf.Decommit(pseg.fileRangeOf(psegAR)); err == nil {
+ pseg = pseg.NextSegment()
+ continue
+ }
+ // If an error occurs, fall through to the general
+ // invalidation case below.
+ }
+ }
+ pseg = mm.pmas.Isolate(pseg, vsegAR)
+ pma = pseg.ValuePtr()
+ if !didUnmapAS {
+ // Unmap all of ar, not just pseg.Range(), to minimize host
+ // syscalls. AddressSpace mappings must be removed before
+ // mm.decPrivateRef().
+ mm.unmapASLocked(ar)
+ didUnmapAS = true
+ }
+ if pma.private {
+ mm.decPrivateRef(pseg.fileRange())
+ }
+ pma.file.DecRef(pseg.fileRange())
+ mm.removeRSSLocked(pseg.Range())
+ pseg = mm.pmas.Remove(pseg).NextSegment()
+ }
+ }
+
+ // "If there are some parts of the specified address space that are not
+ // mapped, the Linux version of madvise() ignores them and applies the call
+ // to the rest (but returns ENOMEM from the system call, as it should)." -
+ // madvise(2)
+ if mm.vmas.SpanRange(ar) != ar.Length() {
+ return syserror.ENOMEM
+ }
+ return nil
+}
+
+// MSyncOpts holds options to MSync.
+type MSyncOpts struct {
+ // Sync has the semantics of MS_SYNC.
+ Sync bool
+
+ // Invalidate has the semantics of MS_INVALIDATE.
+ Invalidate bool
+}
+
+// MSync implements the semantics of Linux's msync().
+func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length uint64, opts MSyncOpts) error {
+ if addr != addr.RoundDown() {
+ return syserror.EINVAL
+ }
+ if length == 0 {
+ return nil
+ }
+ la, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return syserror.ENOMEM
+ }
+ ar, ok := addr.ToRange(uint64(la))
+ if !ok {
+ return syserror.ENOMEM
+ }
+
+ mm.mappingMu.RLock()
+ // Can't defer mm.mappingMu.RUnlock(); see below.
+ vseg := mm.vmas.LowerBoundSegment(ar.Start)
+ if !vseg.Ok() {
+ mm.mappingMu.RUnlock()
+ return syserror.ENOMEM
+ }
+ var unmapped bool
+ lastEnd := ar.Start
+ for {
+ if !vseg.Ok() {
+ mm.mappingMu.RUnlock()
+ unmapped = true
+ break
+ }
+ if lastEnd < vseg.Start() {
+ unmapped = true
+ }
+ lastEnd = vseg.End()
+ vma := vseg.ValuePtr()
+ if opts.Invalidate && vma.mlockMode != memmap.MLockNone {
+ mm.mappingMu.RUnlock()
+ return syserror.EBUSY
+ }
+ // It's only possible to have dirtied the Mappable through a shared
+ // mapping. Don't check if the mapping is writable, because mprotect
+ // may have changed this, and also because Linux doesn't.
+ if id := vma.id; opts.Sync && id != nil && vma.mappable != nil && !vma.private {
+ // We can't call memmap.MappingIdentity.Msync while holding
+ // mm.mappingMu since it may take fs locks that precede it in the
+ // lock order.
+ id.IncRef()
+ mr := vseg.mappableRangeOf(vseg.Range().Intersect(ar))
+ mm.mappingMu.RUnlock()
+ err := id.Msync(ctx, mr)
+ id.DecRef()
+ if err != nil {
+ return err
+ }
+ if lastEnd >= ar.End {
+ break
+ }
+ mm.mappingMu.RLock()
+ vseg = mm.vmas.LowerBoundSegment(lastEnd)
+ } else {
+ if lastEnd >= ar.End {
+ mm.mappingMu.RUnlock()
+ break
+ }
+ vseg = vseg.NextSegment()
+ }
+ }
+
+ if unmapped {
+ return syserror.ENOMEM
+ }
+ return nil
+}
+
+// GetSharedFutexKey is used by kernel.Task.GetSharedKey.
+func (mm *MemoryManager) GetSharedFutexKey(ctx context.Context, addr usermem.Addr) (futex.Key, error) {
+ ar, ok := addr.ToRange(4) // sizeof(int32).
+ if !ok {
+ return futex.Key{}, syserror.EFAULT
+ }
+
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ vseg, _, err := mm.getVMAsLocked(ctx, ar, usermem.Read, false)
+ if err != nil {
+ return futex.Key{}, err
+ }
+ vma := vseg.ValuePtr()
+
+ if vma.private {
+ return futex.Key{
+ Kind: futex.KindSharedPrivate,
+ Offset: uint64(addr),
+ }, nil
+ }
+
+ if vma.id != nil {
+ vma.id.IncRef()
+ }
+ return futex.Key{
+ Kind: futex.KindSharedMappable,
+ Mappable: vma.mappable,
+ MappingIdentity: vma.id,
+ Offset: vseg.mappableOffsetAt(addr),
+ }, nil
+}
+
+// VirtualMemorySize returns the combined length in bytes of all mappings in
+// mm.
+func (mm *MemoryManager) VirtualMemorySize() uint64 {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ return mm.usageAS
+}
+
+// VirtualMemorySizeRange returns the combined length in bytes of all mappings
+// in ar in mm.
+func (mm *MemoryManager) VirtualMemorySizeRange(ar usermem.AddrRange) uint64 {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ return uint64(mm.vmas.SpanRange(ar))
+}
+
+// ResidentSetSize returns the value advertised as mm's RSS in bytes.
+func (mm *MemoryManager) ResidentSetSize() uint64 {
+ mm.activeMu.RLock()
+ defer mm.activeMu.RUnlock()
+ return mm.curRSS
+}
+
+// MaxResidentSetSize returns the value advertised as mm's max RSS in bytes.
+func (mm *MemoryManager) MaxResidentSetSize() uint64 {
+ mm.activeMu.RLock()
+ defer mm.activeMu.RUnlock()
+ return mm.maxRSS
+}
+
+// VirtualDataSize returns the size of private data segments in mm.
+func (mm *MemoryManager) VirtualDataSize() uint64 {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ return mm.dataAS
+}
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
new file mode 100644
index 000000000..02203f79f
--- /dev/null
+++ b/pkg/sentry/mm/vma.go
@@ -0,0 +1,564 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mm
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Preconditions: mm.mappingMu must be locked for writing. opts must be valid
+// as defined by the checks in MMap.
+func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOpts) (vmaIterator, usermem.AddrRange, error) {
+ if opts.MaxPerms != opts.MaxPerms.Effective() {
+ panic(fmt.Sprintf("Non-effective MaxPerms %s cannot be enforced", opts.MaxPerms))
+ }
+
+ // Find a useable range.
+ addr, err := mm.findAvailableLocked(opts.Length, findAvailableOpts{
+ Addr: opts.Addr,
+ Fixed: opts.Fixed,
+ Unmap: opts.Unmap,
+ Map32Bit: opts.Map32Bit,
+ })
+ if err != nil {
+ return vmaIterator{}, usermem.AddrRange{}, err
+ }
+ ar, _ := addr.ToRange(opts.Length)
+
+ // Check against RLIMIT_AS.
+ newUsageAS := mm.usageAS + opts.Length
+ if opts.Unmap {
+ newUsageAS -= uint64(mm.vmas.SpanRange(ar))
+ }
+ if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS {
+ return vmaIterator{}, usermem.AddrRange{}, syserror.ENOMEM
+ }
+
+ if opts.MLockMode != memmap.MLockNone {
+ // Check against RLIMIT_MEMLOCK.
+ if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
+ mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
+ if mlockLimit == 0 {
+ return vmaIterator{}, usermem.AddrRange{}, syserror.EPERM
+ }
+ newLockedAS := mm.lockedAS + opts.Length
+ if opts.Unmap {
+ newLockedAS -= mm.mlockedBytesRangeLocked(ar)
+ }
+ if newLockedAS > mlockLimit {
+ return vmaIterator{}, usermem.AddrRange{}, syserror.EAGAIN
+ }
+ }
+ }
+
+ // Remove overwritten mappings. This ordering is consistent with Linux:
+ // compare Linux's mm/mmap.c:mmap_region() => do_munmap(),
+ // file->f_op->mmap().
+ var vgap vmaGapIterator
+ if opts.Unmap {
+ vgap = mm.unmapLocked(ctx, ar)
+ } else {
+ vgap = mm.vmas.FindGap(ar.Start)
+ }
+
+ // Inform the Mappable, if any, of the new mapping.
+ if opts.Mappable != nil {
+ // The expression for writable is vma.canWriteMappableLocked(), but we
+ // don't yet have a vma.
+ if err := opts.Mappable.AddMapping(ctx, mm, ar, opts.Offset, !opts.Private && opts.MaxPerms.Write); err != nil {
+ return vmaIterator{}, usermem.AddrRange{}, err
+ }
+ }
+
+ // Take a reference on opts.MappingIdentity before inserting the vma since
+ // vma merging can drop the reference.
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.IncRef()
+ }
+
+ // Finally insert the vma.
+ v := vma{
+ mappable: opts.Mappable,
+ off: opts.Offset,
+ realPerms: opts.Perms,
+ effectivePerms: opts.Perms.Effective(),
+ maxPerms: opts.MaxPerms,
+ private: opts.Private,
+ growsDown: opts.GrowsDown,
+ mlockMode: opts.MLockMode,
+ id: opts.MappingIdentity,
+ hint: opts.Hint,
+ }
+
+ vseg := mm.vmas.Insert(vgap, ar, v)
+ mm.usageAS += opts.Length
+ if v.isPrivateDataLocked() {
+ mm.dataAS += opts.Length
+ }
+ if opts.MLockMode != memmap.MLockNone {
+ mm.lockedAS += opts.Length
+ }
+
+ return vseg, ar, nil
+}
+
+type findAvailableOpts struct {
+ // These fields are equivalent to those in memmap.MMapOpts, except that:
+ //
+ // - Addr must be page-aligned.
+ //
+ // - Unmap allows existing guard pages in the returned range.
+
+ Addr usermem.Addr
+ Fixed bool
+ Unmap bool
+ Map32Bit bool
+}
+
+// map32Start/End are the bounds to which MAP_32BIT mappings are constrained,
+// and are equivalent to Linux's MAP32_BASE and MAP32_MAX respectively.
+const (
+ map32Start = 0x40000000
+ map32End = 0x80000000
+)
+
+// findAvailableLocked finds an allocatable range.
+//
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOpts) (usermem.Addr, error) {
+ if opts.Fixed {
+ opts.Map32Bit = false
+ }
+ allowedAR := mm.applicationAddrRange()
+ if opts.Map32Bit {
+ allowedAR = allowedAR.Intersect(usermem.AddrRange{map32Start, map32End})
+ }
+
+ // Does the provided suggestion work?
+ if ar, ok := opts.Addr.ToRange(length); ok {
+ if allowedAR.IsSupersetOf(ar) {
+ if opts.Unmap {
+ return ar.Start, nil
+ }
+ // Check for the presence of an existing vma or guard page.
+ if vgap := mm.vmas.FindGap(ar.Start); vgap.Ok() && vgap.availableRange().IsSupersetOf(ar) {
+ return ar.Start, nil
+ }
+ }
+ }
+
+ // Fixed mappings accept only the requested address.
+ if opts.Fixed {
+ return 0, syserror.ENOMEM
+ }
+
+ // Prefer hugepage alignment if a hugepage or more is requested.
+ alignment := uint64(usermem.PageSize)
+ if length >= usermem.HugePageSize {
+ alignment = usermem.HugePageSize
+ }
+
+ if opts.Map32Bit {
+ return mm.findLowestAvailableLocked(length, alignment, allowedAR)
+ }
+ if mm.layout.DefaultDirection == arch.MmapBottomUp {
+ return mm.findLowestAvailableLocked(length, alignment, usermem.AddrRange{mm.layout.BottomUpBase, mm.layout.MaxAddr})
+ }
+ return mm.findHighestAvailableLocked(length, alignment, usermem.AddrRange{mm.layout.MinAddr, mm.layout.TopDownBase})
+}
+
+func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange {
+ return usermem.AddrRange{mm.layout.MinAddr, mm.layout.MaxAddr}
+}
+
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
+ for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextGap() {
+ if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
+ // Can we shift up to match the alignment?
+ if offset := uint64(gr.Start) % alignment; offset != 0 {
+ if uint64(gr.Length()) >= length+alignment-offset {
+ // Yes, we're aligned.
+ return gr.Start + usermem.Addr(alignment-offset), nil
+ }
+ }
+
+ // Either aligned perfectly, or can't align it.
+ return gr.Start, nil
+ }
+ }
+ return 0, syserror.ENOMEM
+}
+
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
+ for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevGap() {
+ if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
+ // Can we shift down to match the alignment?
+ start := gr.End - usermem.Addr(length)
+ if offset := uint64(start) % alignment; offset != 0 {
+ if gr.Start <= start-usermem.Addr(offset) {
+ // Yes, we're aligned.
+ return start - usermem.Addr(offset), nil
+ }
+ }
+
+ // Either aligned perfectly, or can't align it.
+ return start, nil
+ }
+ }
+ return 0, syserror.ENOMEM
+}
+
+// Preconditions: mm.mappingMu must be locked.
+func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 {
+ var total uint64
+ for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
+ if vseg.ValuePtr().mlockMode != memmap.MLockNone {
+ total += uint64(vseg.Range().Intersect(ar).Length())
+ }
+ }
+ return total
+}
+
+// getVMAsLocked ensures that vmas exist for all addresses in ar, and support
+// access of type (at, ignorePermissions). It returns:
+//
+// - An iterator to the vma containing ar.Start. If no vma contains ar.Start,
+// the iterator is unspecified.
+//
+// - An iterator to the gap after the last vma containing an address in ar. If
+// vmas exist for no addresses in ar, the iterator is to a gap that begins
+// before ar.Start.
+//
+// - An error that is non-nil if vmas exist for only a subset of ar.
+//
+// Preconditions: mm.mappingMu must be locked for reading; it may be
+// temporarily unlocked. ar.Length() != 0.
+func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ // Inline mm.vmas.LowerBoundSegment so that we have the preceding gap if
+ // !vbegin.Ok().
+ vbegin, vgap := mm.vmas.Find(ar.Start)
+ if !vbegin.Ok() {
+ vbegin = vgap.NextSegment()
+ // vseg.Ok() is checked before entering the following loop.
+ } else {
+ vgap = vbegin.PrevGap()
+ }
+
+ addr := ar.Start
+ vseg := vbegin
+ for vseg.Ok() {
+ // Loop invariants: vgap = vseg.PrevGap(); addr < vseg.End().
+ vma := vseg.ValuePtr()
+ if addr < vseg.Start() {
+ // TODO(jamieliu): Implement vma.growsDown here.
+ return vbegin, vgap, syserror.EFAULT
+ }
+
+ perms := vma.effectivePerms
+ if ignorePermissions {
+ perms = vma.maxPerms
+ }
+ if !perms.SupersetOf(at) {
+ return vbegin, vgap, syserror.EPERM
+ }
+
+ addr = vseg.End()
+ vgap = vseg.NextGap()
+ if addr >= ar.End {
+ return vbegin, vgap, nil
+ }
+ vseg = vgap.NextSegment()
+ }
+
+ // Ran out of vmas before ar.End.
+ return vbegin, vgap, syserror.EFAULT
+}
+
+// getVecVMAsLocked ensures that vmas exist for all addresses in ars, and
+// support access to type of (at, ignorePermissions). It returns the subset of
+// ars for which vmas exist. If this is not equal to ars, it returns a non-nil
+// error explaining why.
+//
+// Preconditions: mm.mappingMu must be locked for reading; it may be
+// temporarily unlocked.
+//
+// Postconditions: ars is not mutated.
+func (mm *MemoryManager) getVecVMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool) (usermem.AddrRangeSeq, error) {
+ for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
+ ar := arsit.Head()
+ if ar.Length() == 0 {
+ continue
+ }
+ if _, vend, err := mm.getVMAsLocked(ctx, ar, at, ignorePermissions); err != nil {
+ return truncatedAddrRangeSeq(ars, arsit, vend.Start()), err
+ }
+ }
+ return ars, nil
+}
+
+// vma extension will not shrink the number of unmapped bytes between the start
+// of a growsDown vma and the end of its predecessor non-growsDown vma below
+// guardBytes.
+//
+// guardBytes is equivalent to Linux's stack_guard_gap after upstream
+// 1be7107fbe18 "mm: larger stack guard gap, between vmas".
+const guardBytes = 256 * usermem.PageSize
+
+// unmapLocked unmaps all addresses in ar and returns the resulting gap in
+// mm.vmas.
+//
+// Preconditions: mm.mappingMu must be locked for writing. ar.Length() != 0.
+// ar must be page-aligned.
+func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ // AddressSpace mappings and pmas must be invalidated before
+ // mm.removeVMAsLocked() => memmap.Mappable.RemoveMapping().
+ mm.Invalidate(ar, memmap.InvalidateOpts{InvalidatePrivate: true})
+ return mm.removeVMAsLocked(ctx, ar)
+}
+
+// removeVMAsLocked removes vmas for addresses in ar and returns the resulting
+// gap in mm.vmas. It does not remove pmas or AddressSpace mappings; clients
+// must do so before calling removeVMAsLocked.
+//
+// Preconditions: mm.mappingMu must be locked for writing. ar.Length() != 0. ar
+// must be page-aligned.
+func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
+ if checkInvariants {
+ if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ }
+
+ vseg, vgap := mm.vmas.Find(ar.Start)
+ if vgap.Ok() {
+ vseg = vgap.NextSegment()
+ }
+ for vseg.Ok() && vseg.Start() < ar.End {
+ vseg = mm.vmas.Isolate(vseg, ar)
+ vmaAR := vseg.Range()
+ vma := vseg.ValuePtr()
+ if vma.mappable != nil {
+ vma.mappable.RemoveMapping(ctx, mm, vmaAR, vma.off, vma.canWriteMappableLocked())
+ }
+ if vma.id != nil {
+ vma.id.DecRef()
+ }
+ mm.usageAS -= uint64(vmaAR.Length())
+ if vma.isPrivateDataLocked() {
+ mm.dataAS -= uint64(vmaAR.Length())
+ }
+ if vma.mlockMode != memmap.MLockNone {
+ mm.lockedAS -= uint64(vmaAR.Length())
+ }
+ vgap = mm.vmas.Remove(vseg)
+ vseg = vgap.NextSegment()
+ }
+ return vgap
+}
+
+// canWriteMappableLocked returns true if it is possible for vma.mappable to be
+// written to via this vma, i.e. if it is possible that
+// vma.mappable.Translate(at.Write=true) may be called as a result of this vma.
+// This includes via I/O with usermem.IOOpts.IgnorePermissions = true, such as
+// PTRACE_POKEDATA.
+//
+// canWriteMappableLocked is equivalent to Linux's VM_SHARED.
+//
+// Preconditions: mm.mappingMu must be locked.
+func (vma *vma) canWriteMappableLocked() bool {
+ return !vma.private && vma.maxPerms.Write
+}
+
+// isPrivateDataLocked identify the data segments - private, writable, not stack
+//
+// Preconditions: mm.mappingMu must be locked.
+func (vma *vma) isPrivateDataLocked() bool {
+ return vma.realPerms.Write && vma.private && !vma.growsDown
+}
+
+// vmaSetFunctions implements segment.Functions for vmaSet.
+type vmaSetFunctions struct{}
+
+func (vmaSetFunctions) MinKey() usermem.Addr {
+ return 0
+}
+
+func (vmaSetFunctions) MaxKey() usermem.Addr {
+ return ^usermem.Addr(0)
+}
+
+func (vmaSetFunctions) ClearValue(vma *vma) {
+ vma.mappable = nil
+ vma.id = nil
+ vma.hint = ""
+}
+
+func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRange, vma2 vma) (vma, bool) {
+ if vma1.mappable != vma2.mappable ||
+ (vma1.mappable != nil && vma1.off+uint64(ar1.Length()) != vma2.off) ||
+ vma1.realPerms != vma2.realPerms ||
+ vma1.maxPerms != vma2.maxPerms ||
+ vma1.private != vma2.private ||
+ vma1.growsDown != vma2.growsDown ||
+ vma1.mlockMode != vma2.mlockMode ||
+ vma1.id != vma2.id ||
+ vma1.hint != vma2.hint {
+ return vma{}, false
+ }
+
+ if vma2.id != nil {
+ vma2.id.DecRef()
+ }
+ return vma1, true
+}
+
+func (vmaSetFunctions) Split(ar usermem.AddrRange, v vma, split usermem.Addr) (vma, vma) {
+ v2 := v
+ if v2.mappable != nil {
+ v2.off += uint64(split - ar.Start)
+ }
+ if v2.id != nil {
+ v2.id.IncRef()
+ }
+ return v, v2
+}
+
+// Preconditions: vseg.ValuePtr().mappable != nil. vseg.Range().Contains(addr).
+func (vseg vmaIterator) mappableOffsetAt(addr usermem.Addr) uint64 {
+ if checkInvariants {
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if vseg.ValuePtr().mappable == nil {
+ panic("Mappable offset is meaningless for anonymous vma")
+ }
+ if !vseg.Range().Contains(addr) {
+ panic(fmt.Sprintf("addr %v out of bounds %v", addr, vseg.Range()))
+ }
+ }
+
+ vma := vseg.ValuePtr()
+ vstart := vseg.Start()
+ return vma.off + uint64(addr-vstart)
+}
+
+// Preconditions: vseg.ValuePtr().mappable != nil.
+func (vseg vmaIterator) mappableRange() memmap.MappableRange {
+ return vseg.mappableRangeOf(vseg.Range())
+}
+
+// Preconditions: vseg.ValuePtr().mappable != nil.
+// vseg.Range().IsSupersetOf(ar). ar.Length() != 0.
+func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRange {
+ if checkInvariants {
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if vseg.ValuePtr().mappable == nil {
+ panic("MappableRange is meaningless for anonymous vma")
+ }
+ if !ar.WellFormed() || ar.Length() <= 0 {
+ panic(fmt.Sprintf("invalid ar: %v", ar))
+ }
+ if !vseg.Range().IsSupersetOf(ar) {
+ panic(fmt.Sprintf("ar %v out of bounds %v", ar, vseg.Range()))
+ }
+ }
+
+ vma := vseg.ValuePtr()
+ vstart := vseg.Start()
+ return memmap.MappableRange{vma.off + uint64(ar.Start-vstart), vma.off + uint64(ar.End-vstart)}
+}
+
+// Preconditions: vseg.ValuePtr().mappable != nil.
+// vseg.mappableRange().IsSupersetOf(mr). mr.Length() != 0.
+func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange {
+ if checkInvariants {
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if vseg.ValuePtr().mappable == nil {
+ panic("MappableRange is meaningless for anonymous vma")
+ }
+ if !mr.WellFormed() || mr.Length() <= 0 {
+ panic(fmt.Sprintf("invalid mr: %v", mr))
+ }
+ if !vseg.mappableRange().IsSupersetOf(mr) {
+ panic(fmt.Sprintf("mr %v out of bounds %v", mr, vseg.mappableRange()))
+ }
+ }
+
+ vma := vseg.ValuePtr()
+ vstart := vseg.Start()
+ return usermem.AddrRange{vstart + usermem.Addr(mr.Start-vma.off), vstart + usermem.Addr(mr.End-vma.off)}
+}
+
+// seekNextLowerBound returns mm.vmas.LowerBoundSegment(addr), but does so by
+// scanning linearly forward from vseg.
+//
+// Preconditions: mm.mappingMu must be locked. addr >= vseg.Start().
+func (vseg vmaIterator) seekNextLowerBound(addr usermem.Addr) vmaIterator {
+ if checkInvariants {
+ if !vseg.Ok() {
+ panic("terminal vma iterator")
+ }
+ if addr < vseg.Start() {
+ panic(fmt.Sprintf("can't seek forward to %#x from %#x", addr, vseg.Start()))
+ }
+ }
+ for vseg.Ok() && addr >= vseg.End() {
+ vseg = vseg.NextSegment()
+ }
+ return vseg
+}
+
+// availableRange returns the subset of vgap.Range() in which new vmas may be
+// created without MMapOpts.Unmap == true.
+func (vgap vmaGapIterator) availableRange() usermem.AddrRange {
+ ar := vgap.Range()
+ next := vgap.NextSegment()
+ if !next.Ok() || !next.ValuePtr().growsDown {
+ return ar
+ }
+ // Exclude guard pages.
+ if ar.Length() < guardBytes {
+ return usermem.AddrRange{ar.Start, ar.Start}
+ }
+ ar.End -= guardBytes
+ return ar
+}
diff --git a/pkg/sentry/mm/vma_set.go b/pkg/sentry/mm/vma_set.go
new file mode 100755
index 000000000..c042fe606
--- /dev/null
+++ b/pkg/sentry/mm/vma_set.go
@@ -0,0 +1,1274 @@
+package mm
+
+import (
+ __generics_imported0 "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ vmaminDegree = 8
+
+ vmamaxDegree = 2 * vmaminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type vmaSet struct {
+ root vmanode `state:".(*vmaSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *vmaSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *vmaSet) IsEmptyRange(r __generics_imported0.AddrRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *vmaSet) Span() __generics_imported0.Addr {
+ var sz __generics_imported0.Addr
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *vmaSet) SpanRange(r __generics_imported0.AddrRange) __generics_imported0.Addr {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz __generics_imported0.Addr
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *vmaSet) FirstSegment() vmaIterator {
+ if s.root.nrSegments == 0 {
+ return vmaIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *vmaSet) LastSegment() vmaIterator {
+ if s.root.nrSegments == 0 {
+ return vmaIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *vmaSet) FirstGap() vmaGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return vmaGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *vmaSet) LastGap() vmaGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return vmaGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *vmaSet) Find(key __generics_imported0.Addr) (vmaIterator, vmaGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return vmaIterator{n, i}, vmaGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return vmaIterator{}, vmaGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *vmaSet) FindSegment(key __generics_imported0.Addr) vmaIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *vmaSet) LowerBoundSegment(min __generics_imported0.Addr) vmaIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *vmaSet) UpperBoundSegment(max __generics_imported0.Addr) vmaIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *vmaSet) FindGap(key __generics_imported0.Addr) vmaGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *vmaSet) LowerBoundGap(min __generics_imported0.Addr) vmaGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *vmaSet) UpperBoundGap(max __generics_imported0.Addr) vmaGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *vmaSet) Add(r __generics_imported0.AddrRange, val vma) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *vmaSet) AddWithoutMerging(r __generics_imported0.AddrRange, val vma) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *vmaSet) Insert(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (vmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (vmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (vmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *vmaSet) InsertWithoutMerging(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *vmaSet) InsertWithoutMergingUnchecked(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return vmaIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *vmaSet) Remove(seg vmaIterator) vmaGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ vmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(vmaGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *vmaSet) RemoveAll() {
+ s.root = vmanode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *vmaSet) RemoveRange(r __generics_imported0.AddrRange) vmaGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *vmaSet) Merge(first, second vmaIterator) vmaIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *vmaSet) MergeUnchecked(first, second vmaIterator) vmaIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (vmaSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return vmaIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *vmaSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *vmaSet) MergeRange(r __generics_imported0.AddrRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *vmaSet) MergeAdjacent(r __generics_imported0.AddrRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *vmaSet) Split(seg vmaIterator, split __generics_imported0.Addr) (vmaIterator, vmaIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *vmaSet) SplitUnchecked(seg vmaIterator, split __generics_imported0.Addr) (vmaIterator, vmaIterator) {
+ val1, val2 := (vmaSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.AddrRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *vmaSet) SplitAt(split __generics_imported0.Addr) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *vmaSet) Isolate(seg vmaIterator, r __generics_imported0.AddrRange) vmaIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *vmaSet) ApplyContiguous(r __generics_imported0.AddrRange, fn func(seg vmaIterator)) vmaGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return vmaGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return vmaGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type vmanode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *vmanode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [vmamaxDegree - 1]__generics_imported0.AddrRange
+ values [vmamaxDegree - 1]vma
+ children [vmamaxDegree]*vmanode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *vmanode) firstSegment() vmaIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return vmaIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *vmanode) lastSegment() vmaIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return vmaIterator{n, n.nrSegments - 1}
+}
+
+func (n *vmanode) prevSibling() *vmanode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *vmanode) nextSibling() *vmanode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *vmanode) rebalanceBeforeInsert(gap vmaGapIterator) vmaGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < vmamaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &vmanode{
+ nrSegments: vmaminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &vmanode{
+ nrSegments: vmaminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:vmaminDegree-1], n.keys[:vmaminDegree-1])
+ copy(left.values[:vmaminDegree-1], n.values[:vmaminDegree-1])
+ copy(right.keys[:vmaminDegree-1], n.keys[vmaminDegree:])
+ copy(right.values[:vmaminDegree-1], n.values[vmaminDegree:])
+ n.keys[0], n.values[0] = n.keys[vmaminDegree-1], n.values[vmaminDegree-1]
+ vmazeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:vmaminDegree], n.children[:vmaminDegree])
+ copy(right.children[:vmaminDegree], n.children[vmaminDegree:])
+ vmazeroNodeSlice(n.children[2:])
+ for i := 0; i < vmaminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < vmaminDegree {
+ return vmaGapIterator{left, gap.index}
+ }
+ return vmaGapIterator{right, gap.index - vmaminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[vmaminDegree-1], n.values[vmaminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &vmanode{
+ nrSegments: vmaminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:vmaminDegree-1], n.keys[vmaminDegree:])
+ copy(sibling.values[:vmaminDegree-1], n.values[vmaminDegree:])
+ vmazeroValueSlice(n.values[vmaminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:vmaminDegree], n.children[vmaminDegree:])
+ vmazeroNodeSlice(n.children[vmaminDegree:])
+ for i := 0; i < vmaminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = vmaminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < vmaminDegree {
+ return gap
+ }
+ return vmaGapIterator{sibling, gap.index - vmaminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *vmanode) rebalanceAfterRemove(gap vmaGapIterator) vmaGapIterator {
+ for {
+ if n.nrSegments >= vmaminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= vmaminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ vmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return vmaGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return vmaGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= vmaminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ vmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return vmaGapIterator{n, n.nrSegments}
+ }
+ return vmaGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return vmaGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return vmaGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *vmanode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = vmaGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ vmaSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type vmaIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *vmanode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg vmaIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg vmaIterator) Range() __generics_imported0.AddrRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg vmaIterator) Start() __generics_imported0.Addr {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg vmaIterator) End() __generics_imported0.Addr {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg vmaIterator) SetRangeUnchecked(r __generics_imported0.AddrRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg vmaIterator) SetRange(r __generics_imported0.AddrRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg vmaIterator) SetStartUnchecked(start __generics_imported0.Addr) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg vmaIterator) SetStart(start __generics_imported0.Addr) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg vmaIterator) SetEndUnchecked(end __generics_imported0.Addr) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg vmaIterator) SetEnd(end __generics_imported0.Addr) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg vmaIterator) Value() vma {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg vmaIterator) ValuePtr() *vma {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg vmaIterator) SetValue(val vma) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg vmaIterator) PrevSegment() vmaIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return vmaIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return vmaIterator{}
+ }
+ return vmasegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg vmaIterator) NextSegment() vmaIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return vmaIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return vmaIterator{}
+ }
+ return vmasegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg vmaIterator) PrevGap() vmaGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return vmaGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg vmaIterator) NextGap() vmaGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return vmaGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg vmaIterator) PrevNonEmpty() (vmaIterator, vmaGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return vmaIterator{}, gap
+ }
+ return gap.PrevSegment(), vmaGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg vmaIterator) NextNonEmpty() (vmaIterator, vmaGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return vmaIterator{}, gap
+ }
+ return gap.NextSegment(), vmaGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type vmaGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *vmanode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap vmaGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap vmaGapIterator) Range() __generics_imported0.AddrRange {
+ return __generics_imported0.AddrRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap vmaGapIterator) Start() __generics_imported0.Addr {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return vmaSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap vmaGapIterator) End() __generics_imported0.Addr {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return vmaSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap vmaGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap vmaGapIterator) PrevSegment() vmaIterator {
+ return vmasegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap vmaGapIterator) NextSegment() vmaIterator {
+ return vmasegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap vmaGapIterator) PrevGap() vmaGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return vmaGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap vmaGapIterator) NextGap() vmaGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return vmaGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func vmasegmentBeforePosition(n *vmanode, i int) vmaIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return vmaIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return vmaIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func vmasegmentAfterPosition(n *vmanode, i int) vmaIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return vmaIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return vmaIterator{n, i}
+}
+
+func vmazeroValueSlice(slice []vma) {
+
+ for i := range slice {
+ vmaSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func vmazeroNodeSlice(slice []*vmanode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *vmaSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *vmanode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *vmanode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type vmaSegmentDataSlices struct {
+ Start []__generics_imported0.Addr
+ End []__generics_imported0.Addr
+ Values []vma
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *vmaSet) ExportSortedSlices() *vmaSegmentDataSlices {
+ var sds vmaSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *vmaSet) ImportSortedSlices(sds *vmaSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := __generics_imported0.AddrRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *vmaSet) saveRoot() *vmaSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *vmaSet) loadRoot(sds *vmaSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/pgalloc/context.go b/pkg/sentry/pgalloc/context.go
new file mode 100644
index 000000000..cb9809b1f
--- /dev/null
+++ b/pkg/sentry/pgalloc/context.go
@@ -0,0 +1,48 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pgalloc
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextID is this package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxMemoryFile is a Context.Value key for a MemoryFile.
+ CtxMemoryFile contextID = iota
+
+ // CtxMemoryFileProvider is a Context.Value key for a MemoryFileProvider.
+ CtxMemoryFileProvider
+)
+
+// MemoryFileFromContext returns the MemoryFile used by ctx, or nil if no such
+// MemoryFile exists.
+func MemoryFileFromContext(ctx context.Context) *MemoryFile {
+ if v := ctx.Value(CtxMemoryFile); v != nil {
+ return v.(*MemoryFile)
+ }
+ return nil
+}
+
+// MemoryFileProviderFromContext returns the MemoryFileProvider used by ctx, or nil if no such
+// MemoryFileProvider exists.
+func MemoryFileProviderFromContext(ctx context.Context) MemoryFileProvider {
+ if v := ctx.Value(CtxMemoryFileProvider); v != nil {
+ return v.(MemoryFileProvider)
+ }
+ return nil
+}
diff --git a/pkg/sentry/pgalloc/evictable_range.go b/pkg/sentry/pgalloc/evictable_range.go
new file mode 100755
index 000000000..10ce2ff44
--- /dev/null
+++ b/pkg/sentry/pgalloc/evictable_range.go
@@ -0,0 +1,62 @@
+package pgalloc
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type EvictableRange struct {
+ // Start is the inclusive start of the range.
+ Start uint64
+
+ // End is the exclusive end of the range.
+ End uint64
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r EvictableRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r EvictableRange) Length() uint64 {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r EvictableRange) Contains(x uint64) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r EvictableRange) Overlaps(r2 EvictableRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r EvictableRange) IsSupersetOf(r2 EvictableRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r EvictableRange) Intersect(r2 EvictableRange) EvictableRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r EvictableRange) CanSplitAt(x uint64) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/sentry/pgalloc/evictable_range_set.go b/pkg/sentry/pgalloc/evictable_range_set.go
new file mode 100755
index 000000000..a4dcb1663
--- /dev/null
+++ b/pkg/sentry/pgalloc/evictable_range_set.go
@@ -0,0 +1,1270 @@
+package pgalloc
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ evictableRangeminDegree = 3
+
+ evictableRangemaxDegree = 2 * evictableRangeminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type evictableRangeSet struct {
+ root evictableRangenode `state:".(*evictableRangeSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *evictableRangeSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *evictableRangeSet) IsEmptyRange(r EvictableRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *evictableRangeSet) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *evictableRangeSet) SpanRange(r EvictableRange) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *evictableRangeSet) FirstSegment() evictableRangeIterator {
+ if s.root.nrSegments == 0 {
+ return evictableRangeIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *evictableRangeSet) LastSegment() evictableRangeIterator {
+ if s.root.nrSegments == 0 {
+ return evictableRangeIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *evictableRangeSet) FirstGap() evictableRangeGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return evictableRangeGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *evictableRangeSet) LastGap() evictableRangeGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return evictableRangeGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *evictableRangeSet) Find(key uint64) (evictableRangeIterator, evictableRangeGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return evictableRangeIterator{n, i}, evictableRangeGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return evictableRangeIterator{}, evictableRangeGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *evictableRangeSet) FindSegment(key uint64) evictableRangeIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *evictableRangeSet) LowerBoundSegment(min uint64) evictableRangeIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *evictableRangeSet) UpperBoundSegment(max uint64) evictableRangeIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *evictableRangeSet) FindGap(key uint64) evictableRangeGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *evictableRangeSet) LowerBoundGap(min uint64) evictableRangeGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *evictableRangeSet) UpperBoundGap(max uint64) evictableRangeGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *evictableRangeSet) Add(r EvictableRange, val evictableRangeSetValue) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *evictableRangeSet) AddWithoutMerging(r EvictableRange, val evictableRangeSetValue) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *evictableRangeSet) Insert(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (evictableRangeSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (evictableRangeSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (evictableRangeSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *evictableRangeSet) InsertWithoutMerging(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *evictableRangeSet) InsertWithoutMergingUnchecked(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return evictableRangeIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *evictableRangeSet) Remove(seg evictableRangeIterator) evictableRangeGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ evictableRangeSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(evictableRangeGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *evictableRangeSet) RemoveAll() {
+ s.root = evictableRangenode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *evictableRangeSet) RemoveRange(r EvictableRange) evictableRangeGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *evictableRangeSet) Merge(first, second evictableRangeIterator) evictableRangeIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *evictableRangeSet) MergeUnchecked(first, second evictableRangeIterator) evictableRangeIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (evictableRangeSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return evictableRangeIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *evictableRangeSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *evictableRangeSet) MergeRange(r EvictableRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *evictableRangeSet) MergeAdjacent(r EvictableRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *evictableRangeSet) Split(seg evictableRangeIterator, split uint64) (evictableRangeIterator, evictableRangeIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *evictableRangeSet) SplitUnchecked(seg evictableRangeIterator, split uint64) (evictableRangeIterator, evictableRangeIterator) {
+ val1, val2 := (evictableRangeSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), EvictableRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *evictableRangeSet) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *evictableRangeSet) Isolate(seg evictableRangeIterator, r EvictableRange) evictableRangeIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *evictableRangeSet) ApplyContiguous(r EvictableRange, fn func(seg evictableRangeIterator)) evictableRangeGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return evictableRangeGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return evictableRangeGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type evictableRangenode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *evictableRangenode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [evictableRangemaxDegree - 1]EvictableRange
+ values [evictableRangemaxDegree - 1]evictableRangeSetValue
+ children [evictableRangemaxDegree]*evictableRangenode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *evictableRangenode) firstSegment() evictableRangeIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return evictableRangeIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *evictableRangenode) lastSegment() evictableRangeIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return evictableRangeIterator{n, n.nrSegments - 1}
+}
+
+func (n *evictableRangenode) prevSibling() *evictableRangenode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *evictableRangenode) nextSibling() *evictableRangenode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *evictableRangenode) rebalanceBeforeInsert(gap evictableRangeGapIterator) evictableRangeGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < evictableRangemaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &evictableRangenode{
+ nrSegments: evictableRangeminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &evictableRangenode{
+ nrSegments: evictableRangeminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:evictableRangeminDegree-1], n.keys[:evictableRangeminDegree-1])
+ copy(left.values[:evictableRangeminDegree-1], n.values[:evictableRangeminDegree-1])
+ copy(right.keys[:evictableRangeminDegree-1], n.keys[evictableRangeminDegree:])
+ copy(right.values[:evictableRangeminDegree-1], n.values[evictableRangeminDegree:])
+ n.keys[0], n.values[0] = n.keys[evictableRangeminDegree-1], n.values[evictableRangeminDegree-1]
+ evictableRangezeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:evictableRangeminDegree], n.children[:evictableRangeminDegree])
+ copy(right.children[:evictableRangeminDegree], n.children[evictableRangeminDegree:])
+ evictableRangezeroNodeSlice(n.children[2:])
+ for i := 0; i < evictableRangeminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < evictableRangeminDegree {
+ return evictableRangeGapIterator{left, gap.index}
+ }
+ return evictableRangeGapIterator{right, gap.index - evictableRangeminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[evictableRangeminDegree-1], n.values[evictableRangeminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &evictableRangenode{
+ nrSegments: evictableRangeminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:evictableRangeminDegree-1], n.keys[evictableRangeminDegree:])
+ copy(sibling.values[:evictableRangeminDegree-1], n.values[evictableRangeminDegree:])
+ evictableRangezeroValueSlice(n.values[evictableRangeminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:evictableRangeminDegree], n.children[evictableRangeminDegree:])
+ evictableRangezeroNodeSlice(n.children[evictableRangeminDegree:])
+ for i := 0; i < evictableRangeminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = evictableRangeminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < evictableRangeminDegree {
+ return gap
+ }
+ return evictableRangeGapIterator{sibling, gap.index - evictableRangeminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *evictableRangenode) rebalanceAfterRemove(gap evictableRangeGapIterator) evictableRangeGapIterator {
+ for {
+ if n.nrSegments >= evictableRangeminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= evictableRangeminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ evictableRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return evictableRangeGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return evictableRangeGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= evictableRangeminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ evictableRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return evictableRangeGapIterator{n, n.nrSegments}
+ }
+ return evictableRangeGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return evictableRangeGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return evictableRangeGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *evictableRangenode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = evictableRangeGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ evictableRangeSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type evictableRangeIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *evictableRangenode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg evictableRangeIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg evictableRangeIterator) Range() EvictableRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg evictableRangeIterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg evictableRangeIterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg evictableRangeIterator) SetRangeUnchecked(r EvictableRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg evictableRangeIterator) SetRange(r EvictableRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg evictableRangeIterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg evictableRangeIterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg evictableRangeIterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg evictableRangeIterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg evictableRangeIterator) Value() evictableRangeSetValue {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg evictableRangeIterator) ValuePtr() *evictableRangeSetValue {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg evictableRangeIterator) SetValue(val evictableRangeSetValue) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg evictableRangeIterator) PrevSegment() evictableRangeIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return evictableRangeIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return evictableRangeIterator{}
+ }
+ return evictableRangesegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg evictableRangeIterator) NextSegment() evictableRangeIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return evictableRangeIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return evictableRangeIterator{}
+ }
+ return evictableRangesegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg evictableRangeIterator) PrevGap() evictableRangeGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return evictableRangeGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg evictableRangeIterator) NextGap() evictableRangeGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return evictableRangeGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg evictableRangeIterator) PrevNonEmpty() (evictableRangeIterator, evictableRangeGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return evictableRangeIterator{}, gap
+ }
+ return gap.PrevSegment(), evictableRangeGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg evictableRangeIterator) NextNonEmpty() (evictableRangeIterator, evictableRangeGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return evictableRangeIterator{}, gap
+ }
+ return gap.NextSegment(), evictableRangeGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type evictableRangeGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *evictableRangenode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap evictableRangeGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap evictableRangeGapIterator) Range() EvictableRange {
+ return EvictableRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap evictableRangeGapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return evictableRangeSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap evictableRangeGapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return evictableRangeSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap evictableRangeGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap evictableRangeGapIterator) PrevSegment() evictableRangeIterator {
+ return evictableRangesegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap evictableRangeGapIterator) NextSegment() evictableRangeIterator {
+ return evictableRangesegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap evictableRangeGapIterator) PrevGap() evictableRangeGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return evictableRangeGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap evictableRangeGapIterator) NextGap() evictableRangeGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return evictableRangeGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func evictableRangesegmentBeforePosition(n *evictableRangenode, i int) evictableRangeIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return evictableRangeIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return evictableRangeIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func evictableRangesegmentAfterPosition(n *evictableRangenode, i int) evictableRangeIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return evictableRangeIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return evictableRangeIterator{n, i}
+}
+
+func evictableRangezeroValueSlice(slice []evictableRangeSetValue) {
+
+ for i := range slice {
+ evictableRangeSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func evictableRangezeroNodeSlice(slice []*evictableRangenode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *evictableRangeSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *evictableRangenode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *evictableRangenode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type evictableRangeSegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []evictableRangeSetValue
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *evictableRangeSet) ExportSortedSlices() *evictableRangeSegmentDataSlices {
+ var sds evictableRangeSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *evictableRangeSet) ImportSortedSlices(sds *evictableRangeSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := EvictableRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *evictableRangeSet) saveRoot() *evictableRangeSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *evictableRangeSet) loadRoot(sds *evictableRangeSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
new file mode 100644
index 000000000..2b9924ad7
--- /dev/null
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -0,0 +1,1187 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pgalloc contains the page allocator subsystem, which manages memory
+// that may be mapped into application address spaces.
+//
+// Lock order:
+//
+// pgalloc.MemoryFile.mu
+// pgalloc.MemoryFile.mappingsMu
+package pgalloc
+
+import (
+ "fmt"
+ "math"
+ "os"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// MemoryFile is a platform.File whose pages may be allocated to arbitrary
+// users.
+type MemoryFile struct {
+ // opts holds options passed to NewMemoryFile. opts is immutable.
+ opts MemoryFileOpts
+
+ // MemoryFile owns a single backing file, which is modeled as follows:
+ //
+ // Each page in the file can be committed or uncommitted. A page is
+ // committed if the host kernel is spending resources to store its contents
+ // and uncommitted otherwise. This definition includes pages that the host
+ // kernel has swapped; this is intentional, to ensure that accounting does
+ // not change even if host kernel swapping behavior changes, and that
+ // memory used by pseudo-swap mechanisms like zswap is still accounted.
+ //
+ // The initial contents of uncommitted pages are implicitly zero bytes. A
+ // read or write to the contents of an uncommitted page causes it to be
+ // committed. This is the only event that can cause a uncommitted page to
+ // be committed.
+ //
+ // fallocate(FALLOC_FL_PUNCH_HOLE) (MemoryFile.Decommit) causes committed
+ // pages to be uncommitted. This is the only event that can cause a
+ // committed page to be uncommitted.
+ //
+ // Memory accounting is based on identifying the set of committed pages.
+ // Since we do not have direct access to the MMU, tracking reads and writes
+ // to uncommitted pages to detect commitment would introduce additional
+ // page faults, which would be prohibitively expensive. Instead, we query
+ // the host kernel to determine which pages are committed.
+
+ // file is the backing file. The file pointer is immutable.
+ file *os.File
+
+ mu sync.Mutex
+
+ // usage maps each page in the file to metadata for that page. Pages for
+ // which no segment exists in usage are both unallocated (not in use) and
+ // uncommitted.
+ //
+ // Since usage stores usageInfo objects by value, clients should usually
+ // use usageIterator.ValuePtr() instead of usageIterator.Value() to get a
+ // pointer to the usageInfo rather than a copy.
+ //
+ // usage must be kept maximally merged (that is, there should never be two
+ // adjacent segments with the same values). At least markReclaimed depends
+ // on this property.
+ //
+ // usage is protected by mu.
+ usage usageSet
+
+ // The UpdateUsage function scans all segments with knownCommitted set
+ // to false, sees which pages are committed and creates corresponding
+ // segments with knownCommitted set to true.
+ //
+ // In order to avoid unnecessary scans, usageExpected tracks the total
+ // file blocks expected. This is used to elide the scan when this
+ // matches the underlying file blocks.
+ //
+ // To track swapped pages, usageSwapped tracks the discrepency between
+ // what is observed in core and what is reported by the file. When
+ // usageSwapped is non-zero, a sweep will be performed at least every
+ // second. The start of the last sweep is recorded in usageLast.
+ //
+ // All usage attributes are all protected by mu.
+ usageExpected uint64
+ usageSwapped uint64
+ usageLast time.Time
+
+ // minUnallocatedPage is the minimum page that may be unallocated.
+ // i.e., there are no unallocated pages below minUnallocatedPage.
+ //
+ // minUnallocatedPage is protected by mu.
+ minUnallocatedPage uint64
+
+ // fileSize is the size of the backing memory file in bytes. fileSize is
+ // always a power-of-two multiple of chunkSize.
+ //
+ // fileSize is protected by mu.
+ fileSize int64
+
+ // Pages from the backing file are mapped into the local address space on
+ // the granularity of large pieces called chunks. mappings is a []uintptr
+ // that stores, for each chunk, the start address of a mapping of that
+ // chunk in the current process' address space, or 0 if no such mapping
+ // exists. Once a chunk is mapped, it is never remapped or unmapped until
+ // the MemoryFile is destroyed.
+ //
+ // Mutating the mappings slice or its contents requires both holding
+ // mappingsMu and using atomic memory operations. (The slice is mutated
+ // whenever the file is expanded. Per the above, the only permitted
+ // mutation of the slice's contents is the assignment of a mapping to a
+ // chunk that was previously unmapped.) Reading the slice or its contents
+ // only requires *either* holding mappingsMu or using atomic memory
+ // operations. This allows MemoryFile.MapInternal to avoid locking in the
+ // common case where chunk mappings already exist.
+ mappingsMu sync.Mutex
+ mappings atomic.Value
+
+ // destroyed is set by Destroy to instruct the reclaimer goroutine to
+ // release resources and exit. destroyed is protected by mu.
+ destroyed bool
+
+ // reclaimable is true if usage may contain reclaimable pages. reclaimable
+ // is protected by mu.
+ reclaimable bool
+
+ // minReclaimablePage is the minimum page that may be reclaimable.
+ // i.e., all reclaimable pages are >= minReclaimablePage.
+ //
+ // minReclaimablePage is protected by mu.
+ minReclaimablePage uint64
+
+ // reclaimCond is signaled (with mu locked) when reclaimable or destroyed
+ // transitions from false to true.
+ reclaimCond sync.Cond
+
+ // evictable maps EvictableMemoryUsers to eviction state.
+ //
+ // evictable is protected by mu.
+ evictable map[EvictableMemoryUser]*evictableMemoryUserInfo
+
+ // evictionWG counts the number of goroutines currently performing evictions.
+ evictionWG sync.WaitGroup
+}
+
+// MemoryFileOpts provides options to NewMemoryFile.
+type MemoryFileOpts struct {
+ // DelayedEviction controls the extent to which the MemoryFile may delay
+ // eviction of evictable allocations.
+ DelayedEviction DelayedEvictionType
+}
+
+// DelayedEvictionType is the type of MemoryFileOpts.DelayedEviction.
+type DelayedEvictionType int
+
+const (
+ // DelayedEvictionDefault has unspecified behavior.
+ DelayedEvictionDefault DelayedEvictionType = iota
+
+ // DelayedEvictionDisabled requires that evictable allocations are evicted
+ // as soon as possible.
+ DelayedEvictionDisabled
+
+ // DelayedEvictionEnabled requests that the MemoryFile delay eviction of
+ // evictable allocations until doing so is considered necessary to avoid
+ // performance degradation due to host memory pressure, or OOM kills.
+ //
+ // As of this writing, DelayedEvictionEnabled delays evictions until the
+ // reclaimer goroutine is out of work (pages to reclaim), then evicts all
+ // pending evictable allocations immediately.
+ DelayedEvictionEnabled
+
+ // DelayedEvictionManual requires that evictable allocations are only
+ // evicted when MemoryFile.StartEvictions() is called. This is extremely
+ // dangerous outside of tests.
+ DelayedEvictionManual
+)
+
+// usageInfo tracks usage information.
+//
+// +stateify savable
+type usageInfo struct {
+ // kind is the usage kind.
+ kind usage.MemoryKind
+
+ // knownCommitted is true if the tracked region is definitely committed.
+ // (If it is false, the tracked region may or may not be committed.)
+ knownCommitted bool
+
+ refs uint64
+}
+
+// An EvictableMemoryUser represents a user of MemoryFile-allocated memory that
+// may be asked to deallocate that memory in the presence of memory pressure.
+type EvictableMemoryUser interface {
+ // Evict requests that the EvictableMemoryUser deallocate memory used by
+ // er, which was registered as evictable by a previous call to
+ // MemoryFile.MarkEvictable.
+ //
+ // Evict is not required to deallocate memory. In particular, since pgalloc
+ // must call Evict without holding locks to avoid circular lock ordering,
+ // it is possible that the passed range has already been marked as
+ // unevictable by a racing call to MemoryFile.MarkUnevictable.
+ // Implementations of EvictableMemoryUser must detect such races and handle
+ // them by making Evict have no effect on unevictable ranges.
+ //
+ // After a call to Evict, the MemoryFile will consider the evicted range
+ // unevictable (i.e. it will not call Evict on the same range again) until
+ // informed otherwise by a subsequent call to MarkEvictable.
+ Evict(ctx context.Context, er EvictableRange)
+}
+
+// An EvictableRange represents a range of uint64 offsets in an
+// EvictableMemoryUser.
+//
+// In practice, most EvictableMemoryUsers will probably be implementations of
+// memmap.Mappable, and EvictableRange therefore corresponds to
+// memmap.MappableRange. However, this package cannot depend on the memmap
+// package, since doing so would create a circular dependency.
+//
+// type EvictableRange <generated using go_generics>
+
+// evictableMemoryUserInfo is the value type of MemoryFile.evictable.
+type evictableMemoryUserInfo struct {
+ // ranges tracks all evictable ranges for the given user.
+ ranges evictableRangeSet
+
+ // If evicting is true, there is a goroutine currently evicting all
+ // evictable ranges for this user.
+ evicting bool
+}
+
+const (
+ chunkShift = 24
+ chunkSize = 1 << chunkShift // 16 MB
+ chunkMask = chunkSize - 1
+
+ initialSize = chunkSize
+
+ // maxPage is the highest 64-bit page.
+ maxPage = math.MaxUint64 &^ (usermem.PageSize - 1)
+)
+
+// NewMemoryFile creates a MemoryFile backed by the given file. If
+// NewMemoryFile succeeds, ownership of file is transferred to the returned
+// MemoryFile.
+func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
+ switch opts.DelayedEviction {
+ case DelayedEvictionDefault:
+ opts.DelayedEviction = DelayedEvictionEnabled
+ case DelayedEvictionDisabled, DelayedEvictionEnabled, DelayedEvictionManual:
+ default:
+ return nil, fmt.Errorf("invalid MemoryFileOpts.DelayedEviction: %v", opts.DelayedEviction)
+ }
+
+ // Truncate the file to 0 bytes first to ensure that it's empty.
+ if err := file.Truncate(0); err != nil {
+ return nil, err
+ }
+ if err := file.Truncate(initialSize); err != nil {
+ return nil, err
+ }
+ f := &MemoryFile{
+ opts: opts,
+ fileSize: initialSize,
+ file: file,
+ // No pages are reclaimable. DecRef will always be able to
+ // decrease minReclaimablePage from this point.
+ minReclaimablePage: maxPage,
+ evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
+ }
+ f.mappings.Store(make([]uintptr, initialSize/chunkSize))
+ f.reclaimCond.L = &f.mu
+ go f.runReclaim() // S/R-SAFE: f.mu
+
+ // The Linux kernel contains an optional feature called "Integrity
+ // Measurement Architecture" (IMA). If IMA is enabled, it will checksum
+ // binaries the first time they are mapped PROT_EXEC. This is bad news for
+ // executable pages mapped from our backing file, which can grow to
+ // terabytes in (sparse) size. If IMA attempts to checksum a file that
+ // large, it will allocate all of the sparse pages and quickly exhaust all
+ // memory.
+ //
+ // Work around IMA by immediately creating a temporary PROT_EXEC mapping,
+ // while the backing file is still small. IMA will ignore any future
+ // mappings.
+ m, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ 0,
+ usermem.PageSize,
+ syscall.PROT_EXEC,
+ syscall.MAP_SHARED,
+ file.Fd(),
+ 0)
+ if errno != 0 {
+ // This isn't fatal (IMA may not even be in use). Log the error, but
+ // don't return it.
+ log.Warningf("Failed to pre-map MemoryFile PROT_EXEC: %v", errno)
+ } else {
+ if _, _, errno := syscall.Syscall(
+ syscall.SYS_MUNMAP,
+ m,
+ usermem.PageSize,
+ 0); errno != 0 {
+ panic(fmt.Sprintf("failed to unmap PROT_EXEC MemoryFile mapping: %v", errno))
+ }
+ }
+
+ return f, nil
+}
+
+// Destroy releases all resources used by f.
+//
+// Preconditions: All pages allocated by f have been freed.
+//
+// Postconditions: None of f's methods may be called after Destroy.
+func (f *MemoryFile) Destroy() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.destroyed = true
+ f.reclaimCond.Signal()
+}
+
+// Allocate returns a range of initially-zeroed pages of the given length with
+// the given accounting kind and a single reference held by the caller. When
+// the last reference on an allocated page is released, ownership of the page
+// is returned to the MemoryFile, allowing it to be returned by a future call
+// to Allocate.
+//
+// Preconditions: length must be page-aligned and non-zero.
+func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) {
+ if length == 0 || length%usermem.PageSize != 0 {
+ panic(fmt.Sprintf("invalid allocation length: %#x", length))
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // Align hugepage-and-larger allocations on hugepage boundaries to try
+ // to take advantage of hugetmpfs.
+ alignment := uint64(usermem.PageSize)
+ if length >= usermem.HugePageSize {
+ alignment = usermem.HugePageSize
+ }
+
+ start, minUnallocatedPage := findUnallocatedRange(&f.usage, f.minUnallocatedPage, length, alignment)
+ end := start + length
+ // File offsets are int64s. Since length must be strictly positive, end
+ // cannot legitimately be 0.
+ if end < start || int64(end) <= 0 {
+ return platform.FileRange{}, syserror.ENOMEM
+ }
+
+ // Expand the file if needed. Double the file size on each expansion;
+ // uncommitted pages have effectively no cost.
+ fileSize := f.fileSize
+ for int64(end) > fileSize {
+ if fileSize >= 2*fileSize {
+ // fileSize overflow.
+ return platform.FileRange{}, syserror.ENOMEM
+ }
+ fileSize *= 2
+ }
+ if fileSize > f.fileSize {
+ if err := f.file.Truncate(fileSize); err != nil {
+ return platform.FileRange{}, err
+ }
+ f.fileSize = fileSize
+ f.mappingsMu.Lock()
+ oldMappings := f.mappings.Load().([]uintptr)
+ newMappings := make([]uintptr, fileSize>>chunkShift)
+ copy(newMappings, oldMappings)
+ f.mappings.Store(newMappings)
+ f.mappingsMu.Unlock()
+ }
+
+ // Mark selected pages as in use.
+ fr := platform.FileRange{start, end}
+ if !f.usage.Add(fr, usageInfo{
+ kind: kind,
+ refs: 1,
+ }) {
+ panic(fmt.Sprintf("allocating %v: failed to insert into usage set:\n%v", fr, &f.usage))
+ }
+
+ if minUnallocatedPage < start {
+ f.minUnallocatedPage = minUnallocatedPage
+ } else {
+ // start was the first unallocated page. The next must be
+ // somewhere beyond end.
+ f.minUnallocatedPage = end
+ }
+
+ return fr, nil
+}
+
+// findUnallocatedRange returns the first unallocated page in usage of the
+// specified length and alignment beginning at page start and the first single
+// unallocated page.
+func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uint64, uint64) {
+ // Only searched until the first page is found.
+ firstPage := start
+ foundFirstPage := false
+ alignMask := alignment - 1
+ for seg := usage.LowerBoundSegment(start); seg.Ok(); seg = seg.NextSegment() {
+ r := seg.Range()
+
+ if !foundFirstPage && r.Start > firstPage {
+ foundFirstPage = true
+ }
+
+ if start >= r.End {
+ // start was rounded up to an alignment boundary from the end
+ // of a previous segment and is now beyond r.End.
+ continue
+ }
+ // This segment represents allocated or reclaimable pages; only the
+ // range from start to the segment's beginning is allocatable, and the
+ // next allocatable range begins after the segment.
+ if r.Start > start && r.Start-start >= length {
+ break
+ }
+ start = (r.End + alignMask) &^ alignMask
+ if !foundFirstPage {
+ firstPage = r.End
+ }
+ }
+ return start, firstPage
+}
+
+// AllocateAndFill allocates memory of the given kind and fills it by calling
+// r.ReadToBlocks() repeatedly until either length bytes are read or a non-nil
+// error is returned. It returns the memory filled by r, truncated down to the
+// nearest page. If this is shorter than length bytes due to an error returned
+// by r.ReadToBlocks(), it returns that error.
+//
+// Preconditions: length > 0. length must be page-aligned.
+func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) {
+ fr, err := f.Allocate(length, kind)
+ if err != nil {
+ return platform.FileRange{}, err
+ }
+ dsts, err := f.MapInternal(fr, usermem.Write)
+ if err != nil {
+ f.DecRef(fr)
+ return platform.FileRange{}, err
+ }
+ n, err := safemem.ReadFullToBlocks(r, dsts)
+ un := uint64(usermem.Addr(n).RoundDown())
+ if un < length {
+ // Free unused memory and update fr to contain only the memory that is
+ // still allocated.
+ f.DecRef(platform.FileRange{fr.Start + un, fr.End})
+ fr.End = fr.Start + un
+ }
+ return fr, err
+}
+
+// fallocate(2) modes, defined in Linux's include/uapi/linux/falloc.h.
+const (
+ _FALLOC_FL_KEEP_SIZE = 1
+ _FALLOC_FL_PUNCH_HOLE = 2
+)
+
+// Decommit releases resources associated with maintaining the contents of the
+// given pages. If Decommit succeeds, future accesses of the decommitted pages
+// will read zeroes.
+//
+// Preconditions: fr.Length() > 0.
+func (f *MemoryFile) Decommit(fr platform.FileRange) error {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ panic(fmt.Sprintf("invalid range: %v", fr))
+ }
+
+ // "After a successful call, subsequent reads from this range will
+ // return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with
+ // FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2)
+ err := syscall.Fallocate(
+ int(f.file.Fd()),
+ _FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE,
+ int64(fr.Start),
+ int64(fr.Length()))
+ if err != nil {
+ return err
+ }
+ f.markDecommitted(fr)
+ return nil
+}
+
+func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ // Since we're changing the knownCommitted attribute, we need to merge
+ // across the entire range to ensure that the usage tree is minimal.
+ gap := f.usage.ApplyContiguous(fr, func(seg usageIterator) {
+ val := seg.ValuePtr()
+ if val.knownCommitted {
+ // Drop the usageExpected appropriately.
+ amount := seg.Range().Length()
+ usage.MemoryAccounting.Dec(amount, val.kind)
+ f.usageExpected -= amount
+ val.knownCommitted = false
+ }
+ })
+ if gap.Ok() {
+ panic(fmt.Sprintf("Decommit(%v): attempted to decommit unallocated pages %v:\n%v", fr, gap.Range(), &f.usage))
+ }
+ f.usage.MergeRange(fr)
+}
+
+// IncRef implements platform.File.IncRef.
+func (f *MemoryFile) IncRef(fr platform.FileRange) {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ panic(fmt.Sprintf("invalid range: %v", fr))
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ gap := f.usage.ApplyContiguous(fr, func(seg usageIterator) {
+ seg.ValuePtr().refs++
+ })
+ if gap.Ok() {
+ panic(fmt.Sprintf("IncRef(%v): attempted to IncRef on unallocated pages %v:\n%v", fr, gap.Range(), &f.usage))
+ }
+
+ f.usage.MergeAdjacent(fr)
+}
+
+// DecRef implements platform.File.DecRef.
+func (f *MemoryFile) DecRef(fr platform.FileRange) {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ panic(fmt.Sprintf("invalid range: %v", fr))
+ }
+
+ var freed bool
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ for seg := f.usage.FindSegment(fr.Start); seg.Ok() && seg.Start() < fr.End; seg = seg.NextSegment() {
+ seg = f.usage.Isolate(seg, fr)
+ val := seg.ValuePtr()
+ if val.refs == 0 {
+ panic(fmt.Sprintf("DecRef(%v): 0 existing references on %v:\n%v", fr, seg.Range(), &f.usage))
+ }
+ val.refs--
+ if val.refs == 0 {
+ freed = true
+ // Reclassify memory as System, until it's freed by the reclaim
+ // goroutine.
+ if val.knownCommitted {
+ usage.MemoryAccounting.Move(seg.Range().Length(), usage.System, val.kind)
+ }
+ val.kind = usage.System
+ }
+ }
+ f.usage.MergeAdjacent(fr)
+
+ if freed {
+ if fr.Start < f.minReclaimablePage {
+ // We've freed at least one lower page.
+ f.minReclaimablePage = fr.Start
+ }
+ f.reclaimable = true
+ f.reclaimCond.Signal()
+ }
+}
+
+// MapInternal implements platform.File.MapInternal.
+func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ if !fr.WellFormed() || fr.Length() == 0 {
+ panic(fmt.Sprintf("invalid range: %v", fr))
+ }
+ if at.Execute {
+ return safemem.BlockSeq{}, syserror.EACCES
+ }
+
+ chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift)
+ if chunks == 1 {
+ // Avoid an unnecessary slice allocation.
+ var seq safemem.BlockSeq
+ err := f.forEachMappingSlice(fr, func(bs []byte) {
+ seq = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(bs))
+ })
+ return seq, err
+ }
+ blocks := make([]safemem.Block, 0, chunks)
+ err := f.forEachMappingSlice(fr, func(bs []byte) {
+ blocks = append(blocks, safemem.BlockFromSafeSlice(bs))
+ })
+ return safemem.BlockSeqFromSlice(blocks), err
+}
+
+// forEachMappingSlice invokes fn on a sequence of byte slices that
+// collectively map all bytes in fr.
+func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error {
+ mappings := f.mappings.Load().([]uintptr)
+ for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
+ chunk := int(chunkStart >> chunkShift)
+ m := atomic.LoadUintptr(&mappings[chunk])
+ if m == 0 {
+ var err error
+ mappings, m, err = f.getChunkMapping(chunk)
+ if err != nil {
+ return err
+ }
+ }
+ startOff := uint64(0)
+ if chunkStart < fr.Start {
+ startOff = fr.Start - chunkStart
+ }
+ endOff := uint64(chunkSize)
+ if chunkStart+chunkSize > fr.End {
+ endOff = fr.End - chunkStart
+ }
+ fn(unsafeSlice(m, chunkSize)[startOff:endOff])
+ }
+ return nil
+}
+
+func (f *MemoryFile) getChunkMapping(chunk int) ([]uintptr, uintptr, error) {
+ f.mappingsMu.Lock()
+ defer f.mappingsMu.Unlock()
+ // Another thread may have replaced f.mappings altogether due to file
+ // expansion.
+ mappings := f.mappings.Load().([]uintptr)
+ // Another thread may have already mapped the chunk.
+ if m := mappings[chunk]; m != 0 {
+ return mappings, m, nil
+ }
+ m, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ 0,
+ chunkSize,
+ syscall.PROT_READ|syscall.PROT_WRITE,
+ syscall.MAP_SHARED,
+ f.file.Fd(),
+ uintptr(chunk<<chunkShift))
+ if errno != 0 {
+ return nil, 0, errno
+ }
+ atomic.StoreUintptr(&mappings[chunk], m)
+ return mappings, m, nil
+}
+
+// MarkEvictable allows f to request memory deallocation by calling
+// user.Evict(er) in the future.
+//
+// Redundantly marking an already-evictable range as evictable has no effect.
+func (f *MemoryFile) MarkEvictable(user EvictableMemoryUser, er EvictableRange) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ info, ok := f.evictable[user]
+ if !ok {
+ info = &evictableMemoryUserInfo{}
+ f.evictable[user] = info
+ }
+ gap := info.ranges.LowerBoundGap(er.Start)
+ for gap.Ok() && gap.Start() < er.End {
+ gapER := gap.Range().Intersect(er)
+ if gapER.Length() == 0 {
+ gap = gap.NextGap()
+ continue
+ }
+ gap = info.ranges.Insert(gap, gapER, evictableRangeSetValue{}).NextGap()
+ }
+ if !info.evicting {
+ switch f.opts.DelayedEviction {
+ case DelayedEvictionDisabled:
+ // Kick off eviction immediately.
+ f.startEvictionGoroutineLocked(user, info)
+ case DelayedEvictionEnabled:
+ // Ensure that the reclaimer goroutine is running, so that it can
+ // start eviction when necessary.
+ f.reclaimCond.Signal()
+ }
+ }
+}
+
+// MarkUnevictable informs f that user no longer considers er to be evictable,
+// so the MemoryFile should no longer call user.Evict(er). Note that, per
+// EvictableMemoryUser.Evict's documentation, user.Evict(er) may still be
+// called even after MarkUnevictable returns due to race conditions, and
+// implementations of EvictableMemoryUser must handle this possibility.
+//
+// Redundantly marking an already-unevictable range as unevictable has no
+// effect.
+func (f *MemoryFile) MarkUnevictable(user EvictableMemoryUser, er EvictableRange) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ info, ok := f.evictable[user]
+ if !ok {
+ return
+ }
+ seg := info.ranges.LowerBoundSegment(er.Start)
+ for seg.Ok() && seg.Start() < er.End {
+ seg = info.ranges.Isolate(seg, er)
+ seg = info.ranges.Remove(seg).NextSegment()
+ }
+ // We can only remove info if there's no eviction goroutine running on its
+ // behalf.
+ if !info.evicting && info.ranges.IsEmpty() {
+ delete(f.evictable, user)
+ }
+}
+
+// MarkAllUnevictable informs f that user no longer considers any offsets to be
+// evictable. It otherwise has the same semantics as MarkUnevictable.
+func (f *MemoryFile) MarkAllUnevictable(user EvictableMemoryUser) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ info, ok := f.evictable[user]
+ if !ok {
+ return
+ }
+ info.ranges.RemoveAll()
+ // We can only remove info if there's no eviction goroutine running on its
+ // behalf.
+ if !info.evicting {
+ delete(f.evictable, user)
+ }
+}
+
+// UpdateUsage ensures that the memory usage statistics in
+// usage.MemoryAccounting are up to date.
+func (f *MemoryFile) UpdateUsage() error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // If the underlying usage matches where the usage tree already
+ // represents, then we can just avoid the entire scan (we know it's
+ // accurate).
+ currentUsage, err := f.TotalUsage()
+ if err != nil {
+ return err
+ }
+ if currentUsage == f.usageExpected && f.usageSwapped == 0 {
+ log.Debugf("UpdateUsage: skipped with usageSwapped=0.")
+ return nil
+ }
+ // If the current usage matches the expected but there's swap
+ // accounting, then ensure a scan takes place at least every second
+ // (when requested).
+ if currentUsage == f.usageExpected+f.usageSwapped && time.Now().Before(f.usageLast.Add(time.Second)) {
+ log.Debugf("UpdateUsage: skipped with usageSwapped!=0.")
+ return nil
+ }
+
+ f.usageLast = time.Now()
+ err = f.updateUsageLocked(currentUsage, mincore)
+ log.Debugf("UpdateUsage: currentUsage=%d, usageExpected=%d, usageSwapped=%d.",
+ currentUsage, f.usageExpected, f.usageSwapped)
+ log.Debugf("UpdateUsage: took %v.", time.Since(f.usageLast))
+ return err
+}
+
+// updateUsageLocked attempts to detect commitment of previous-uncommitted
+// pages by invoking checkCommitted, which is a function that, for each page i
+// in bs, sets committed[i] to 1 if the page is committed and 0 otherwise.
+//
+// Precondition: f.mu must be held.
+func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(bs []byte, committed []byte) error) error {
+ // Track if anything changed to elide the merge. In the common case, we
+ // expect all segments to be committed and no merge to occur.
+ changedAny := false
+ defer func() {
+ if changedAny {
+ f.usage.MergeAll()
+ }
+
+ // Adjust the swap usage to reflect reality.
+ if f.usageExpected < currentUsage {
+ // Since no pages may be marked decommitted while we hold mu, we
+ // know that usage may have only increased since we got the last
+ // current usage. Therefore, if usageExpected is still short of
+ // currentUsage, we must assume that the difference is in pages
+ // that have been swapped.
+ newUsageSwapped := currentUsage - f.usageExpected
+ if f.usageSwapped < newUsageSwapped {
+ usage.MemoryAccounting.Inc(newUsageSwapped-f.usageSwapped, usage.System)
+ } else {
+ usage.MemoryAccounting.Dec(f.usageSwapped-newUsageSwapped, usage.System)
+ }
+ f.usageSwapped = newUsageSwapped
+ } else if f.usageSwapped != 0 {
+ // We have more usage accounted for than the file itself.
+ // That's fine, we probably caught a race where pages were
+ // being committed while the above loop was running. Just
+ // report the higher number that we found and ignore swap.
+ usage.MemoryAccounting.Dec(f.usageSwapped, usage.System)
+ f.usageSwapped = 0
+ }
+ }()
+
+ // Reused mincore buffer, will generally be <= 4096 bytes.
+ var buf []byte
+
+ // Iterate over all usage data. There will only be usage segments
+ // present when there is an associated reference.
+ for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ val := seg.Value()
+
+ // Already known to be committed; ignore.
+ if val.knownCommitted {
+ continue
+ }
+
+ // Assume that reclaimable pages (that aren't already known to be
+ // committed) are not committed. This isn't necessarily true, even
+ // after the reclaimer does Decommit(), because the kernel may
+ // subsequently back the hugepage-sized region containing the
+ // decommitted page with a hugepage. However, it's consistent with our
+ // treatment of unallocated pages, which have the same property.
+ if val.refs == 0 {
+ continue
+ }
+
+ // Get the range for this segment. As we touch slices, the
+ // Start value will be walked along.
+ r := seg.Range()
+
+ var checkErr error
+ err := f.forEachMappingSlice(r, func(s []byte) {
+ if checkErr != nil {
+ return
+ }
+
+ // Ensure that we have sufficient buffer for the call
+ // (one byte per page). The length of each slice must
+ // be page-aligned.
+ bufLen := len(s) / usermem.PageSize
+ if len(buf) < bufLen {
+ buf = make([]byte, bufLen)
+ }
+
+ // Query for new pages in core.
+ if err := checkCommitted(s, buf); err != nil {
+ checkErr = err
+ return
+ }
+
+ // Scan each page and switch out segments.
+ populatedRun := false
+ populatedRunStart := 0
+ for i := 0; i <= bufLen; i++ {
+ // We run past the end of the slice here to
+ // simplify the logic and only set populated if
+ // we're still looking at elements.
+ populated := false
+ if i < bufLen {
+ populated = buf[i]&0x1 != 0
+ }
+
+ switch {
+ case populated == populatedRun:
+ // Keep the run going.
+ continue
+ case populated && !populatedRun:
+ // Begin the run.
+ populatedRun = true
+ populatedRunStart = i
+ // Keep going.
+ continue
+ case !populated && populatedRun:
+ // Finish the run by changing this segment.
+ runRange := platform.FileRange{
+ Start: r.Start + uint64(populatedRunStart*usermem.PageSize),
+ End: r.Start + uint64(i*usermem.PageSize),
+ }
+ seg = f.usage.Isolate(seg, runRange)
+ seg.ValuePtr().knownCommitted = true
+ // Advance the segment only if we still
+ // have work to do in the context of
+ // the original segment from the for
+ // loop. Otherwise, the for loop itself
+ // will advance the segment
+ // appropriately.
+ if runRange.End != r.End {
+ seg = seg.NextSegment()
+ }
+ amount := runRange.Length()
+ usage.MemoryAccounting.Inc(amount, val.kind)
+ f.usageExpected += amount
+ changedAny = true
+ populatedRun = false
+ }
+ }
+
+ // Advance r.Start.
+ r.Start += uint64(len(s))
+ })
+ if checkErr != nil {
+ return checkErr
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// TotalUsage returns an aggregate usage for all memory statistics except
+// Mapped (which is external to MemoryFile). This is generally much cheaper
+// than UpdateUsage, but will not provide a fine-grained breakdown.
+func (f *MemoryFile) TotalUsage() (uint64, error) {
+ // Stat the underlying file to discover the underlying usage. stat(2)
+ // always reports the allocated block count in units of 512 bytes. This
+ // includes pages in the page cache and swapped pages.
+ var stat syscall.Stat_t
+ if err := syscall.Fstat(int(f.file.Fd()), &stat); err != nil {
+ return 0, err
+ }
+ return uint64(stat.Blocks * 512), nil
+}
+
+// TotalSize returns the current size of the backing file in bytes, which is an
+// upper bound on the amount of memory that can currently be allocated from the
+// MemoryFile. The value returned by TotalSize is permitted to change.
+func (f *MemoryFile) TotalSize() uint64 {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ return uint64(f.fileSize)
+}
+
+// File returns the backing file.
+func (f *MemoryFile) File() *os.File {
+ return f.file
+}
+
+// FD implements platform.File.FD.
+func (f *MemoryFile) FD() int {
+ return int(f.file.Fd())
+}
+
+// String implements fmt.Stringer.String.
+//
+// Note that because f.String locks f.mu, calling f.String internally
+// (including indirectly through the fmt package) risks recursive locking.
+// Within the pgalloc package, use f.usage directly instead.
+func (f *MemoryFile) String() string {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ return f.usage.String()
+}
+
+// runReclaim implements the reclaimer goroutine, which continuously decommits
+// reclaimable pages in order to reduce memory usage and make them available
+// for allocation.
+func (f *MemoryFile) runReclaim() {
+ for {
+ fr, ok := f.findReclaimable()
+ if !ok {
+ break
+ }
+
+ if err := f.Decommit(fr); err != nil {
+ log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
+ // Zero the pages manually. This won't reduce memory usage, but at
+ // least ensures that the pages will be zero when reallocated.
+ f.forEachMappingSlice(fr, func(bs []byte) {
+ for i := range bs {
+ bs[i] = 0
+ }
+ })
+ // Pretend the pages were decommitted even though they weren't,
+ // since the memory accounting implementation has no idea how to
+ // deal with this.
+ f.markDecommitted(fr)
+ }
+ f.markReclaimed(fr)
+ }
+ // We only get here if findReclaimable finds f.destroyed set and returns
+ // false.
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ if !f.destroyed {
+ panic("findReclaimable broke out of reclaim loop, but destroyed is no longer set")
+ }
+ f.file.Close()
+ // Ensure that any attempts to use f.file.Fd() fail instead of getting a fd
+ // that has possibly been reassigned.
+ f.file = nil
+ f.mappingsMu.Lock()
+ defer f.mappingsMu.Unlock()
+ mappings := f.mappings.Load().([]uintptr)
+ for i, m := range mappings {
+ if m != 0 {
+ _, _, errno := syscall.Syscall(syscall.SYS_MUNMAP, m, chunkSize, 0)
+ if errno != 0 {
+ log.Warningf("Failed to unmap mapping %#x for MemoryFile chunk %d: %v", m, i, errno)
+ }
+ }
+ }
+ // Similarly, invalidate f.mappings. (atomic.Value.Store(nil) panics.)
+ f.mappings.Store([]uintptr{})
+}
+
+func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ for {
+ for {
+ if f.destroyed {
+ return platform.FileRange{}, false
+ }
+ if f.reclaimable {
+ break
+ }
+ if f.opts.DelayedEviction == DelayedEvictionEnabled {
+ // No work to do. Evict any pending evictable allocations to
+ // get more reclaimable pages before going to sleep.
+ f.startEvictionsLocked()
+ }
+ f.reclaimCond.Wait()
+ }
+ // Allocate returns the first usable range in offset order and is
+ // currently a linear scan, so reclaiming from the beginning of the
+ // file minimizes the expected latency of Allocate.
+ for seg := f.usage.LowerBoundSegment(f.minReclaimablePage); seg.Ok(); seg = seg.NextSegment() {
+ if seg.ValuePtr().refs == 0 {
+ f.minReclaimablePage = seg.End()
+ return seg.Range(), true
+ }
+ }
+ // No pages are reclaimable.
+ f.reclaimable = false
+ f.minReclaimablePage = maxPage
+ }
+}
+
+func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ seg := f.usage.FindSegment(fr.Start)
+ // All of fr should be mapped to a single uncommitted reclaimable segment
+ // accounted to System.
+ if !seg.Ok() {
+ panic(fmt.Sprintf("reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage))
+ }
+ if !seg.Range().IsSupersetOf(fr) {
+ panic(fmt.Sprintf("reclaimed pages %v are not entirely contained in segment %v with state %v:\n%v", fr, seg.Range(), seg.Value(), &f.usage))
+ }
+ if got, want := seg.Value(), (usageInfo{
+ kind: usage.System,
+ knownCommitted: false,
+ refs: 0,
+ }); got != want {
+ panic(fmt.Sprintf("reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage))
+ }
+ // Deallocate reclaimed pages. Even though all of seg is reclaimable, the
+ // caller of markReclaimed may not have decommitted it, so we can only mark
+ // fr as reclaimed.
+ f.usage.Remove(f.usage.Isolate(seg, fr))
+ if fr.Start < f.minUnallocatedPage {
+ // We've deallocated at least one lower page.
+ f.minUnallocatedPage = fr.Start
+ }
+}
+
+// StartEvictions requests that f evict all evictable allocations. It does not
+// wait for eviction to complete; for this, see MemoryFile.WaitForEvictions.
+func (f *MemoryFile) StartEvictions() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.startEvictionsLocked()
+}
+
+// Preconditions: f.mu must be locked.
+func (f *MemoryFile) startEvictionsLocked() {
+ for user, info := range f.evictable {
+ // Don't start multiple goroutines to evict the same user's
+ // allocations.
+ if !info.evicting {
+ f.startEvictionGoroutineLocked(user, info)
+ }
+ }
+}
+
+// Preconditions: info == f.evictable[user]. !info.evicting. f.mu must be
+// locked.
+func (f *MemoryFile) startEvictionGoroutineLocked(user EvictableMemoryUser, info *evictableMemoryUserInfo) {
+ info.evicting = true
+ f.evictionWG.Add(1)
+ go func() { // S/R-SAFE: f.evictionWG
+ defer f.evictionWG.Done()
+ for {
+ f.mu.Lock()
+ info, ok := f.evictable[user]
+ if !ok {
+ // This shouldn't happen: only this goroutine is permitted
+ // to delete this entry.
+ f.mu.Unlock()
+ panic(fmt.Sprintf("evictableMemoryUserInfo for EvictableMemoryUser %v deleted while eviction goroutine running", user))
+ }
+ if info.ranges.IsEmpty() {
+ delete(f.evictable, user)
+ f.mu.Unlock()
+ return
+ }
+ // Evict from the end of info.ranges, under the assumption that
+ // if ranges in user start being used again (and are
+ // consequently marked unevictable), such uses are more likely
+ // to start from the beginning of user.
+ seg := info.ranges.LastSegment()
+ er := seg.Range()
+ info.ranges.Remove(seg)
+ // user.Evict() must be called without holding f.mu to avoid
+ // circular lock ordering.
+ f.mu.Unlock()
+ user.Evict(context.Background(), er)
+ }
+ }()
+}
+
+// WaitForEvictions blocks until f is no longer evicting any evictable
+// allocations.
+func (f *MemoryFile) WaitForEvictions() {
+ f.evictionWG.Wait()
+}
+
+type usageSetFunctions struct{}
+
+func (usageSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (usageSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (usageSetFunctions) ClearValue(val *usageInfo) {
+}
+
+func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) {
+ return val1, val1 == val2
+}
+
+func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
+ return val, val
+}
+
+// evictableRangeSetValue is the value type of evictableRangeSet.
+type evictableRangeSetValue struct{}
+
+type evictableRangeSetFunctions struct{}
+
+func (evictableRangeSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (evictableRangeSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (evictableRangeSetFunctions) ClearValue(val *evictableRangeSetValue) {
+}
+
+func (evictableRangeSetFunctions) Merge(_ EvictableRange, _ evictableRangeSetValue, _ EvictableRange, _ evictableRangeSetValue) (evictableRangeSetValue, bool) {
+ return evictableRangeSetValue{}, true
+}
+
+func (evictableRangeSetFunctions) Split(_ EvictableRange, _ evictableRangeSetValue, _ uint64) (evictableRangeSetValue, evictableRangeSetValue) {
+ return evictableRangeSetValue{}, evictableRangeSetValue{}
+}
diff --git a/pkg/sentry/pgalloc/pgalloc_state_autogen.go b/pkg/sentry/pgalloc/pgalloc_state_autogen.go
new file mode 100755
index 000000000..36a5aafa1
--- /dev/null
+++ b/pkg/sentry/pgalloc/pgalloc_state_autogen.go
@@ -0,0 +1,146 @@
+// automatically generated by stateify.
+
+package pgalloc
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *EvictableRange) beforeSave() {}
+func (x *EvictableRange) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *EvictableRange) afterLoad() {}
+func (x *EvictableRange) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func (x *evictableRangeSet) beforeSave() {}
+func (x *evictableRangeSet) save(m state.Map) {
+ x.beforeSave()
+ var root *evictableRangeSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *evictableRangeSet) afterLoad() {}
+func (x *evictableRangeSet) load(m state.Map) {
+ m.LoadValue("root", new(*evictableRangeSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*evictableRangeSegmentDataSlices)) })
+}
+
+func (x *evictableRangenode) beforeSave() {}
+func (x *evictableRangenode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *evictableRangenode) afterLoad() {}
+func (x *evictableRangenode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *evictableRangeSegmentDataSlices) beforeSave() {}
+func (x *evictableRangeSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *evictableRangeSegmentDataSlices) afterLoad() {}
+func (x *evictableRangeSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func (x *usageInfo) beforeSave() {}
+func (x *usageInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("kind", &x.kind)
+ m.Save("knownCommitted", &x.knownCommitted)
+ m.Save("refs", &x.refs)
+}
+
+func (x *usageInfo) afterLoad() {}
+func (x *usageInfo) load(m state.Map) {
+ m.Load("kind", &x.kind)
+ m.Load("knownCommitted", &x.knownCommitted)
+ m.Load("refs", &x.refs)
+}
+
+func (x *usageSet) beforeSave() {}
+func (x *usageSet) save(m state.Map) {
+ x.beforeSave()
+ var root *usageSegmentDataSlices = x.saveRoot()
+ m.SaveValue("root", root)
+}
+
+func (x *usageSet) afterLoad() {}
+func (x *usageSet) load(m state.Map) {
+ m.LoadValue("root", new(*usageSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*usageSegmentDataSlices)) })
+}
+
+func (x *usagenode) beforeSave() {}
+func (x *usagenode) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nrSegments", &x.nrSegments)
+ m.Save("parent", &x.parent)
+ m.Save("parentIndex", &x.parentIndex)
+ m.Save("hasChildren", &x.hasChildren)
+ m.Save("keys", &x.keys)
+ m.Save("values", &x.values)
+ m.Save("children", &x.children)
+}
+
+func (x *usagenode) afterLoad() {}
+func (x *usagenode) load(m state.Map) {
+ m.Load("nrSegments", &x.nrSegments)
+ m.Load("parent", &x.parent)
+ m.Load("parentIndex", &x.parentIndex)
+ m.Load("hasChildren", &x.hasChildren)
+ m.Load("keys", &x.keys)
+ m.Load("values", &x.values)
+ m.Load("children", &x.children)
+}
+
+func (x *usageSegmentDataSlices) beforeSave() {}
+func (x *usageSegmentDataSlices) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+ m.Save("Values", &x.Values)
+}
+
+func (x *usageSegmentDataSlices) afterLoad() {}
+func (x *usageSegmentDataSlices) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+ m.Load("Values", &x.Values)
+}
+
+func init() {
+ state.Register("pgalloc.EvictableRange", (*EvictableRange)(nil), state.Fns{Save: (*EvictableRange).save, Load: (*EvictableRange).load})
+ state.Register("pgalloc.evictableRangeSet", (*evictableRangeSet)(nil), state.Fns{Save: (*evictableRangeSet).save, Load: (*evictableRangeSet).load})
+ state.Register("pgalloc.evictableRangenode", (*evictableRangenode)(nil), state.Fns{Save: (*evictableRangenode).save, Load: (*evictableRangenode).load})
+ state.Register("pgalloc.evictableRangeSegmentDataSlices", (*evictableRangeSegmentDataSlices)(nil), state.Fns{Save: (*evictableRangeSegmentDataSlices).save, Load: (*evictableRangeSegmentDataSlices).load})
+ state.Register("pgalloc.usageInfo", (*usageInfo)(nil), state.Fns{Save: (*usageInfo).save, Load: (*usageInfo).load})
+ state.Register("pgalloc.usageSet", (*usageSet)(nil), state.Fns{Save: (*usageSet).save, Load: (*usageSet).load})
+ state.Register("pgalloc.usagenode", (*usagenode)(nil), state.Fns{Save: (*usagenode).save, Load: (*usagenode).load})
+ state.Register("pgalloc.usageSegmentDataSlices", (*usageSegmentDataSlices)(nil), state.Fns{Save: (*usageSegmentDataSlices).save, Load: (*usageSegmentDataSlices).load})
+}
diff --git a/pkg/sentry/pgalloc/pgalloc_unsafe.go b/pkg/sentry/pgalloc/pgalloc_unsafe.go
new file mode 100644
index 000000000..a4b5d581c
--- /dev/null
+++ b/pkg/sentry/pgalloc/pgalloc_unsafe.go
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pgalloc
+
+import (
+ "reflect"
+ "syscall"
+ "unsafe"
+)
+
+func unsafeSlice(addr uintptr, length int) (slice []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ sh.Data = addr
+ sh.Len = length
+ sh.Cap = length
+ return
+}
+
+func mincore(s []byte, buf []byte) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_MINCORE,
+ uintptr(unsafe.Pointer(&s[0])),
+ uintptr(len(s)),
+ uintptr(unsafe.Pointer(&buf[0]))); errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
new file mode 100644
index 000000000..d4ba384b1
--- /dev/null
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -0,0 +1,210 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pgalloc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+// SaveTo writes f's state to the given stream.
+func (f *MemoryFile) SaveTo(w io.Writer) error {
+ // Wait for reclaim.
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ for f.reclaimable {
+ f.reclaimCond.Signal()
+ f.mu.Unlock()
+ runtime.Gosched()
+ f.mu.Lock()
+ }
+
+ // Ensure that there are no pending evictions.
+ if len(f.evictable) != 0 {
+ panic(fmt.Sprintf("evictions still pending for %d users; call StartEvictions and WaitForEvictions before SaveTo", len(f.evictable)))
+ }
+
+ // Ensure that all pages that contain data have knownCommitted set, since
+ // we only store knownCommitted pages below.
+ zeroPage := make([]byte, usermem.PageSize)
+ err := f.updateUsageLocked(0, func(bs []byte, committed []byte) error {
+ for pgoff := 0; pgoff < len(bs); pgoff += usermem.PageSize {
+ i := pgoff / usermem.PageSize
+ pg := bs[pgoff : pgoff+usermem.PageSize]
+ if !bytes.Equal(pg, zeroPage) {
+ committed[i] = 1
+ continue
+ }
+ committed[i] = 0
+ // Reading the page caused it to be committed; decommit it to
+ // reduce memory usage.
+ //
+ // "MADV_REMOVE [...] Free up a given range of pages and its
+ // associated backing store. This is equivalent to punching a hole
+ // in the corresponding byte range of the backing store (see
+ // fallocate(2))." - madvise(2)
+ if err := syscall.Madvise(pg, syscall.MADV_REMOVE); err != nil {
+ // This doesn't impact the correctness of saved memory, it
+ // just means that we're incrementally more likely to OOM.
+ // Complain, but don't abort saving.
+ log.Warningf("Decommitting page %p while saving failed: %v", pg, err)
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ // Save metadata.
+ if err := state.Save(w, &f.fileSize, nil); err != nil {
+ return err
+ }
+ if err := state.Save(w, &f.usage, nil); err != nil {
+ return err
+ }
+
+ // Dump out committed pages.
+ for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ if !seg.Value().knownCommitted {
+ continue
+ }
+ // Write a header to distinguish from objects.
+ if err := state.WriteHeader(w, uint64(seg.Range().Length()), false); err != nil {
+ return err
+ }
+ // Write out data.
+ var ioErr error
+ err := f.forEachMappingSlice(seg.Range(), func(s []byte) {
+ if ioErr != nil {
+ return
+ }
+ _, ioErr = w.Write(s)
+ })
+ if ioErr != nil {
+ return ioErr
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// LoadFrom loads MemoryFile state from the given stream.
+func (f *MemoryFile) LoadFrom(r io.Reader) error {
+ // Load metadata.
+ if err := state.Load(r, &f.fileSize, nil); err != nil {
+ return err
+ }
+ if err := f.file.Truncate(f.fileSize); err != nil {
+ return err
+ }
+ newMappings := make([]uintptr, f.fileSize>>chunkShift)
+ f.mappings.Store(newMappings)
+ if err := state.Load(r, &f.usage, nil); err != nil {
+ return err
+ }
+
+ // Try to map committed chunks concurrently: For any given chunk, either
+ // this loop or the following one will mmap the chunk first and cache it in
+ // f.mappings for the other, but this loop is likely to run ahead of the
+ // other since it doesn't do any work between mmaps. The rest of this
+ // function doesn't mutate f.usage, so it's safe to iterate concurrently.
+ mapperDone := make(chan struct{})
+ mapperCanceled := int32(0)
+ go func() { // S/R-SAFE: see comment
+ defer func() { close(mapperDone) }()
+ for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ if atomic.LoadInt32(&mapperCanceled) != 0 {
+ return
+ }
+ if seg.Value().knownCommitted {
+ f.forEachMappingSlice(seg.Range(), func(s []byte) {})
+ }
+ }
+ }()
+ defer func() {
+ atomic.StoreInt32(&mapperCanceled, 1)
+ <-mapperDone
+ }()
+
+ // Load committed pages.
+ for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ if !seg.Value().knownCommitted {
+ continue
+ }
+ // Verify header.
+ length, object, err := state.ReadHeader(r)
+ if err != nil {
+ return err
+ }
+ if object {
+ // Not expected.
+ return fmt.Errorf("unexpected object")
+ }
+ if expected := uint64(seg.Range().Length()); length != expected {
+ // Size mismatch.
+ return fmt.Errorf("mismatched segment: expected %d, got %d", expected, length)
+ }
+ // Read data.
+ var ioErr error
+ err = f.forEachMappingSlice(seg.Range(), func(s []byte) {
+ if ioErr != nil {
+ return
+ }
+ _, ioErr = io.ReadFull(r, s)
+ })
+ if ioErr != nil {
+ return ioErr
+ }
+ if err != nil {
+ return err
+ }
+
+ // Update accounting for restored pages. We need to do this here since
+ // these segments are marked as "known committed", and will be skipped
+ // over on accounting scans.
+ usage.MemoryAccounting.Inc(seg.End()-seg.Start(), seg.Value().kind)
+ }
+
+ return nil
+}
+
+// MemoryFileProvider provides the MemoryFile method.
+//
+// This type exists to work around a save/restore defect. The only object in a
+// saved object graph that S/R allows to be replaced at time of restore is the
+// starting point of the restore, kernel.Kernel. However, the MemoryFile
+// changes between save and restore as well, so objects that need persistent
+// access to the MemoryFile must instead store a pointer to the Kernel and call
+// Kernel.MemoryFile() as required. In most cases, depending on the kernel
+// package directly would create a package dependency loop, so the stored
+// pointer must instead be a MemoryProvider interface object. Correspondingly,
+// kernel.Kernel is the only implementation of this interface.
+type MemoryFileProvider interface {
+ // MemoryFile returns the Kernel MemoryFile.
+ MemoryFile() *MemoryFile
+}
diff --git a/pkg/sentry/pgalloc/usage_set.go b/pkg/sentry/pgalloc/usage_set.go
new file mode 100755
index 000000000..8ef4952eb
--- /dev/null
+++ b/pkg/sentry/pgalloc/usage_set.go
@@ -0,0 +1,1274 @@
+package pgalloc
+
+import (
+ __generics_imported0 "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+)
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ usageminDegree = 10
+
+ usagemaxDegree = 2 * usageminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type usageSet struct {
+ root usagenode `state:".(*usageSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *usageSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *usageSet) IsEmptyRange(r __generics_imported0.FileRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *usageSet) Span() uint64 {
+ var sz uint64
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *usageSet) SpanRange(r __generics_imported0.FileRange) uint64 {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uint64
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *usageSet) FirstSegment() usageIterator {
+ if s.root.nrSegments == 0 {
+ return usageIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *usageSet) LastSegment() usageIterator {
+ if s.root.nrSegments == 0 {
+ return usageIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *usageSet) FirstGap() usageGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return usageGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *usageSet) LastGap() usageGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return usageGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *usageSet) Find(key uint64) (usageIterator, usageGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return usageIterator{n, i}, usageGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return usageIterator{}, usageGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *usageSet) FindSegment(key uint64) usageIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *usageSet) LowerBoundSegment(min uint64) usageIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *usageSet) UpperBoundSegment(max uint64) usageIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *usageSet) FindGap(key uint64) usageGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *usageSet) LowerBoundGap(min uint64) usageGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *usageSet) UpperBoundGap(max uint64) usageGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *usageSet) Add(r __generics_imported0.FileRange, val usageInfo) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *usageSet) AddWithoutMerging(r __generics_imported0.FileRange, val usageInfo) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *usageSet) Insert(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (usageSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (usageSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (usageSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *usageSet) InsertWithoutMerging(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *usageSet) InsertWithoutMergingUnchecked(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return usageIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *usageSet) Remove(seg usageIterator) usageGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ usageSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(usageGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *usageSet) RemoveAll() {
+ s.root = usagenode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *usageSet) RemoveRange(r __generics_imported0.FileRange) usageGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *usageSet) Merge(first, second usageIterator) usageIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *usageSet) MergeUnchecked(first, second usageIterator) usageIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (usageSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return usageIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *usageSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *usageSet) MergeRange(r __generics_imported0.FileRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *usageSet) MergeAdjacent(r __generics_imported0.FileRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *usageSet) Split(seg usageIterator, split uint64) (usageIterator, usageIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *usageSet) SplitUnchecked(seg usageIterator, split uint64) (usageIterator, usageIterator) {
+ val1, val2 := (usageSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *usageSet) SplitAt(split uint64) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *usageSet) Isolate(seg usageIterator, r __generics_imported0.FileRange) usageIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *usageSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg usageIterator)) usageGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return usageGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return usageGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type usagenode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *usagenode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [usagemaxDegree - 1]__generics_imported0.FileRange
+ values [usagemaxDegree - 1]usageInfo
+ children [usagemaxDegree]*usagenode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *usagenode) firstSegment() usageIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return usageIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *usagenode) lastSegment() usageIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return usageIterator{n, n.nrSegments - 1}
+}
+
+func (n *usagenode) prevSibling() *usagenode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *usagenode) nextSibling() *usagenode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *usagenode) rebalanceBeforeInsert(gap usageGapIterator) usageGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < usagemaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &usagenode{
+ nrSegments: usageminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &usagenode{
+ nrSegments: usageminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:usageminDegree-1], n.keys[:usageminDegree-1])
+ copy(left.values[:usageminDegree-1], n.values[:usageminDegree-1])
+ copy(right.keys[:usageminDegree-1], n.keys[usageminDegree:])
+ copy(right.values[:usageminDegree-1], n.values[usageminDegree:])
+ n.keys[0], n.values[0] = n.keys[usageminDegree-1], n.values[usageminDegree-1]
+ usagezeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:usageminDegree], n.children[:usageminDegree])
+ copy(right.children[:usageminDegree], n.children[usageminDegree:])
+ usagezeroNodeSlice(n.children[2:])
+ for i := 0; i < usageminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < usageminDegree {
+ return usageGapIterator{left, gap.index}
+ }
+ return usageGapIterator{right, gap.index - usageminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[usageminDegree-1], n.values[usageminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &usagenode{
+ nrSegments: usageminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:usageminDegree-1], n.keys[usageminDegree:])
+ copy(sibling.values[:usageminDegree-1], n.values[usageminDegree:])
+ usagezeroValueSlice(n.values[usageminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:usageminDegree], n.children[usageminDegree:])
+ usagezeroNodeSlice(n.children[usageminDegree:])
+ for i := 0; i < usageminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = usageminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < usageminDegree {
+ return gap
+ }
+ return usageGapIterator{sibling, gap.index - usageminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *usagenode) rebalanceAfterRemove(gap usageGapIterator) usageGapIterator {
+ for {
+ if n.nrSegments >= usageminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= usageminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ usageSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return usageGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return usageGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= usageminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ usageSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return usageGapIterator{n, n.nrSegments}
+ }
+ return usageGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return usageGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return usageGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *usagenode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = usageGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ usageSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type usageIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *usagenode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg usageIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg usageIterator) Range() __generics_imported0.FileRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg usageIterator) Start() uint64 {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg usageIterator) End() uint64 {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg usageIterator) SetRangeUnchecked(r __generics_imported0.FileRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg usageIterator) SetRange(r __generics_imported0.FileRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg usageIterator) SetStartUnchecked(start uint64) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg usageIterator) SetStart(start uint64) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg usageIterator) SetEndUnchecked(end uint64) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg usageIterator) SetEnd(end uint64) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg usageIterator) Value() usageInfo {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg usageIterator) ValuePtr() *usageInfo {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg usageIterator) SetValue(val usageInfo) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg usageIterator) PrevSegment() usageIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return usageIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return usageIterator{}
+ }
+ return usagesegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg usageIterator) NextSegment() usageIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return usageIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return usageIterator{}
+ }
+ return usagesegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg usageIterator) PrevGap() usageGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return usageGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg usageIterator) NextGap() usageGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return usageGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg usageIterator) PrevNonEmpty() (usageIterator, usageGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return usageIterator{}, gap
+ }
+ return gap.PrevSegment(), usageGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg usageIterator) NextNonEmpty() (usageIterator, usageGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return usageIterator{}, gap
+ }
+ return gap.NextSegment(), usageGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type usageGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *usagenode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap usageGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap usageGapIterator) Range() __generics_imported0.FileRange {
+ return __generics_imported0.FileRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap usageGapIterator) Start() uint64 {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return usageSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap usageGapIterator) End() uint64 {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return usageSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap usageGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap usageGapIterator) PrevSegment() usageIterator {
+ return usagesegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap usageGapIterator) NextSegment() usageIterator {
+ return usagesegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap usageGapIterator) PrevGap() usageGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return usageGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap usageGapIterator) NextGap() usageGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return usageGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func usagesegmentBeforePosition(n *usagenode, i int) usageIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return usageIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return usageIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func usagesegmentAfterPosition(n *usagenode, i int) usageIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return usageIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return usageIterator{n, i}
+}
+
+func usagezeroValueSlice(slice []usageInfo) {
+
+ for i := range slice {
+ usageSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func usagezeroNodeSlice(slice []*usagenode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *usageSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *usagenode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *usagenode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type usageSegmentDataSlices struct {
+ Start []uint64
+ End []uint64
+ Values []usageInfo
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *usageSet) ExportSortedSlices() *usageSegmentDataSlices {
+ var sds usageSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *usageSet) ImportSortedSlices(sds *usageSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *usageSet) saveRoot() *usageSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *usageSet) loadRoot(sds *usageSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/sentry/platform/context.go b/pkg/sentry/platform/context.go
new file mode 100644
index 000000000..793f57fd7
--- /dev/null
+++ b/pkg/sentry/platform/context.go
@@ -0,0 +1,36 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package platform
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextID is the auth package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxPlatform is a Context.Value key for a Platform.
+ CtxPlatform contextID = iota
+)
+
+// FromContext returns the Platform that is used to execute ctx's application
+// code, or nil if no such Platform exists.
+func FromContext(ctx context.Context) Platform {
+ if v := ctx.Value(CtxPlatform); v != nil {
+ return v.(Platform)
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/file_range.go b/pkg/sentry/platform/file_range.go
new file mode 100755
index 000000000..685d360e3
--- /dev/null
+++ b/pkg/sentry/platform/file_range.go
@@ -0,0 +1,62 @@
+package platform
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type FileRange struct {
+ // Start is the inclusive start of the range.
+ Start uint64
+
+ // End is the exclusive end of the range.
+ End uint64
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r FileRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r FileRange) Length() uint64 {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r FileRange) Contains(x uint64) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r FileRange) Overlaps(r2 FileRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r FileRange) IsSupersetOf(r2 FileRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r FileRange) Intersect(r2 FileRange) FileRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r FileRange) CanSplitAt(x uint64) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go
new file mode 100644
index 000000000..a4651f500
--- /dev/null
+++ b/pkg/sentry/platform/interrupt/interrupt.go
@@ -0,0 +1,96 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package interrupt provides an interrupt helper.
+package interrupt
+
+import (
+ "fmt"
+ "sync"
+)
+
+// Receiver receives interrupt notifications from a Forwarder.
+type Receiver interface {
+ // NotifyInterrupt is called when the Receiver receives an interrupt.
+ NotifyInterrupt()
+}
+
+// Forwarder is a helper for delivering delayed signal interruptions.
+//
+// This helps platform implementations with Interrupt semantics.
+type Forwarder struct {
+ // mu protects the below.
+ mu sync.Mutex
+
+ // dst is the function to be called when NotifyInterrupt() is called. If
+ // dst is nil, pending will be set instead, causing the next call to
+ // Enable() to return false.
+ dst Receiver
+ pending bool
+}
+
+// Enable attempts to enable interrupt forwarding to r. If f has already
+// received an interrupt, Enable does nothing and returns false. Otherwise,
+// future calls to f.NotifyInterrupt() cause r.NotifyInterrupt() to be called,
+// and Enable returns true.
+//
+// Usage:
+//
+// if !f.Enable(r) {
+// // There was an interrupt.
+// return
+// }
+// defer f.Disable()
+//
+// Preconditions: r must not be nil. f must not already be forwarding
+// interrupts to a Receiver.
+func (f *Forwarder) Enable(r Receiver) bool {
+ if r == nil {
+ panic("nil Receiver")
+ }
+ f.mu.Lock()
+ if f.dst != nil {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("already forwarding interrupts to %+v", f.dst))
+ }
+ if f.pending {
+ f.pending = false
+ f.mu.Unlock()
+ return false
+ }
+ f.dst = r
+ f.mu.Unlock()
+ return true
+}
+
+// Disable stops interrupt forwarding. If interrupt forwarding is already
+// disabled, Disable is a no-op.
+func (f *Forwarder) Disable() {
+ f.mu.Lock()
+ f.dst = nil
+ f.mu.Unlock()
+}
+
+// NotifyInterrupt implements Receiver.NotifyInterrupt. If interrupt forwarding
+// is enabled, the configured Receiver will be notified. Otherwise the
+// interrupt will be delivered to the next call to Enable.
+func (f *Forwarder) NotifyInterrupt() {
+ f.mu.Lock()
+ if f.dst != nil {
+ f.dst.NotifyInterrupt()
+ } else {
+ f.pending = true
+ }
+ f.mu.Unlock()
+}
diff --git a/pkg/sentry/platform/interrupt/interrupt_state_autogen.go b/pkg/sentry/platform/interrupt/interrupt_state_autogen.go
new file mode 100755
index 000000000..15e8bacdf
--- /dev/null
+++ b/pkg/sentry/platform/interrupt/interrupt_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package interrupt
+
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
new file mode 100644
index 000000000..689122175
--- /dev/null
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/atomicbitops"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// dirtySet tracks vCPUs for invalidation.
+type dirtySet struct {
+ vCPUs []uint64
+}
+
+// forEach iterates over all CPUs in the dirty set.
+func (ds *dirtySet) forEach(m *machine, fn func(c *vCPU)) {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ for index := range ds.vCPUs {
+ mask := atomic.SwapUint64(&ds.vCPUs[index], 0)
+ if mask != 0 {
+ for bit := 0; bit < 64; bit++ {
+ if mask&(1<<uint64(bit)) == 0 {
+ continue
+ }
+ id := 64*index + bit
+ fn(m.vCPUsByID[id])
+ }
+ }
+ }
+}
+
+// mark marks the given vCPU as dirty and returns whether it was previously
+// clean. Being previously clean implies that a flush is needed on entry.
+func (ds *dirtySet) mark(c *vCPU) bool {
+ index := uint64(c.id) / 64
+ bit := uint64(1) << uint(c.id%64)
+
+ oldValue := atomic.LoadUint64(&ds.vCPUs[index])
+ if oldValue&bit != 0 {
+ return false // Not clean.
+ }
+
+ // Set the bit unilaterally, and ensure that a flush takes place. Note
+ // that it's possible for races to occur here, but since the flush is
+ // taking place long after these lines there's no race in practice.
+ atomicbitops.OrUint64(&ds.vCPUs[index], bit)
+ return true // Previously clean.
+}
+
+// addressSpace is a wrapper for PageTables.
+type addressSpace struct {
+ platform.NoAddressSpaceIO
+
+ // mu is the lock for modifications to the address space.
+ //
+ // Note that the page tables themselves are not locked.
+ mu sync.Mutex
+
+ // machine is the underlying machine.
+ machine *machine
+
+ // pageTables are for this particular address space.
+ pageTables *pagetables.PageTables
+
+ // dirtySet is the set of dirty vCPUs.
+ dirtySet *dirtySet
+}
+
+// invalidate is the implementation for Invalidate.
+func (as *addressSpace) invalidate() {
+ as.dirtySet.forEach(as.machine, func(c *vCPU) {
+ if c.active.get() == as { // If this happens to be active,
+ c.BounceToKernel() // ... force a kernel transition.
+ }
+ })
+}
+
+// Invalidate interrupts all dirty contexts.
+func (as *addressSpace) Invalidate() {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+ as.invalidate()
+}
+
+// Touch adds the given vCPU to the dirty list.
+//
+// The return value indicates whether a flush is required.
+func (as *addressSpace) Touch(c *vCPU) bool {
+ return as.dirtySet.mark(c)
+}
+
+type hostMapEntry struct {
+ addr uintptr
+ length uintptr
+}
+
+func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
+ for m.length > 0 {
+ physical, length, ok := translateToPhysical(m.addr)
+ if !ok {
+ panic("unable to translate segment")
+ }
+ if length > m.length {
+ length = m.length
+ }
+
+ // Ensure that this map has physical mappings. If the page does
+ // not have physical mappings, the KVM module may inject
+ // spurious exceptions when emulation fails (i.e. it tries to
+ // emulate because the RIP is pointed at those pages).
+ as.machine.mapPhysical(physical, length)
+
+ // Install the page table mappings. Note that the ordering is
+ // important; if the pagetable mappings were installed before
+ // ensuring the physical pages were available, then some other
+ // thread could theoretically access them.
+ //
+ // Due to the way KVM's shadow paging implementation works,
+ // modifications to the page tables while in host mode may not
+ // be trapped, leading to the shadow pages being out of sync.
+ // Therefore, we need to ensure that we are in guest mode for
+ // page table modifications. See the call to bluepill, below.
+ as.machine.retryInGuest(func() {
+ inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
+ AccessType: at,
+ User: true,
+ }, physical) || inv
+ })
+ m.addr += length
+ m.length -= length
+ addr += usermem.Addr(length)
+ }
+
+ return inv
+}
+
+// MapFile implements platform.AddressSpace.MapFile.
+func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+
+ // Get mappings in the sentry's address space, which are guaranteed to be
+ // valid as long as a reference is held on the mapped pages (which is in
+ // turn required by AddressSpace.MapFile precondition).
+ //
+ // If precommit is true, we will touch mappings to commit them, so ensure
+ // that mappings are readable from sentry context.
+ //
+ // We don't execute from application file-mapped memory, and guest page
+ // tables don't care if we have execute permission (but they do need pages
+ // to be readable).
+ bs, err := f.MapInternal(fr, usermem.AccessType{
+ Read: at.Read || at.Execute || precommit,
+ Write: at.Write,
+ })
+ if err != nil {
+ return err
+ }
+
+ // Map the mappings in the sentry's address space (guest physical memory)
+ // into the application's address space (guest virtual memory).
+ inv := false
+ for !bs.IsEmpty() {
+ b := bs.Head()
+ bs = bs.Tail()
+ // Since fr was page-aligned, b should also be page-aligned. We do the
+ // lookup in our host page tables for this translation.
+ if precommit {
+ s := b.ToSlice()
+ for i := 0; i < len(s); i += usermem.PageSize {
+ _ = s[i] // Touch to commit.
+ }
+ }
+ prev := as.mapHost(addr, hostMapEntry{
+ addr: b.Addr(),
+ length: uintptr(b.Len()),
+ }, at)
+ inv = inv || prev
+ addr += usermem.Addr(b.Len())
+ }
+ if inv {
+ as.invalidate()
+ }
+
+ return nil
+}
+
+// Unmap unmaps the given range by calling pagetables.PageTables.Unmap.
+func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) {
+ as.mu.Lock()
+ defer as.mu.Unlock()
+
+ // See above re: retryInGuest.
+ var prev bool
+ as.machine.retryInGuest(func() {
+ prev = as.pageTables.Unmap(addr, uintptr(length)) || prev
+ })
+ if prev {
+ as.invalidate()
+
+ // Recycle any freed intermediate pages.
+ as.pageTables.Allocator.Recycle()
+ }
+}
+
+// Release releases the page tables.
+func (as *addressSpace) Release() {
+ as.Unmap(0, ^uint64(0))
+
+ // Free all pages from the allocator.
+ as.pageTables.Allocator.(allocator).base.Drain()
+
+ // Drop all cached machine references.
+ as.machine.dropPageTables(as.pageTables)
+}
diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go
new file mode 100644
index 000000000..42bcc9733
--- /dev/null
+++ b/pkg/sentry/platform/kvm/allocator.go
@@ -0,0 +1,76 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables"
+)
+
+type allocator struct {
+ base *pagetables.RuntimeAllocator
+}
+
+// newAllocator is used to define the allocator.
+func newAllocator() allocator {
+ return allocator{
+ base: pagetables.NewRuntimeAllocator(),
+ }
+}
+
+// NewPTEs implements pagetables.Allocator.NewPTEs.
+//
+//go:nosplit
+func (a allocator) NewPTEs() *pagetables.PTEs {
+ return a.base.NewPTEs()
+}
+
+// PhysicalFor returns the physical address for a set of PTEs.
+//
+//go:nosplit
+func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
+ virtual := a.base.PhysicalFor(ptes)
+ physical, _, ok := translateToPhysical(virtual)
+ if !ok {
+ panic(fmt.Sprintf("PhysicalFor failed for %p", ptes))
+ }
+ return physical
+}
+
+// LookupPTEs implements pagetables.Allocator.LookupPTEs.
+//
+//go:nosplit
+func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
+ virtualStart, physicalStart, _, ok := calculateBluepillFault(physical)
+ if !ok {
+ panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical))
+ }
+ return a.base.LookupPTEs(virtualStart + (physical - physicalStart))
+}
+
+// FreePTEs implements pagetables.Allocator.FreePTEs.
+//
+//go:nosplit
+func (a allocator) FreePTEs(ptes *pagetables.PTEs) {
+ a.base.FreePTEs(ptes)
+}
+
+// Recycle implements pagetables.Allocator.Recycle.
+//
+//go:nosplit
+func (a allocator) Recycle() {
+ a.base.Recycle()
+}
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
new file mode 100644
index 000000000..a926e6f8b
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "fmt"
+ "reflect"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy"
+)
+
+// bluepill enters guest mode.
+func bluepill(*vCPU)
+
+// sighandler is the signal entry point.
+func sighandler()
+
+// dieTrampoline is the assembly trampoline. This calls dieHandler.
+//
+// This uses an architecture-specific calling convention, documented in
+// dieArchSetup and the assembly implementation for dieTrampoline.
+func dieTrampoline()
+
+var (
+ // savedHandler is a pointer to the previous handler.
+ //
+ // This is called by bluepillHandler.
+ savedHandler uintptr
+
+ // dieTrampolineAddr is the address of dieTrampoline.
+ dieTrampolineAddr uintptr
+)
+
+// dieHandler is called by dieTrampoline.
+//
+//go:nosplit
+func dieHandler(c *vCPU) {
+ throw(c.dieState.message)
+}
+
+// die is called to set the vCPU up to panic.
+//
+// This loads vCPU state, and sets up a call for the trampoline.
+//
+//go:nosplit
+func (c *vCPU) die(context *arch.SignalContext64, msg string) {
+ // Save the death message, which will be thrown.
+ c.dieState.message = msg
+
+ // Reload all registers to have an accurate stack trace when we return
+ // to host mode. This means that the stack should be unwound correctly.
+ if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
+ throw(msg)
+ }
+
+ // Setup the trampoline.
+ dieArchSetup(c, context, &c.dieState.guestRegs)
+}
+
+func init() {
+ // Install the handler.
+ if err := safecopy.ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for signal %d: %v", syscall.SIGSEGV, err))
+ }
+
+ // Extract the address for the trampoline.
+ dieTrampolineAddr = reflect.ValueOf(dieTrampoline).Pointer()
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
new file mode 100644
index 000000000..c258408f9
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -0,0 +1,141 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
+)
+
+var (
+ // bounceSignal is the signal used for bouncing KVM.
+ //
+ // We use SIGCHLD because it is not masked by the runtime, and
+ // it will be ignored properly by other parts of the kernel.
+ bounceSignal = syscall.SIGCHLD
+
+ // bounceSignalMask has only bounceSignal set.
+ bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1))
+
+ // bounce is the interrupt vector used to return to the kernel.
+ bounce = uint32(ring0.VirtualizationException)
+)
+
+// redpill on amd64 invokes a syscall with -1.
+//
+//go:nosplit
+func redpill() {
+ syscall.RawSyscall(^uintptr(0), 0, 0, 0)
+}
+
+// bluepillArchEnter is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchEnter(context *arch.SignalContext64) *vCPU {
+ c := vCPUPtr(uintptr(context.Rax))
+ regs := c.CPU.Registers()
+ regs.R8 = context.R8
+ regs.R9 = context.R9
+ regs.R10 = context.R10
+ regs.R11 = context.R11
+ regs.R12 = context.R12
+ regs.R13 = context.R13
+ regs.R14 = context.R14
+ regs.R15 = context.R15
+ regs.Rdi = context.Rdi
+ regs.Rsi = context.Rsi
+ regs.Rbp = context.Rbp
+ regs.Rbx = context.Rbx
+ regs.Rdx = context.Rdx
+ regs.Rax = context.Rax
+ regs.Rcx = context.Rcx
+ regs.Rsp = context.Rsp
+ regs.Rip = context.Rip
+ regs.Eflags = context.Eflags
+ regs.Eflags &^= uint64(ring0.KernelFlagsClear)
+ regs.Eflags |= ring0.KernelFlagsSet
+ regs.Cs = uint64(ring0.Kcode)
+ regs.Ds = uint64(ring0.Udata)
+ regs.Es = uint64(ring0.Udata)
+ regs.Ss = uint64(ring0.Kdata)
+ return c
+}
+
+// KernelSyscall handles kernel syscalls.
+//
+//go:nosplit
+func (c *vCPU) KernelSyscall() {
+ regs := c.Registers()
+ if regs.Rax != ^uint64(0) {
+ regs.Rip -= 2 // Rewind.
+ }
+ // We only trigger a bluepill entry in the bluepill function, and can
+ // therefore be guaranteed that there is no floating point state to be
+ // loaded on resuming from halt. We only worry about saving on exit.
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.Halt()
+ ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+}
+
+// KernelException handles kernel exceptions.
+//
+//go:nosplit
+func (c *vCPU) KernelException(vector ring0.Vector) {
+ regs := c.Registers()
+ if vector == ring0.Vector(bounce) {
+ // These should not interrupt kernel execution; point the Rip
+ // to zero to ensure that we get a reasonable panic when we
+ // attempt to return and a full stack trace.
+ regs.Rip = 0
+ }
+ // See above.
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.Halt()
+ ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+}
+
+// bluepillArchExit is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
+ regs := c.CPU.Registers()
+ context.R8 = regs.R8
+ context.R9 = regs.R9
+ context.R10 = regs.R10
+ context.R11 = regs.R11
+ context.R12 = regs.R12
+ context.R13 = regs.R13
+ context.R14 = regs.R14
+ context.R15 = regs.R15
+ context.Rdi = regs.Rdi
+ context.Rsi = regs.Rsi
+ context.Rbp = regs.Rbp
+ context.Rbx = regs.Rbx
+ context.Rdx = regs.Rdx
+ context.Rax = regs.Rax
+ context.Rcx = regs.Rcx
+ context.Rsp = regs.Rsp
+ context.Rip = regs.Rip
+ context.Eflags = regs.Eflags
+
+ // Set the context pointer to the saved floating point state. This is
+ // where the guest data has been serialized, the kernel will restore
+ // from this new pointer value.
+ context.Fpstate = uint64(uintptrValue((*byte)(c.floatingPointState)))
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
new file mode 100644
index 000000000..2bc34a435
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// VCPU_CPU is the location of the CPU in the vCPU struct.
+//
+// This is guaranteed to be zero.
+#define VCPU_CPU 0x0
+
+// CPU_SELF is the self reference in ring0's percpu.
+//
+// This is guaranteed to be zero.
+#define CPU_SELF 0x0
+
+// Context offsets.
+//
+// Only limited use of the context is done in the assembly stub below, most is
+// done in the Go handlers. However, the RIP must be examined.
+#define CONTEXT_RAX 0x90
+#define CONTEXT_RIP 0xa8
+#define CONTEXT_FP 0xe0
+
+// CLI is the literal byte for the disable interrupts instruction.
+//
+// This is checked as the source of the fault.
+#define CLI $0xfa
+
+// See bluepill.go.
+TEXT ·bluepill(SB),NOSPLIT,$0
+begin:
+ MOVQ vcpu+0(FP), AX
+ LEAQ VCPU_CPU(AX), BX
+ BYTE CLI;
+check_vcpu:
+ MOVQ CPU_SELF(GS), CX
+ CMPQ BX, CX
+ JE right_vCPU
+wrong_vcpu:
+ CALL ·redpill(SB)
+ JMP begin
+right_vCPU:
+ RET
+
+// sighandler: see bluepill.go for documentation.
+//
+// The arguments are the following:
+//
+// DI - The signal number.
+// SI - Pointer to siginfo_t structure.
+// DX - Pointer to ucontext structure.
+//
+TEXT ·sighandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel.
+ MOVQ $0x80, CX
+ CMPL CX, 0x8(SI)
+ JNE fallback
+
+ // Check if RIP is disable interrupts.
+ MOVQ CONTEXT_RIP(DX), CX
+ CMPQ CX, $0x0
+ JE fallback
+ CMPB 0(CX), CLI
+ JNE fallback
+
+ // Call the bluepillHandler.
+ PUSHQ DX // First argument (context).
+ CALL ·bluepillHandler(SB) // Call the handler.
+ POPQ DX // Discard the argument.
+ RET
+
+fallback:
+ // Jump to the previous signal handler.
+ XORQ CX, CX
+ MOVQ ·savedHandler(SB), AX
+ JMP AX
+
+// dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation.
+TEXT ·dieTrampoline(SB),NOSPLIT,$0
+ PUSHQ BX // First argument (vCPU).
+ PUSHQ AX // Fake the old RIP as caller.
+ JMP ·dieHandler(SB)
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
new file mode 100644
index 000000000..92fde7ee0
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
+)
+
+// bluepillArchContext returns the arch-specific context.
+//
+//go:nosplit
+func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
+ return &((*arch.UContext64)(context).MContext)
+}
+
+// dieArchSetup initialies the state for dieTrampoline.
+//
+// The amd64 dieTrampoline requires the vCPU to be set in BX, and the last RIP
+// to be in AX. The trampoline then simulates a call to dieHandler from the
+// provided RIP.
+//
+//go:nosplit
+func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // If the vCPU is in user mode, we set the stack to the stored stack
+ // value in the vCPU itself. We don't want to unwind the user stack.
+ if guestRegs.RFLAGS&ring0.UserFlagsSet == ring0.UserFlagsSet {
+ regs := c.CPU.Registers()
+ context.Rax = regs.Rax
+ context.Rsp = regs.Rsp
+ context.Rbp = regs.Rbp
+ } else {
+ context.Rax = guestRegs.RIP
+ context.Rsp = guestRegs.RSP
+ context.Rbp = guestRegs.RBP
+ context.Eflags = guestRegs.RFLAGS
+ }
+ context.Rbx = uint64(uintptr(unsafe.Pointer(c)))
+ context.Rip = uint64(dieTrampolineAddr)
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
new file mode 100644
index 000000000..3c452f5ba
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+const (
+ // faultBlockSize is the size used for servicing memory faults.
+ //
+ // This should be large enough to avoid frequent faults and avoid using
+ // all available KVM slots (~512), but small enough that KVM does not
+ // complain about slot sizes (~4GB). See handleBluepillFault for how
+ // this block is used.
+ faultBlockSize = 2 << 30
+
+ // faultBlockMask is the mask for the fault blocks.
+ //
+ // This must be typed to avoid overflow complaints (ugh).
+ faultBlockMask = ^uintptr(faultBlockSize - 1)
+)
+
+// yield yields the CPU.
+//
+//go:nosplit
+func yield() {
+ syscall.RawSyscall(syscall.SYS_SCHED_YIELD, 0, 0, 0)
+}
+
+// calculateBluepillFault calculates the fault address range.
+//
+//go:nosplit
+func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) {
+ alignedPhysical := physical &^ uintptr(usermem.PageSize-1)
+ for _, pr := range physicalRegions {
+ end := pr.physical + pr.length
+ if physical < pr.physical || physical >= end {
+ continue
+ }
+
+ // Adjust the block to match our size.
+ physicalStart = alignedPhysical & faultBlockMask
+ if physicalStart < pr.physical {
+ // Bound the starting point to the start of the region.
+ physicalStart = pr.physical
+ }
+ virtualStart = pr.virtual + (physicalStart - pr.physical)
+ physicalEnd := physicalStart + faultBlockSize
+ if physicalEnd > end {
+ physicalEnd = end
+ }
+ length = physicalEnd - physicalStart
+ return virtualStart, physicalStart, length, true
+ }
+
+ return 0, 0, 0, false
+}
+
+// handleBluepillFault handles a physical fault.
+//
+// The corresponding virtual address is returned. This may throw on error.
+//
+//go:nosplit
+func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) {
+ // Paging fault: we need to map the underlying physical pages for this
+ // fault. This all has to be done in this function because we're in a
+ // signal handler context. (We can't call any functions that might
+ // split the stack.)
+ virtualStart, physicalStart, length, ok := calculateBluepillFault(physical)
+ if !ok {
+ return 0, false
+ }
+
+ // Set the KVM slot.
+ //
+ // First, we need to acquire the exclusive right to set a slot. See
+ // machine.nextSlot for information about the protocol.
+ slot := atomic.SwapUint32(&m.nextSlot, ^uint32(0))
+ for slot == ^uint32(0) {
+ yield() // Race with another call.
+ slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0))
+ }
+ errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart)
+ if errno == 0 {
+ // Successfully added region; we can increment nextSlot and
+ // allow another set to proceed here.
+ atomic.StoreUint32(&m.nextSlot, slot+1)
+ return virtualStart + (physical - physicalStart), true
+ }
+
+ // Release our slot (still available).
+ atomic.StoreUint32(&m.nextSlot, slot)
+
+ switch errno {
+ case syscall.EEXIST:
+ // The region already exists. It's possible that we raced with
+ // another vCPU here. We just revert nextSlot and return true,
+ // because this must have been satisfied by some other vCPU.
+ return virtualStart + (physical - physicalStart), true
+ case syscall.EINVAL:
+ throw("set memory region failed; out of slots")
+ case syscall.ENOMEM:
+ throw("set memory region failed: out of memory")
+ case syscall.EFAULT:
+ throw("set memory region failed: invalid physical range")
+ default:
+ throw("set memory region failed: unknown reason")
+ }
+
+ panic("unreachable")
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
new file mode 100644
index 000000000..7e8e9f42a
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -0,0 +1,213 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.12
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+package kvm
+
+import (
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+)
+
+//go:linkname throw runtime.throw
+func throw(string)
+
+// vCPUPtr returns a CPU for the given address.
+//
+//go:nosplit
+func vCPUPtr(addr uintptr) *vCPU {
+ return (*vCPU)(unsafe.Pointer(addr))
+}
+
+// bytePtr returns a bytePtr for the given address.
+//
+//go:nosplit
+func bytePtr(addr uintptr) *byte {
+ return (*byte)(unsafe.Pointer(addr))
+}
+
+// uintptrValue returns a uintptr for the given address.
+//
+//go:nosplit
+func uintptrValue(addr *byte) uintptr {
+ return (uintptr)(unsafe.Pointer(addr))
+}
+
+// bluepillHandler is called from the signal stub.
+//
+// The world may be stopped while this is executing, and it executes on the
+// signal stack. It should only execute raw system calls and functions that are
+// explicitly marked go:nosplit.
+//
+//go:nosplit
+func bluepillHandler(context unsafe.Pointer) {
+ // Sanitize the registers; interrupts must always be disabled.
+ c := bluepillArchEnter(bluepillArchContext(context))
+
+ // Increment the number of switches.
+ atomic.AddUint32(&c.switches, 1)
+
+ // Mark this as guest mode.
+ switch atomic.SwapUint32(&c.state, vCPUGuest|vCPUUser) {
+ case vCPUUser: // Expected case.
+ case vCPUUser | vCPUWaiter:
+ c.notify()
+ default:
+ throw("invalid state")
+ }
+
+ for {
+ switch _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0); errno {
+ case 0: // Expected case.
+ case syscall.EINTR:
+ // First, we process whatever pending signal
+ // interrupted KVM. Since we're in a signal handler
+ // currently, all signals are masked and the signal
+ // must have been delivered directly to this thread.
+ sig, _, errno := syscall.RawSyscall6(
+ syscall.SYS_RT_SIGTIMEDWAIT,
+ uintptr(unsafe.Pointer(&bounceSignalMask)),
+ 0, // siginfo.
+ 0, // timeout.
+ 8, // sigset size.
+ 0, 0)
+ if errno != 0 {
+ throw("error waiting for pending signal")
+ }
+ if sig != uintptr(bounceSignal) {
+ throw("unexpected signal")
+ }
+
+ // Check whether the current state of the vCPU is ready
+ // for interrupt injection. Because we don't have a
+ // PIC, we can't inject an interrupt while they are
+ // masked. We need to request a window if it's not
+ // ready.
+ if c.runData.readyForInterruptInjection == 0 {
+ c.runData.requestInterruptWindow = 1
+ continue // Rerun vCPU.
+ } else {
+ // Force injection below; the vCPU is ready.
+ c.runData.exitReason = _KVM_EXIT_IRQ_WINDOW_OPEN
+ }
+ case syscall.EFAULT:
+ // If a fault is not serviceable due to the host
+ // backing pages having page permissions, instead of an
+ // MMIO exit we receive EFAULT from the run ioctl. We
+ // always inject an NMI here since we may be in kernel
+ // mode and have interrupts disabled.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_NMI, 0); errno != 0 {
+ throw("NMI injection failed")
+ }
+ continue // Rerun vCPU.
+ default:
+ throw("run failed")
+ }
+
+ switch c.runData.exitReason {
+ case _KVM_EXIT_EXCEPTION:
+ c.die(bluepillArchContext(context), "exception")
+ return
+ case _KVM_EXIT_IO:
+ c.die(bluepillArchContext(context), "I/O")
+ return
+ case _KVM_EXIT_INTERNAL_ERROR:
+ // An internal error is typically thrown when emulation
+ // fails. This can occur via the MMIO path below (and
+ // it might fail because we have multiple regions that
+ // are not mapped). We would actually prefer that no
+ // emulation occur, and don't mind at all if it fails.
+ case _KVM_EXIT_HYPERCALL:
+ c.die(bluepillArchContext(context), "hypercall")
+ return
+ case _KVM_EXIT_DEBUG:
+ c.die(bluepillArchContext(context), "debug")
+ return
+ case _KVM_EXIT_HLT:
+ // Copy out registers.
+ bluepillArchExit(c, bluepillArchContext(context))
+
+ // Return to the vCPUReady state; notify any waiters.
+ user := atomic.LoadUint32(&c.state) & vCPUUser
+ switch atomic.SwapUint32(&c.state, user) {
+ case user | vCPUGuest: // Expected case.
+ case user | vCPUGuest | vCPUWaiter:
+ c.notify()
+ default:
+ throw("invalid state")
+ }
+ return
+ case _KVM_EXIT_MMIO:
+ // Increment the fault count.
+ atomic.AddUint32(&c.faults, 1)
+
+ // For MMIO, the physical address is the first data item.
+ physical := uintptr(c.runData.data[0])
+ virtual, ok := handleBluepillFault(c.machine, physical)
+ if !ok {
+ c.die(bluepillArchContext(context), "invalid physical address")
+ return
+ }
+
+ // We now need to fill in the data appropriately. KVM
+ // expects us to provide the result of the given MMIO
+ // operation in the runData struct. This is safe
+ // because, if a fault occurs here, the same fault
+ // would have occurred in guest mode. The kernel should
+ // not create invalid page table mappings.
+ data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1]))
+ length := (uintptr)((uint32)(c.runData.data[2]))
+ write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0
+ for i := uintptr(0); i < length; i++ {
+ b := bytePtr(uintptr(virtual) + i)
+ if write {
+ // Write to the given address.
+ *b = data[i]
+ } else {
+ // Read from the given address.
+ data[i] = *b
+ }
+ }
+ case _KVM_EXIT_IRQ_WINDOW_OPEN:
+ // Interrupt: we must have requested an interrupt
+ // window; set the interrupt line.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_INTERRUPT,
+ uintptr(unsafe.Pointer(&bounce))); errno != 0 {
+ throw("interrupt injection failed")
+ }
+ // Clear previous injection request.
+ c.runData.requestInterruptWindow = 0
+ case _KVM_EXIT_SHUTDOWN:
+ c.die(bluepillArchContext(context), "shutdown")
+ return
+ case _KVM_EXIT_FAIL_ENTRY:
+ c.die(bluepillArchContext(context), "entry failed")
+ return
+ default:
+ c.die(bluepillArchContext(context), "unknown")
+ return
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
new file mode 100644
index 000000000..0eb0020f7
--- /dev/null
+++ b/pkg/sentry/platform/kvm/context.go
@@ -0,0 +1,87 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/interrupt"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// context is an implementation of the platform context.
+//
+// This is a thin wrapper around the machine.
+type context struct {
+ // machine is the parent machine, and is immutable.
+ machine *machine
+
+ // info is the arch.SignalInfo cached for this context.
+ info arch.SignalInfo
+
+ // interrupt is the interrupt context.
+ interrupt interrupt.Forwarder
+}
+
+// Switch runs the provided context in the given address space.
+func (c *context) Switch(as platform.AddressSpace, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) {
+ localAS := as.(*addressSpace)
+
+ // Grab a vCPU.
+ cpu := c.machine.Get()
+
+ // Enable interrupts (i.e. calls to vCPU.Notify).
+ if !c.interrupt.Enable(cpu) {
+ c.machine.Put(cpu) // Already preempted.
+ return nil, usermem.NoAccess, platform.ErrContextInterrupt
+ }
+
+ // Set the active address space.
+ //
+ // This must be done prior to the call to Touch below. If the address
+ // space is invalidated between this line and the call below, we will
+ // flag on entry anyways. When the active address space below is
+ // cleared, it indicates that we don't need an explicit interrupt and
+ // that the flush can occur naturally on the next user entry.
+ cpu.active.set(localAS)
+
+ // Prepare switch options.
+ switchOpts := ring0.SwitchOpts{
+ Registers: &ac.StateData().Regs,
+ FloatingPointState: (*byte)(ac.FloatingPointData()),
+ PageTables: localAS.pageTables,
+ Flush: localAS.Touch(cpu),
+ FullRestore: ac.FullRestore(),
+ }
+
+ // Take the blue pill.
+ at, err := cpu.SwitchToUser(switchOpts, &c.info)
+
+ // Clear the address space.
+ cpu.active.set(nil)
+
+ // Release resources.
+ c.machine.Put(cpu)
+
+ // All done.
+ c.interrupt.Disable()
+ return &c.info, at, err
+}
+
+// Interrupt interrupts the running context.
+func (c *context) Interrupt() {
+ c.interrupt.NotifyInterrupt()
+}
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
new file mode 100644
index 000000000..ed0521c3f
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -0,0 +1,143 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package kvm provides a kvm-based implementation of the platform interface.
+package kvm
+
+import (
+ "fmt"
+ "os"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// KVM represents a lightweight VM context.
+type KVM struct {
+ platform.NoCPUPreemptionDetection
+
+ // machine is the backing VM.
+ machine *machine
+}
+
+var (
+ globalOnce sync.Once
+ globalErr error
+)
+
+// OpenDevice opens the KVM device at /dev/kvm and returns the File.
+func OpenDevice() (*os.File, error) {
+ f, err := os.OpenFile("/dev/kvm", syscall.O_RDWR, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error opening /dev/kvm: %v", err)
+ }
+ return f, nil
+}
+
+// New returns a new KVM-based implementation of the platform interface.
+func New(deviceFile *os.File) (*KVM, error) {
+ fd := deviceFile.Fd()
+
+ // Ensure global initialization is done.
+ globalOnce.Do(func() {
+ physicalInit()
+ globalErr = updateSystemValues(int(fd))
+ ring0.Init(cpuid.HostFeatureSet())
+ })
+ if globalErr != nil {
+ return nil, globalErr
+ }
+
+ // Create a new VM fd.
+ vm, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, fd, _KVM_CREATE_VM, 0)
+ if errno != 0 {
+ return nil, fmt.Errorf("creating VM: %v", errno)
+ }
+ // We are done with the device file.
+ deviceFile.Close()
+
+ // Create a VM context.
+ machine, err := newMachine(int(vm))
+ if err != nil {
+ return nil, err
+ }
+
+ // All set.
+ return &KVM{
+ machine: machine,
+ }, nil
+}
+
+// SupportsAddressSpaceIO implements platform.Platform.SupportsAddressSpaceIO.
+func (*KVM) SupportsAddressSpaceIO() bool {
+ return false
+}
+
+// CooperativelySchedulesAddressSpace implements platform.Platform.CooperativelySchedulesAddressSpace.
+func (*KVM) CooperativelySchedulesAddressSpace() bool {
+ return false
+}
+
+// MapUnit implements platform.Platform.MapUnit.
+func (*KVM) MapUnit() uint64 {
+ // We greedily creates PTEs in MapFile, so extremely large mappings can
+ // be expensive. Not _that_ expensive since we allow super pages, but
+ // even though can get out of hand if you're creating multi-terabyte
+ // mappings. For this reason, we limit mappings to an arbitrary 16MB.
+ return 16 << 20
+}
+
+// MinUserAddress returns the lowest available address.
+func (*KVM) MinUserAddress() usermem.Addr {
+ return usermem.PageSize
+}
+
+// MaxUserAddress returns the first address that may not be used.
+func (*KVM) MaxUserAddress() usermem.Addr {
+ return usermem.Addr(ring0.MaximumUserAddress)
+}
+
+// NewAddressSpace returns a new pagetable root.
+func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
+ // Allocate page tables and install system mappings.
+ pageTables := pagetables.New(newAllocator())
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ // Map the kernel in the upper half.
+ pageTables.Map(
+ usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+ return true // Keep iterating.
+ })
+
+ // Return the new address space.
+ return &addressSpace{
+ machine: k.machine,
+ pageTables: pageTables,
+ dirtySet: k.machine.newDirtySet(),
+ }, nil, nil
+}
+
+// NewContext returns an interruptible context.
+func (k *KVM) NewContext() platform.Context {
+ return &context{
+ machine: k.machine,
+ }
+}
diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go
new file mode 100644
index 000000000..61493ccaf
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_amd64.go
@@ -0,0 +1,213 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
+)
+
+// userMemoryRegion is a region of physical memory.
+//
+// This mirrors kvm_memory_region.
+type userMemoryRegion struct {
+ slot uint32
+ flags uint32
+ guestPhysAddr uint64
+ memorySize uint64
+ userspaceAddr uint64
+}
+
+// userRegs represents KVM user registers.
+//
+// This mirrors kvm_regs.
+type userRegs struct {
+ RAX uint64
+ RBX uint64
+ RCX uint64
+ RDX uint64
+ RSI uint64
+ RDI uint64
+ RSP uint64
+ RBP uint64
+ R8 uint64
+ R9 uint64
+ R10 uint64
+ R11 uint64
+ R12 uint64
+ R13 uint64
+ R14 uint64
+ R15 uint64
+ RIP uint64
+ RFLAGS uint64
+}
+
+// systemRegs represents KVM system registers.
+//
+// This mirrors kvm_sregs.
+type systemRegs struct {
+ CS segment
+ DS segment
+ ES segment
+ FS segment
+ GS segment
+ SS segment
+ TR segment
+ LDT segment
+ GDT descriptor
+ IDT descriptor
+ CR0 uint64
+ CR2 uint64
+ CR3 uint64
+ CR4 uint64
+ CR8 uint64
+ EFER uint64
+ apicBase uint64
+ interruptBitmap [(_KVM_NR_INTERRUPTS + 63) / 64]uint64
+}
+
+// segment is the expanded form of a segment register.
+//
+// This mirrors kvm_segment.
+type segment struct {
+ base uint64
+ limit uint32
+ selector uint16
+ typ uint8
+ present uint8
+ DPL uint8
+ DB uint8
+ S uint8
+ L uint8
+ G uint8
+ AVL uint8
+ unusable uint8
+ _ uint8
+}
+
+// Clear clears the segment and marks it unusable.
+func (s *segment) Clear() {
+ *s = segment{unusable: 1}
+}
+
+// selector is a segment selector.
+type selector uint16
+
+// tobool is a simple helper.
+func tobool(x ring0.SegmentDescriptorFlags) uint8 {
+ if x != 0 {
+ return 1
+ }
+ return 0
+}
+
+// Load loads the segment described by d into the segment s.
+//
+// The argument sel is recorded as the segment selector index.
+func (s *segment) Load(d *ring0.SegmentDescriptor, sel ring0.Selector) {
+ flag := d.Flags()
+ if flag&ring0.SegmentDescriptorPresent == 0 {
+ s.Clear()
+ return
+ }
+ s.base = uint64(d.Base())
+ s.limit = d.Limit()
+ s.typ = uint8((flag>>8)&0xF) | 1
+ s.S = tobool(flag & ring0.SegmentDescriptorSystem)
+ s.DPL = uint8(d.DPL())
+ s.present = tobool(flag & ring0.SegmentDescriptorPresent)
+ s.AVL = tobool(flag & ring0.SegmentDescriptorAVL)
+ s.L = tobool(flag & ring0.SegmentDescriptorLong)
+ s.DB = tobool(flag & ring0.SegmentDescriptorDB)
+ s.G = tobool(flag & ring0.SegmentDescriptorG)
+ if s.L != 0 {
+ s.limit = 0xffffffff
+ }
+ s.unusable = 0
+ s.selector = uint16(sel)
+}
+
+// descriptor describes a region of physical memory.
+//
+// It corresponds to the pseudo-descriptor used in the x86 LGDT and LIDT
+// instructions, and mirrors kvm_dtable.
+type descriptor struct {
+ base uint64
+ limit uint16
+ _ [3]uint16
+}
+
+// modelControlRegister is an MSR entry.
+//
+// This mirrors kvm_msr_entry.
+type modelControlRegister struct {
+ index uint32
+ _ uint32
+ data uint64
+}
+
+// modelControlRegisers is a collection of MSRs.
+//
+// This mirrors kvm_msrs.
+type modelControlRegisters struct {
+ nmsrs uint32
+ _ uint32
+ entries [16]modelControlRegister
+}
+
+// runData is the run structure. This may be mapped for synchronous register
+// access (although that doesn't appear to be supported by my kernel at least).
+//
+// This mirrors kvm_run.
+type runData struct {
+ requestInterruptWindow uint8
+ _ [7]uint8
+
+ exitReason uint32
+ readyForInterruptInjection uint8
+ ifFlag uint8
+ _ [2]uint8
+
+ cr8 uint64
+ apicBase uint64
+
+ // This is the union data for exits. Interpretation depends entirely on
+ // the exitReason above (see vCPU code for more information).
+ data [32]uint64
+}
+
+// cpuidEntry is a single CPUID entry.
+//
+// This mirrors kvm_cpuid_entry2.
+type cpuidEntry struct {
+ function uint32
+ index uint32
+ flags uint32
+ eax uint32
+ ebx uint32
+ ecx uint32
+ edx uint32
+ _ [3]uint32
+}
+
+// cpuidEntries is a collection of CPUID entries.
+//
+// This mirrors kvm_cpuid2.
+type cpuidEntries struct {
+ nr uint32
+ _ uint32
+ entries [_KVM_NR_CPUID_ENTRIES]cpuidEntry
+}
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go b/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go
new file mode 100644
index 000000000..46c4b9113
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_amd64_unsafe.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "syscall"
+ "unsafe"
+)
+
+var (
+ runDataSize int
+ hasGuestPCID bool
+ cpuidSupported = cpuidEntries{nr: _KVM_NR_CPUID_ENTRIES}
+)
+
+func updateSystemValues(fd int) error {
+ // Extract the mmap size.
+ sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0)
+ if errno != 0 {
+ return fmt.Errorf("getting VCPU mmap size: %v", errno)
+ }
+
+ // Save the data.
+ runDataSize = int(sz)
+
+ // Must do the dance to figure out the number of entries.
+ _, _, errno = syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(fd),
+ _KVM_GET_SUPPORTED_CPUID,
+ uintptr(unsafe.Pointer(&cpuidSupported)))
+ if errno != 0 && errno != syscall.ENOMEM {
+ // Some other error occurred.
+ return fmt.Errorf("getting supported CPUID: %v", errno)
+ }
+
+ // The number should now be correct.
+ _, _, errno = syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(fd),
+ _KVM_GET_SUPPORTED_CPUID,
+ uintptr(unsafe.Pointer(&cpuidSupported)))
+ if errno != 0 {
+ // Didn't work with the right number.
+ return fmt.Errorf("getting supported CPUID (2nd attempt): %v", errno)
+ }
+
+ // Calculate whether guestPCID is supported.
+ //
+ // FIXME(ascannell): These should go through the much more pleasant
+ // cpuid package interfaces, once a way to accept raw kvm CPUID entries
+ // is plumbed (or some rough equivalent).
+ for i := 0; i < int(cpuidSupported.nr); i++ {
+ entry := cpuidSupported.entries[i]
+ if entry.function == 1 && entry.index == 0 && entry.ecx&(1<<17) != 0 {
+ hasGuestPCID = true // Found matching PCID in guest feature set.
+ }
+ }
+
+ // Success.
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
new file mode 100644
index 000000000..d05f05c29
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -0,0 +1,64 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+// KVM ioctls.
+//
+// Only the ioctls we need in Go appear here; some additional ioctls are used
+// within the assembly stubs (KVM_INTERRUPT, etc.).
+const (
+ _KVM_CREATE_VM = 0xae01
+ _KVM_GET_VCPU_MMAP_SIZE = 0xae04
+ _KVM_CREATE_VCPU = 0xae41
+ _KVM_SET_TSS_ADDR = 0xae47
+ _KVM_RUN = 0xae80
+ _KVM_NMI = 0xae9a
+ _KVM_CHECK_EXTENSION = 0xae03
+ _KVM_INTERRUPT = 0x4004ae86
+ _KVM_SET_MSRS = 0x4008ae89
+ _KVM_SET_USER_MEMORY_REGION = 0x4020ae46
+ _KVM_SET_REGS = 0x4090ae82
+ _KVM_SET_SREGS = 0x4138ae84
+ _KVM_GET_REGS = 0x8090ae81
+ _KVM_GET_SUPPORTED_CPUID = 0xc008ae05
+ _KVM_SET_CPUID2 = 0x4008ae90
+ _KVM_SET_SIGNAL_MASK = 0x4004ae8b
+)
+
+// KVM exit reasons.
+const (
+ _KVM_EXIT_EXCEPTION = 0x1
+ _KVM_EXIT_IO = 0x2
+ _KVM_EXIT_HYPERCALL = 0x3
+ _KVM_EXIT_DEBUG = 0x4
+ _KVM_EXIT_HLT = 0x5
+ _KVM_EXIT_MMIO = 0x6
+ _KVM_EXIT_IRQ_WINDOW_OPEN = 0x7
+ _KVM_EXIT_SHUTDOWN = 0x8
+ _KVM_EXIT_FAIL_ENTRY = 0x9
+ _KVM_EXIT_INTERNAL_ERROR = 0x11
+)
+
+// KVM capability options.
+const (
+ _KVM_CAP_MAX_VCPUS = 0x42
+)
+
+// KVM limits.
+const (
+ _KVM_NR_VCPUS = 0xff
+ _KVM_NR_INTERRUPTS = 0x100
+ _KVM_NR_CPUID_ENTRIES = 0x100
+)
diff --git a/pkg/sentry/platform/kvm/kvm_state_autogen.go b/pkg/sentry/platform/kvm/kvm_state_autogen.go
new file mode 100755
index 000000000..5ab0e0735
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package kvm
+
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
new file mode 100644
index 000000000..f5953b96e
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -0,0 +1,525 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "fmt"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/atomicbitops"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// machine contains state associated with the VM as a whole.
+type machine struct {
+ // fd is the vm fd.
+ fd int
+
+ // nextSlot is the next slot for setMemoryRegion.
+ //
+ // This must be accessed atomically. If nextSlot is ^uint32(0), then
+ // slots are currently being updated, and the caller should retry.
+ nextSlot uint32
+
+ // kernel is the set of global structures.
+ kernel ring0.Kernel
+
+ // mappingCache is used for mapPhysical.
+ mappingCache sync.Map
+
+ // mu protects vCPUs.
+ mu sync.RWMutex
+
+ // available is notified when vCPUs are available.
+ available sync.Cond
+
+ // vCPUs are the machine vCPUs.
+ //
+ // These are populated dynamically.
+ vCPUs map[uint64]*vCPU
+
+ // vCPUsByID are the machine vCPUs, can be indexed by the vCPU's ID.
+ vCPUsByID map[int]*vCPU
+
+ // maxVCPUs is the maximum number of vCPUs supported by the machine.
+ maxVCPUs int
+}
+
+const (
+ // vCPUReady is an alias for all the below clear.
+ vCPUReady uint32 = 0
+
+ // vCPUser indicates that the vCPU is in or about to enter user mode.
+ vCPUUser uint32 = 1 << 0
+
+ // vCPUGuest indicates the vCPU is in guest mode.
+ vCPUGuest uint32 = 1 << 1
+
+ // vCPUWaiter indicates that there is a waiter.
+ //
+ // If this is set, then notify must be called on any state transitions.
+ vCPUWaiter uint32 = 1 << 2
+)
+
+// vCPU is a single KVM vCPU.
+type vCPU struct {
+ // CPU is the kernel CPU data.
+ //
+ // This must be the first element of this structure, it is referenced
+ // by the bluepill code (see bluepill_amd64.s).
+ ring0.CPU
+
+ // id is the vCPU id.
+ id int
+
+ // fd is the vCPU fd.
+ fd int
+
+ // tid is the last set tid.
+ tid uint64
+
+ // switches is a count of world switches (informational only).
+ switches uint32
+
+ // faults is a count of world faults (informational only).
+ faults uint32
+
+ // state is the vCPU state.
+ //
+ // This is a bitmask of the three fields (vCPU*) described above.
+ state uint32
+
+ // runData for this vCPU.
+ runData *runData
+
+ // machine associated with this vCPU.
+ machine *machine
+
+ // active is the current addressSpace: this is set and read atomically,
+ // it is used to elide unnecessary interrupts due to invalidations.
+ active atomicAddressSpace
+
+ // vCPUArchState is the architecture-specific state.
+ vCPUArchState
+
+ dieState dieState
+}
+
+type dieState struct {
+ // message is thrown from die.
+ message string
+
+ // guestRegs is used to store register state during vCPU.die() to prevent
+ // allocation inside nosplit function.
+ guestRegs userRegs
+}
+
+// newVCPU creates a returns a new vCPU.
+//
+// Precondtion: mu must be held.
+func (m *machine) newVCPU() *vCPU {
+ id := len(m.vCPUs)
+
+ // Create the vCPU.
+ fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CREATE_VCPU, uintptr(id))
+ if errno != 0 {
+ panic(fmt.Sprintf("error creating new vCPU: %v", errno))
+ }
+
+ c := &vCPU{
+ id: id,
+ fd: int(fd),
+ machine: m,
+ }
+ c.CPU.Init(&m.kernel, c)
+ m.vCPUsByID[c.id] = c
+
+ // Ensure the signal mask is correct.
+ if err := c.setSignalMask(); err != nil {
+ panic(fmt.Sprintf("error setting signal mask: %v", err))
+ }
+
+ // Map the run data.
+ runData, err := mapRunData(int(fd))
+ if err != nil {
+ panic(fmt.Sprintf("error mapping run data: %v", err))
+ }
+ c.runData = runData
+
+ // Initialize architecture state.
+ if err := c.initArchState(); err != nil {
+ panic(fmt.Sprintf("error initialization vCPU state: %v", err))
+ }
+
+ return c // Done.
+}
+
+// newMachine returns a new VM context.
+func newMachine(vm int) (*machine, error) {
+ // Create the machine.
+ m := &machine{
+ fd: vm,
+ vCPUs: make(map[uint64]*vCPU),
+ vCPUsByID: make(map[int]*vCPU),
+ }
+ m.available.L = &m.mu
+ m.kernel.Init(ring0.KernelOpts{
+ PageTables: pagetables.New(newAllocator()),
+ })
+
+ maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
+ if errno != 0 {
+ m.maxVCPUs = _KVM_NR_VCPUS
+ } else {
+ m.maxVCPUs = int(maxVCPUs)
+ }
+ log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
+
+ // Apply the physical mappings. Note that these mappings may point to
+ // guest physical addresses that are not actually available. These
+ // physical pages are mapped on demand, see kernel_unsafe.go.
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ // Map everything in the lower half.
+ m.kernel.PageTables.Map(
+ usermem.Addr(pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+
+ // And keep everything in the upper half.
+ m.kernel.PageTables.Map(
+ usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+
+ return true // Keep iterating.
+ })
+
+ // Ensure that the currently mapped virtual regions are actually
+ // available in the VM. Note that this doesn't guarantee no future
+ // faults, however it should guarantee that everything is available to
+ // ensure successful vCPU entry.
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return // skip region.
+ }
+ for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
+ physical, length, ok := translateToPhysical(virtual)
+ if !ok {
+ // This must be an invalid region that was
+ // knocked out by creation of the physical map.
+ return
+ }
+ if virtual+length > vr.virtual+vr.length {
+ // Cap the length to the end of the area.
+ length = vr.virtual + vr.length - virtual
+ }
+
+ // Ensure the physical range is mapped.
+ m.mapPhysical(physical, length)
+ virtual += length
+ }
+ })
+
+ // Initialize architecture state.
+ if err := m.initArchState(); err != nil {
+ m.Destroy()
+ return nil, err
+ }
+
+ // Ensure the machine is cleaned up properly.
+ runtime.SetFinalizer(m, (*machine).Destroy)
+ return m, nil
+}
+
+// mapPhysical checks for the mapping of a physical range, and installs one if
+// not available. This attempts to be efficient for calls in the hot path.
+//
+// This panics on error.
+func (m *machine) mapPhysical(physical, length uintptr) {
+ for end := physical + length; physical < end; {
+ _, physicalStart, length, ok := calculateBluepillFault(physical)
+ if !ok {
+ // Should never happen.
+ panic("mapPhysical on unknown physical address")
+ }
+
+ if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok {
+ // Not present in the cache; requires setting the slot.
+ if _, ok := handleBluepillFault(m, physical); !ok {
+ panic("handleBluepillFault failed")
+ }
+ }
+
+ // Move to the next chunk.
+ physical = physicalStart + length
+ }
+}
+
+// Destroy frees associated resources.
+//
+// Destroy should only be called once all active users of the machine are gone.
+// The machine object should not be used after calling Destroy.
+//
+// Precondition: all vCPUs must be returned to the machine.
+func (m *machine) Destroy() {
+ runtime.SetFinalizer(m, nil)
+
+ // Destroy vCPUs.
+ for _, c := range m.vCPUs {
+ // Ensure the vCPU is not still running in guest mode. This is
+ // possible iff teardown has been done by other threads, and
+ // somehow a single thread has not executed any system calls.
+ c.BounceToHost()
+
+ // Note that the runData may not be mapped if an error occurs
+ // during the middle of initialization.
+ if c.runData != nil {
+ if err := unmapRunData(c.runData); err != nil {
+ panic(fmt.Sprintf("error unmapping rundata: %v", err))
+ }
+ }
+ if err := syscall.Close(int(c.fd)); err != nil {
+ panic(fmt.Sprintf("error closing vCPU fd: %v", err))
+ }
+ }
+
+ // vCPUs are gone: teardown machine state.
+ if err := syscall.Close(m.fd); err != nil {
+ panic(fmt.Sprintf("error closing VM fd: %v", err))
+ }
+}
+
+// Get gets an available vCPU.
+func (m *machine) Get() *vCPU {
+ runtime.LockOSThread()
+ tid := procid.Current()
+ m.mu.RLock()
+
+ // Check for an exact match.
+ if c := m.vCPUs[tid]; c != nil {
+ c.lock()
+ m.mu.RUnlock()
+ return c
+ }
+
+ // The happy path failed. We now proceed to acquire an exclusive lock
+ // (because the vCPU map may change), and scan all available vCPUs.
+ m.mu.RUnlock()
+ m.mu.Lock()
+
+ for {
+ // Scan for an available vCPU.
+ for origTID, c := range m.vCPUs {
+ if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
+ delete(m.vCPUs, origTID)
+ m.vCPUs[tid] = c
+ m.mu.Unlock()
+ c.loadSegments(tid)
+ return c
+ }
+ }
+
+ // Create a new vCPU (maybe).
+ if len(m.vCPUs) < m.maxVCPUs {
+ c := m.newVCPU()
+ c.lock()
+ m.vCPUs[tid] = c
+ m.mu.Unlock()
+ c.loadSegments(tid)
+ return c
+ }
+
+ // Scan for something not in user mode.
+ for origTID, c := range m.vCPUs {
+ if !atomic.CompareAndSwapUint32(&c.state, vCPUGuest, vCPUGuest|vCPUWaiter) {
+ continue
+ }
+
+ // The vCPU is not be able to transition to
+ // vCPUGuest|vCPUUser or to vCPUUser because that
+ // transition requires holding the machine mutex, as we
+ // do now. There is no path to register a waiter on
+ // just the vCPUReady state.
+ for {
+ c.waitUntilNot(vCPUGuest | vCPUWaiter)
+ if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
+ break
+ }
+ }
+
+ // Steal the vCPU.
+ delete(m.vCPUs, origTID)
+ m.vCPUs[tid] = c
+ m.mu.Unlock()
+ c.loadSegments(tid)
+ return c
+ }
+
+ // Everything is executing in user mode. Wait until something
+ // is available. Note that signaling the condition variable
+ // will have the extra effect of kicking the vCPUs out of guest
+ // mode if that's where they were.
+ m.available.Wait()
+ }
+}
+
+// Put puts the current vCPU.
+func (m *machine) Put(c *vCPU) {
+ c.unlock()
+ runtime.UnlockOSThread()
+ m.available.Signal()
+}
+
+// newDirtySet returns a new dirty set.
+func (m *machine) newDirtySet() *dirtySet {
+ return &dirtySet{
+ vCPUs: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
+ }
+}
+
+// lock marks the vCPU as in user mode.
+//
+// This should only be called directly when known to be safe, i.e. when
+// the vCPU is owned by the current TID with no chance of theft.
+//
+//go:nosplit
+func (c *vCPU) lock() {
+ atomicbitops.OrUint32(&c.state, vCPUUser)
+}
+
+// unlock clears the vCPUUser bit.
+//
+//go:nosplit
+func (c *vCPU) unlock() {
+ if atomic.CompareAndSwapUint32(&c.state, vCPUUser|vCPUGuest, vCPUGuest) {
+ // Happy path: no exits are forced, and we can continue
+ // executing on our merry way with a single atomic access.
+ return
+ }
+
+ // Clear the lock.
+ origState := atomic.LoadUint32(&c.state)
+ atomicbitops.AndUint32(&c.state, ^vCPUUser)
+ switch origState {
+ case vCPUUser:
+ // Normal state.
+ case vCPUUser | vCPUGuest | vCPUWaiter:
+ // Force a transition: this must trigger a notification when we
+ // return from guest mode.
+ c.notify()
+ case vCPUUser | vCPUWaiter:
+ // Waiting for the lock to be released; the responsibility is
+ // on us to notify the waiter and clear the associated bit.
+ atomicbitops.AndUint32(&c.state, ^vCPUWaiter)
+ c.notify()
+ default:
+ panic("invalid state")
+ }
+}
+
+// NotifyInterrupt implements interrupt.Receiver.NotifyInterrupt.
+//
+//go:nosplit
+func (c *vCPU) NotifyInterrupt() {
+ c.BounceToKernel()
+}
+
+// pid is used below in bounce.
+var pid = syscall.Getpid()
+
+// bounce forces a return to the kernel or to host mode.
+//
+// This effectively unwinds the state machine.
+func (c *vCPU) bounce(forceGuestExit bool) {
+ for {
+ switch state := atomic.LoadUint32(&c.state); state {
+ case vCPUReady, vCPUWaiter:
+ // There is nothing to be done, we're already in the
+ // kernel pre-acquisition. The Bounce criteria have
+ // been satisfied.
+ return
+ case vCPUUser:
+ // We need to register a waiter for the actual guest
+ // transition. When the transition takes place, then we
+ // can inject an interrupt to ensure a return to host
+ // mode.
+ atomic.CompareAndSwapUint32(&c.state, state, state|vCPUWaiter)
+ case vCPUUser | vCPUWaiter:
+ // Wait for the transition to guest mode. This should
+ // come from the bluepill handler.
+ c.waitUntilNot(state)
+ case vCPUGuest, vCPUUser | vCPUGuest:
+ if state == vCPUGuest && !forceGuestExit {
+ // The vCPU is already not acquired, so there's
+ // no need to do a fresh injection here.
+ return
+ }
+ // The vCPU is in user or kernel mode. Attempt to
+ // register a notification on change.
+ if !atomic.CompareAndSwapUint32(&c.state, state, state|vCPUWaiter) {
+ break // Retry.
+ }
+ for {
+ // We need to spin here until the signal is
+ // delivered, because Tgkill can return EAGAIN
+ // under memory pressure. Since we already
+ // marked ourselves as a waiter, we need to
+ // ensure that a signal is actually delivered.
+ if err := syscall.Tgkill(pid, int(atomic.LoadUint64(&c.tid)), bounceSignal); err == nil {
+ break
+ } else if err.(syscall.Errno) == syscall.EAGAIN {
+ continue
+ } else {
+ // Nothing else should be returned by tgkill.
+ panic(fmt.Sprintf("unexpected tgkill error: %v", err))
+ }
+ }
+ case vCPUGuest | vCPUWaiter, vCPUUser | vCPUGuest | vCPUWaiter:
+ if state == vCPUGuest|vCPUWaiter && !forceGuestExit {
+ // See above.
+ return
+ }
+ // Wait for the transition. This again should happen
+ // from the bluepill handler, but on the way out.
+ c.waitUntilNot(state)
+ default:
+ // Should not happen: the above is exhaustive.
+ panic("invalid state")
+ }
+ }
+}
+
+// BounceToKernel ensures that the vCPU bounces back to the kernel.
+//
+//go:nosplit
+func (c *vCPU) BounceToKernel() {
+ c.bounce(false)
+}
+
+// BounceToHost ensures that the vCPU is in host mode.
+//
+//go:nosplit
+func (c *vCPU) BounceToHost() {
+ c.bounce(true)
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
new file mode 100644
index 000000000..b6821122a
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -0,0 +1,357 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "reflect"
+ "runtime/debug"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// initArchState initializes architecture-specific state.
+func (m *machine) initArchState() error {
+ // Set the legacy TSS address. This address is covered by the reserved
+ // range (up to 4GB). In fact, this is a main reason it exists.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_SET_TSS_ADDR,
+ uintptr(reservedMemory-(3*usermem.PageSize))); errno != 0 {
+ return errno
+ }
+
+ // Enable CPUID faulting, if possible. Note that this also serves as a
+ // basic platform sanity tests, since we will enter guest mode for the
+ // first time here. The recovery is necessary, since if we fail to read
+ // the platform info register, we will retry to host mode and
+ // ultimately need to handle a segmentation fault.
+ old := debug.SetPanicOnFault(true)
+ defer func() {
+ recover()
+ debug.SetPanicOnFault(old)
+ }()
+ m.retryInGuest(func() {
+ ring0.SetCPUIDFaulting(true)
+ })
+
+ return nil
+}
+
+type vCPUArchState struct {
+ // PCIDs is the set of PCIDs for this vCPU.
+ //
+ // This starts above fixedKernelPCID.
+ PCIDs *pagetables.PCIDs
+
+ // floatingPointState is the floating point state buffer used in guest
+ // to host transitions. See usage in bluepill_amd64.go.
+ floatingPointState *arch.FloatingPointData
+}
+
+const (
+ // fixedKernelPCID is a fixed kernel PCID used for the kernel page
+ // tables. We must start allocating user PCIDs above this in order to
+ // avoid any conflict (see below).
+ fixedKernelPCID = 1
+
+ // poolPCIDs is the number of PCIDs to record in the database. As this
+ // grows, assignment can take longer, since it is a simple linear scan.
+ // Beyond a relatively small number, there are likely few perform
+ // benefits, since the TLB has likely long since lost any translations
+ // from more than a few PCIDs past.
+ poolPCIDs = 8
+)
+
+// dropPageTables drops cached page table entries.
+func (m *machine) dropPageTables(pt *pagetables.PageTables) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ // Clear from all PCIDs.
+ for _, c := range m.vCPUs {
+ c.PCIDs.Drop(pt)
+ }
+}
+
+// initArchState initializes architecture-specific state.
+func (c *vCPU) initArchState() error {
+ var (
+ kernelSystemRegs systemRegs
+ kernelUserRegs userRegs
+ )
+
+ // Set base control registers.
+ kernelSystemRegs.CR0 = c.CR0()
+ kernelSystemRegs.CR4 = c.CR4()
+ kernelSystemRegs.EFER = c.EFER()
+
+ // Set the IDT & GDT in the registers.
+ kernelSystemRegs.IDT.base, kernelSystemRegs.IDT.limit = c.IDT()
+ kernelSystemRegs.GDT.base, kernelSystemRegs.GDT.limit = c.GDT()
+ kernelSystemRegs.CS.Load(&ring0.KernelCodeSegment, ring0.Kcode)
+ kernelSystemRegs.DS.Load(&ring0.UserDataSegment, ring0.Udata)
+ kernelSystemRegs.ES.Load(&ring0.UserDataSegment, ring0.Udata)
+ kernelSystemRegs.SS.Load(&ring0.KernelDataSegment, ring0.Kdata)
+ kernelSystemRegs.FS.Load(&ring0.UserDataSegment, ring0.Udata)
+ kernelSystemRegs.GS.Load(&ring0.UserDataSegment, ring0.Udata)
+ tssBase, tssLimit, tss := c.TSS()
+ kernelSystemRegs.TR.Load(tss, ring0.Tss)
+ kernelSystemRegs.TR.base = tssBase
+ kernelSystemRegs.TR.limit = uint32(tssLimit)
+
+ // Point to kernel page tables, with no initial PCID.
+ kernelSystemRegs.CR3 = c.machine.kernel.PageTables.CR3(false, 0)
+
+ // Initialize the PCID database.
+ if hasGuestPCID {
+ // Note that NewPCIDs may return a nil table here, in which
+ // case we simply don't use PCID support (see below). In
+ // practice, this should not happen, however.
+ c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
+ }
+
+ // Set the CPUID; this is required before setting system registers,
+ // since KVM will reject several CR4 bits if the CPUID does not
+ // indicate the support is available.
+ if err := c.setCPUID(); err != nil {
+ return err
+ }
+
+ // Set the entrypoint for the kernel.
+ kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer())
+ kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ kernelUserRegs.RFLAGS = ring0.KernelFlagsSet
+
+ // Set the system registers.
+ if err := c.setSystemRegisters(&kernelSystemRegs); err != nil {
+ return err
+ }
+
+ // Set the user registers.
+ if err := c.setUserRegisters(&kernelUserRegs); err != nil {
+ return err
+ }
+
+ // Allocate some floating point state save area for the local vCPU.
+ // This will be saved prior to leaving the guest, and we restore from
+ // this always. We cannot use the pointer in the context alone because
+ // we don't know how large the area there is in reality.
+ c.floatingPointState = arch.NewFloatingPointData()
+
+ // Set the time offset to the host native time.
+ return c.setSystemTime()
+}
+
+// nonCanonical generates a canonical address return.
+//
+//go:nosplit
+func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ *info = arch.SignalInfo{
+ Signo: signal,
+ Code: arch.SignalInfoKernel,
+ }
+ info.SetAddr(addr) // Include address.
+ return usermem.NoAccess, platform.ErrContextSignal
+}
+
+// fault generates an appropriate fault return.
+//
+//go:nosplit
+func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ bluepill(c) // Probably no-op, but may not be.
+ faultAddr := ring0.ReadCR2()
+ code, user := c.ErrorCode()
+ if !user {
+ // The last fault serviced by this CPU was not a user
+ // fault, so we can't reliably trust the faultAddr or
+ // the code provided here. We need to re-execute.
+ return usermem.NoAccess, platform.ErrContextInterrupt
+ }
+ // Reset the pointed SignalInfo.
+ *info = arch.SignalInfo{Signo: signal}
+ info.SetAddr(uint64(faultAddr))
+ accessType := usermem.AccessType{
+ Read: code&(1<<1) == 0,
+ Write: code&(1<<1) != 0,
+ Execute: code&(1<<4) != 0,
+ }
+ if !accessType.Write && !accessType.Execute {
+ info.Code = 1 // SEGV_MAPERR.
+ } else {
+ info.Code = 2 // SEGV_ACCERR.
+ }
+ return accessType, platform.ErrContextSignal
+}
+
+// SwitchToUser unpacks architectural-details.
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
+ // Check for canonical addresses.
+ if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Rip) {
+ return nonCanonical(regs.Rip, int32(syscall.SIGSEGV), info)
+ } else if !ring0.IsCanonical(regs.Rsp) {
+ return nonCanonical(regs.Rsp, int32(syscall.SIGBUS), info)
+ } else if !ring0.IsCanonical(regs.Fs_base) {
+ return nonCanonical(regs.Fs_base, int32(syscall.SIGBUS), info)
+ } else if !ring0.IsCanonical(regs.Gs_base) {
+ return nonCanonical(regs.Gs_base, int32(syscall.SIGBUS), info)
+ }
+
+ // Assign PCIDs.
+ if c.PCIDs != nil {
+ var requireFlushPCID bool // Force a flush?
+ switchOpts.UserPCID, requireFlushPCID = c.PCIDs.Assign(switchOpts.PageTables)
+ switchOpts.KernelPCID = fixedKernelPCID
+ switchOpts.Flush = switchOpts.Flush || requireFlushPCID
+ }
+
+ // See below.
+ var vector ring0.Vector
+
+ // Past this point, stack growth can cause system calls (and a break
+ // from guest mode). So we need to ensure that between the bluepill
+ // call here and the switch call immediately below, no additional
+ // allocations occur.
+ entersyscall()
+ bluepill(c)
+ vector = c.CPU.SwitchToUser(switchOpts)
+ exitsyscall()
+
+ switch vector {
+ case ring0.Syscall, ring0.SyscallInt80:
+ // Fast path: system call executed.
+ return usermem.NoAccess, nil
+
+ case ring0.PageFault:
+ return c.fault(int32(syscall.SIGSEGV), info)
+
+ case ring0.Debug, ring0.Breakpoint:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGTRAP),
+ Code: 1, // TRAP_BRKPT (breakpoint).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.GeneralProtectionFault,
+ ring0.SegmentNotPresent,
+ ring0.BoundRangeExceeded,
+ ring0.InvalidTSS,
+ ring0.StackSegmentFault:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGSEGV),
+ Code: arch.SignalInfoKernel,
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ if vector == ring0.GeneralProtectionFault {
+ // When CPUID faulting is enabled, we will generate a #GP(0) when
+ // userspace executes a CPUID instruction. This is handled above,
+ // because we need to be able to map and read user memory.
+ return usermem.AccessType{}, platform.ErrContextSignalCPUID
+ }
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.InvalidOpcode:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGILL),
+ Code: 1, // ILL_ILLOPC (illegal opcode).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.DivideByZero:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGFPE),
+ Code: 1, // FPE_INTDIV (divide by zero).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.Overflow:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGFPE),
+ Code: 2, // FPE_INTOVF (integer overflow).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.X87FloatingPointException,
+ ring0.SIMDFloatingPointException:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGFPE),
+ Code: 7, // FPE_FLTINV (invalid operation).
+ }
+ info.SetAddr(switchOpts.Registers.Rip) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+
+ case ring0.Vector(bounce): // ring0.VirtualizationException
+ return usermem.NoAccess, platform.ErrContextInterrupt
+
+ case ring0.AlignmentCheck:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGBUS),
+ Code: 2, // BUS_ADRERR (physical address does not exist).
+ }
+ return usermem.NoAccess, platform.ErrContextSignal
+
+ case ring0.NMI:
+ // An NMI is generated only when a fault is not servicable by
+ // KVM itself, so we think some mapping is writeable but it's
+ // really not. This could happen, e.g. if some file is
+ // truncated (and would generate a SIGBUS) and we map it
+ // directly into the instance.
+ return c.fault(int32(syscall.SIGBUS), info)
+
+ case ring0.DeviceNotAvailable,
+ ring0.DoubleFault,
+ ring0.CoprocessorSegmentOverrun,
+ ring0.MachineCheck,
+ ring0.SecurityException:
+ fallthrough
+ default:
+ panic(fmt.Sprintf("unexpected vector: 0x%x", vector))
+ }
+}
+
+// retryInGuest runs the given function in guest mode.
+//
+// If the function does not complete in guest mode (due to execution of a
+// system call due to a GC stall, for example), then it will be retried. The
+// given function must be idempotent as a result of the retry mechanism.
+func (m *machine) retryInGuest(fn func()) {
+ c := m.Get()
+ defer m.Put(c)
+ for {
+ c.ClearErrorCode() // See below.
+ bluepill(c) // Force guest mode.
+ fn() // Execute the given function.
+ _, user := c.ErrorCode()
+ if user {
+ // If user is set, then we haven't bailed back to host
+ // mode via a kernel exception or system call. We
+ // consider the full function to have executed in guest
+ // mode and we can return.
+ break
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
new file mode 100644
index 000000000..06a2e3b0c
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -0,0 +1,161 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/time"
+)
+
+// setMemoryRegion initializes a region.
+//
+// This may be called from bluepillHandler, and therefore returns an errno
+// directly (instead of wrapping in an error) to avoid allocations.
+//
+//go:nosplit
+func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno {
+ userRegion := userMemoryRegion{
+ slot: uint32(slot),
+ flags: 0,
+ guestPhysAddr: uint64(physical),
+ memorySize: uint64(length),
+ userspaceAddr: uint64(virtual),
+ }
+
+ // Set the region.
+ _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_SET_USER_MEMORY_REGION,
+ uintptr(unsafe.Pointer(&userRegion)))
+ return errno
+}
+
+// loadSegments copies the current segments.
+//
+// This may be called from within the signal context and throws on error.
+//
+//go:nosplit
+func (c *vCPU) loadSegments(tid uint64) {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_ARCH_PRCTL,
+ linux.ARCH_GET_FS,
+ uintptr(unsafe.Pointer(&c.CPU.Registers().Fs_base)),
+ 0); errno != 0 {
+ throw("getting FS segment")
+ }
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_ARCH_PRCTL,
+ linux.ARCH_GET_GS,
+ uintptr(unsafe.Pointer(&c.CPU.Registers().Gs_base)),
+ 0); errno != 0 {
+ throw("getting GS segment")
+ }
+ atomic.StoreUint64(&c.tid, tid)
+}
+
+// setCPUID sets the CPUID to be used by the guest.
+func (c *vCPU) setCPUID() error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_CPUID2,
+ uintptr(unsafe.Pointer(&cpuidSupported))); errno != 0 {
+ return fmt.Errorf("error setting CPUID: %v", errno)
+ }
+ return nil
+}
+
+// setSystemTime sets the TSC for the vCPU.
+//
+// This has to make the call many times in order to minimize the intrinstic
+// error in the offset. Unfortunately KVM does not expose a relative offset via
+// the API, so this is an approximation. We do this via an iterative algorithm.
+// This has the advantage that it can generally deal with highly variable
+// system call times and should converge on the correct offset.
+func (c *vCPU) setSystemTime() error {
+ const (
+ _MSR_IA32_TSC = 0x00000010
+ calibrateTries = 10
+ )
+ registers := modelControlRegisters{
+ nmsrs: 1,
+ }
+ registers.entries[0] = modelControlRegister{
+ index: _MSR_IA32_TSC,
+ }
+ target := uint64(^uint32(0))
+ for done := 0; done < calibrateTries; {
+ start := uint64(time.Rdtsc())
+ registers.entries[0].data = start + target
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_MSRS,
+ uintptr(unsafe.Pointer(&registers))); errno != 0 {
+ return fmt.Errorf("error setting system time: %v", errno)
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result. So we only count attempts
+ // within +/- 6.25% of our minimum as an attempt.
+ end := uint64(time.Rdtsc())
+ if end < start {
+ continue // Totally bogus.
+ }
+ half := (end - start) / 2
+ if half < target {
+ target = half
+ }
+ if (half - target) < target/8 {
+ done++
+ }
+ }
+ return nil
+}
+
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
new file mode 100644
index 000000000..1d3c6d2d6
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -0,0 +1,160 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.12
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+package kvm
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+//go:linkname entersyscall runtime.entersyscall
+func entersyscall()
+
+//go:linkname exitsyscall runtime.exitsyscall
+func exitsyscall()
+
+// mapRunData maps the vCPU run data.
+func mapRunData(fd int) (*runData, error) {
+ r, _, errno := syscall.RawSyscall6(
+ syscall.SYS_MMAP,
+ 0,
+ uintptr(runDataSize),
+ syscall.PROT_READ|syscall.PROT_WRITE,
+ syscall.MAP_SHARED,
+ uintptr(fd),
+ 0)
+ if errno != 0 {
+ return nil, fmt.Errorf("error mapping runData: %v", errno)
+ }
+ return (*runData)(unsafe.Pointer(r)), nil
+}
+
+// unmapRunData unmaps the vCPU run data.
+func unmapRunData(r *runData) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_MUNMAP,
+ uintptr(unsafe.Pointer(r)),
+ uintptr(runDataSize),
+ 0); errno != 0 {
+ return fmt.Errorf("error unmapping runData: %v", errno)
+ }
+ return nil
+}
+
+// setUserRegisters sets user registers in the vCPU.
+func (c *vCPU) setUserRegisters(uregs *userRegs) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_REGS,
+ uintptr(unsafe.Pointer(uregs))); errno != 0 {
+ return fmt.Errorf("error setting user registers: %v", errno)
+ }
+ return nil
+}
+
+// getUserRegisters reloads user registers in the vCPU.
+//
+// This is safe to call from a nosplit context.
+//
+//go:nosplit
+func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_REGS,
+ uintptr(unsafe.Pointer(uregs))); errno != 0 {
+ return errno
+ }
+ return 0
+}
+
+// setSystemRegisters sets system registers.
+func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SREGS,
+ uintptr(unsafe.Pointer(sregs))); errno != 0 {
+ return fmt.Errorf("error setting system registers: %v", errno)
+ }
+ return nil
+}
+
+// atomicAddressSpace is an atomic address space pointer.
+type atomicAddressSpace struct {
+ pointer unsafe.Pointer
+}
+
+// set sets the address space value.
+//
+//go:nosplit
+func (a *atomicAddressSpace) set(as *addressSpace) {
+ atomic.StorePointer(&a.pointer, unsafe.Pointer(as))
+}
+
+// get gets the address space value.
+//
+// Note that this should be considered best-effort, and may have changed by the
+// time this function returns.
+//
+//go:nosplit
+func (a *atomicAddressSpace) get() *addressSpace {
+ return (*addressSpace)(atomic.LoadPointer(&a.pointer))
+}
+
+// notify notifies that the vCPU has transitioned modes.
+//
+// This may be called by a signal handler and therefore throws on error.
+//
+//go:nosplit
+func (c *vCPU) notify() {
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_FUTEX,
+ uintptr(unsafe.Pointer(&c.state)),
+ linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG,
+ ^uintptr(0), // Number of waiters.
+ 0, 0, 0)
+ if errno != 0 {
+ throw("futex wake error")
+ }
+}
+
+// waitUntilNot waits for the vCPU to transition modes.
+//
+// The state should have been previously set to vCPUWaiter after performing an
+// appropriate action to cause a transition (e.g. interrupt injection).
+//
+// This panics on error.
+func (c *vCPU) waitUntilNot(state uint32) {
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_FUTEX,
+ uintptr(unsafe.Pointer(&c.state)),
+ linux.FUTEX_WAIT|linux.FUTEX_PRIVATE_FLAG,
+ uintptr(state),
+ 0, 0, 0)
+ if errno != 0 && errno != syscall.EINTR && errno != syscall.EAGAIN {
+ panic("futex wait error")
+ }
+}
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
new file mode 100644
index 000000000..450eb8201
--- /dev/null
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -0,0 +1,224 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "fmt"
+ "sort"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+const (
+ // reservedMemory is a chunk of physical memory reserved starting at
+ // physical address zero. There are some special pages in this region,
+ // so we just call the whole thing off.
+ //
+ // Other architectures may define this to be zero.
+ reservedMemory = 0x100000000
+)
+
+type region struct {
+ virtual uintptr
+ length uintptr
+}
+
+type physicalRegion struct {
+ region
+ physical uintptr
+}
+
+// physicalRegions contains a list of available physical regions.
+//
+// The physical value used in physicalRegions is a number indicating the
+// physical offset, aligned appropriately and starting above reservedMemory.
+var physicalRegions []physicalRegion
+
+// fillAddressSpace fills the host address space with PROT_NONE mappings until
+// we have a host address space size that is less than or equal to the physical
+// address space. This allows us to have an injective host virtual to guest
+// physical mapping.
+//
+// The excluded regions are returned.
+func fillAddressSpace() (excludedRegions []region) {
+ // We can cut vSize in half, because the kernel will be using the top
+ // half and we ignore it while constructing mappings. It's as if we've
+ // already excluded half the possible addresses.
+ vSize := uintptr(1) << ring0.VirtualAddressBits()
+ vSize = vSize >> 1
+
+ // We exclude reservedMemory below from our physical memory size, so it
+ // needs to be dropped here as well. Otherwise, we could end up with
+ // physical addresses that are beyond what is mapped.
+ pSize := uintptr(1) << ring0.PhysicalAddressBits()
+ pSize -= reservedMemory
+
+ // Add specifically excluded regions; see excludeVirtualRegion.
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ excludedRegions = append(excludedRegions, vr.region)
+ vSize -= vr.length
+ log.Infof("excluded: virtual [%x,%x)", vr.virtual, vr.virtual+vr.length)
+ }
+ })
+
+ // Do we need any more work?
+ if vSize < pSize {
+ return excludedRegions
+ }
+
+ // Calculate the required space and fill it.
+ //
+ // Note carefully that we add faultBlockSize to required up front, and
+ // on each iteration of the loop below (i.e. each new physical region
+ // we define), we add faultBlockSize again. This is done because the
+ // computation of physical regions will ensure proper alignments with
+ // faultBlockSize, potentially causing up to faultBlockSize bytes in
+ // internal fragmentation for each physical region. So we need to
+ // account for this properly during allocation.
+ requiredAddr, ok := usermem.Addr(vSize - pSize + faultBlockSize).RoundUp()
+ if !ok {
+ panic(fmt.Sprintf(
+ "overflow for vSize (%x) - pSize (%x) + faultBlockSize (%x)",
+ vSize, pSize, faultBlockSize))
+ }
+ required := uintptr(requiredAddr)
+ current := required // Attempted mmap size.
+ for filled := uintptr(0); filled < required && current > 0; {
+ addr, _, errno := syscall.RawSyscall6(
+ syscall.SYS_MMAP,
+ 0, // Suggested address.
+ current,
+ syscall.PROT_NONE,
+ syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE|syscall.MAP_NORESERVE,
+ 0, 0)
+ if errno != 0 {
+ // Attempt half the size; overflow not possible.
+ currentAddr, _ := usermem.Addr(current >> 1).RoundUp()
+ current = uintptr(currentAddr)
+ continue
+ }
+ // We filled a block.
+ filled += current
+ excludedRegions = append(excludedRegions, region{
+ virtual: addr,
+ length: current,
+ })
+ // See comment above.
+ if filled != required {
+ required += faultBlockSize
+ }
+ }
+ if current == 0 {
+ panic("filling address space failed")
+ }
+ sort.Slice(excludedRegions, func(i, j int) bool {
+ return excludedRegions[i].virtual < excludedRegions[j].virtual
+ })
+ for _, r := range excludedRegions {
+ log.Infof("region: virtual [%x,%x)", r.virtual, r.virtual+r.length)
+ }
+ return excludedRegions
+}
+
+// computePhysicalRegions computes physical regions.
+func computePhysicalRegions(excludedRegions []region) (physicalRegions []physicalRegion) {
+ physical := uintptr(reservedMemory)
+ addValidRegion := func(virtual, length uintptr) {
+ if length == 0 {
+ return
+ }
+ if virtual == 0 {
+ virtual += usermem.PageSize
+ length -= usermem.PageSize
+ }
+ if end := virtual + length; end > ring0.MaximumUserAddress {
+ length -= (end - ring0.MaximumUserAddress)
+ }
+ if length == 0 {
+ return
+ }
+ // Round physical up to the same alignment as the virtual
+ // address (with respect to faultBlockSize).
+ if offset := virtual &^ faultBlockMask; physical&^faultBlockMask != offset {
+ if newPhysical := (physical & faultBlockMask) + offset; newPhysical > physical {
+ physical = newPhysical // Round up by only a little bit.
+ } else {
+ physical = ((physical + faultBlockSize) & faultBlockMask) + offset
+ }
+ }
+ physicalRegions = append(physicalRegions, physicalRegion{
+ region: region{
+ virtual: virtual,
+ length: length,
+ },
+ physical: physical,
+ })
+ physical += length
+ }
+ lastExcludedEnd := uintptr(0)
+ for _, r := range excludedRegions {
+ addValidRegion(lastExcludedEnd, r.virtual-lastExcludedEnd)
+ lastExcludedEnd = r.virtual + r.length
+ }
+ addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd)
+
+ // Dump our all physical regions.
+ for _, r := range physicalRegions {
+ log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)",
+ r.virtual, r.virtual+r.length, r.physical, r.physical+r.length)
+ }
+ return physicalRegions
+}
+
+// physicalInit initializes physical address mappings.
+func physicalInit() {
+ physicalRegions = computePhysicalRegions(fillAddressSpace())
+}
+
+// applyPhysicalRegions applies the given function on physical regions.
+//
+// Iteration continues as long as true is returned. The return value is the
+// return from the last call to fn, or true if there are no entries.
+//
+// Precondition: physicalInit must have been called.
+func applyPhysicalRegions(fn func(pr physicalRegion) bool) bool {
+ for _, pr := range physicalRegions {
+ if !fn(pr) {
+ return false
+ }
+ }
+ return true
+}
+
+// translateToPhysical translates the given virtual address.
+//
+// Precondition: physicalInit must have been called.
+//
+//go:nosplit
+func translateToPhysical(virtual uintptr) (physical uintptr, length uintptr, ok bool) {
+ for _, pr := range physicalRegions {
+ if pr.virtual <= virtual && virtual < pr.virtual+pr.length {
+ physical = pr.physical + (virtual - pr.virtual)
+ length = pr.length - (virtual - pr.virtual)
+ ok = true
+ return
+ }
+ }
+ return
+}
diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go
new file mode 100644
index 000000000..28a1b4414
--- /dev/null
+++ b/pkg/sentry/platform/kvm/virtual_map.go
@@ -0,0 +1,113 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "os"
+ "regexp"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+type virtualRegion struct {
+ region
+ accessType usermem.AccessType
+ shared bool
+ offset uintptr
+ filename string
+}
+
+// mapsLine matches a single line from /proc/PID/maps.
+var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2}:[0-9a-f]{2,} [0-9]+\\s+(.*)")
+
+// excludeRegion returns true if these regions should be excluded from the
+// physical map. Virtual regions need to be excluded if get_user_pages will
+// fail on those addresses, preventing KVM from satisfying EPT faults.
+//
+// This includes the VVAR page because the VVAR page may be mapped as I/O
+// memory. And the VDSO page is knocked out because the VVAR page is not even
+// recorded in /proc/self/maps on older kernels; knocking out the VDSO page
+// prevents code in the VDSO from accessing the VVAR address.
+//
+// This is called by the physical map functions, not applyVirtualRegions.
+func excludeVirtualRegion(r virtualRegion) bool {
+ return r.filename == "[vvar]" || r.filename == "[vdso]"
+}
+
+// applyVirtualRegions parses the process maps file.
+//
+// Unlike mappedRegions, these are not consistent over time.
+func applyVirtualRegions(fn func(vr virtualRegion)) error {
+ // Open /proc/self/maps.
+ f, err := os.Open("/proc/self/maps")
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ // Parse all entries.
+ r := bufio.NewReader(f)
+ for {
+ b, err := r.ReadBytes('\n')
+ if b != nil && len(b) > 0 {
+ m := mapsLine.FindSubmatch(b)
+ if m == nil {
+ // This should not happen: kernel bug?
+ return fmt.Errorf("badly formed line: %v", string(b))
+ }
+ start, err := strconv.ParseUint(string(m[1]), 16, 64)
+ if err != nil {
+ return fmt.Errorf("bad start address: %v", string(b))
+ }
+ end, err := strconv.ParseUint(string(m[2]), 16, 64)
+ if err != nil {
+ return fmt.Errorf("bad end address: %v", string(b))
+ }
+ read := m[3][0] == 'r'
+ write := m[3][1] == 'w'
+ execute := m[3][2] == 'x'
+ shared := m[3][3] == 's'
+ offset, err := strconv.ParseUint(string(m[4]), 16, 64)
+ if err != nil {
+ return fmt.Errorf("bad offset: %v", string(b))
+ }
+ fn(virtualRegion{
+ region: region{
+ virtual: uintptr(start),
+ length: uintptr(end - start),
+ },
+ accessType: usermem.AccessType{
+ Read: read,
+ Write: write,
+ Execute: execute,
+ },
+ shared: shared,
+ offset: uintptr(offset),
+ filename: string(m[5]),
+ })
+ }
+ if err != nil && err == io.EOF {
+ break
+ } else if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/platform/mmap_min_addr.go b/pkg/sentry/platform/mmap_min_addr.go
new file mode 100644
index 000000000..90976735b
--- /dev/null
+++ b/pkg/sentry/platform/mmap_min_addr.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package platform
+
+import (
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// systemMMapMinAddrSource is the source file.
+const systemMMapMinAddrSource = "/proc/sys/vm/mmap_min_addr"
+
+// systemMMapMinAddr is the system's minimum map address.
+var systemMMapMinAddr uint64
+
+// SystemMMapMinAddr returns the minimum system address.
+func SystemMMapMinAddr() usermem.Addr {
+ return usermem.Addr(systemMMapMinAddr)
+}
+
+// MMapMinAddr is a size zero struct that implements MinUserAddress based on
+// the system minimum address. It is suitable for embedding in platforms that
+// rely on the system mmap, and thus require the system minimum.
+type MMapMinAddr struct {
+}
+
+// MinUserAddress implements platform.MinUserAddresss.
+func (*MMapMinAddr) MinUserAddress() usermem.Addr {
+ return SystemMMapMinAddr()
+}
+
+func init() {
+ // Open the source file.
+ b, err := ioutil.ReadFile(systemMMapMinAddrSource)
+ if err != nil {
+ panic(fmt.Sprintf("couldn't open %s: %v", systemMMapMinAddrSource, err))
+ }
+
+ // Parse the result.
+ systemMMapMinAddr, err = strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64)
+ if err != nil {
+ panic(fmt.Sprintf("couldn't parse %s from %s: %v", string(b), systemMMapMinAddrSource, err))
+ }
+}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
new file mode 100644
index 000000000..ae37276ad
--- /dev/null
+++ b/pkg/sentry/platform/platform.go
@@ -0,0 +1,349 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package platform provides a Platform abstraction.
+//
+// See Platform for more information.
+package platform
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// Platform provides abstractions for execution contexts (Context,
+// AddressSpace).
+type Platform interface {
+ // SupportsAddressSpaceIO returns true if AddressSpaces returned by this
+ // Platform support AddressSpaceIO methods.
+ //
+ // The value returned by SupportsAddressSpaceIO is guaranteed to remain
+ // unchanged over the lifetime of the Platform.
+ SupportsAddressSpaceIO() bool
+
+ // CooperativelySchedulesAddressSpace returns true if the Platform has a
+ // limited number of AddressSpaces, such that mm.MemoryManager.Deactivate
+ // should call AddressSpace.Release when there are no goroutines that
+ // require the mm.MemoryManager to have an active AddressSpace.
+ //
+ // The value returned by CooperativelySchedulesAddressSpace is guaranteed
+ // to remain unchanged over the lifetime of the Platform.
+ CooperativelySchedulesAddressSpace() bool
+
+ // DetectsCPUPreemption returns true if Contexts returned by the Platform
+ // can reliably return ErrContextCPUPreempted.
+ DetectsCPUPreemption() bool
+
+ // MapUnit returns the alignment used for optional mappings into this
+ // platform's AddressSpaces. Higher values indicate lower per-page costs
+ // for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates
+ // that the cost of AddressSpace.MapFile is effectively independent of the
+ // number of pages mapped. If MapUnit is non-zero, it must be a power-of-2
+ // multiple of usermem.PageSize.
+ MapUnit() uint64
+
+ // MinUserAddress returns the minimum mappable address on this
+ // platform.
+ MinUserAddress() usermem.Addr
+
+ // MaxUserAddress returns the maximum mappable address on this
+ // platform.
+ MaxUserAddress() usermem.Addr
+
+ // NewAddressSpace returns a new memory context for this platform.
+ //
+ // If mappingsID is not nil, the platform may assume that (1) all calls
+ // to NewAddressSpace with the same mappingsID represent the same
+ // (mutable) set of mappings, and (2) the set of mappings has not
+ // changed since the last time AddressSpace.Release was called on an
+ // AddressSpace returned by a call to NewAddressSpace with the same
+ // mappingsID.
+ //
+ // If a new AddressSpace cannot be created immediately, a nil
+ // AddressSpace is returned, along with channel that is closed when
+ // the caller should retry a call to NewAddressSpace.
+ //
+ // In general, this blocking behavior only occurs when
+ // CooperativelySchedulesAddressSpace (above) returns false.
+ NewAddressSpace(mappingsID interface{}) (AddressSpace, <-chan struct{}, error)
+
+ // NewContext returns a new execution context.
+ NewContext() Context
+
+ // PreemptAllCPUs causes all concurrent calls to Context.Switch(), as well
+ // as the first following call to Context.Switch() for each Context, to
+ // return ErrContextCPUPreempted.
+ //
+ // PreemptAllCPUs is only supported if DetectsCPUPremption() == true.
+ // Platforms for which this does not hold may panic if PreemptAllCPUs is
+ // called.
+ PreemptAllCPUs() error
+}
+
+// NoCPUPreemptionDetection implements Platform.DetectsCPUPreemption and
+// dependent methods for Platforms that do not support this feature.
+type NoCPUPreemptionDetection struct{}
+
+// DetectsCPUPreemption implements Platform.DetectsCPUPreemption.
+func (NoCPUPreemptionDetection) DetectsCPUPreemption() bool {
+ return false
+}
+
+// PreemptAllCPUs implements Platform.PreemptAllCPUs.
+func (NoCPUPreemptionDetection) PreemptAllCPUs() error {
+ panic("This platform does not support CPU preemption detection")
+}
+
+// Context represents the execution context for a single thread.
+type Context interface {
+ // Switch resumes execution of the thread specified by the arch.Context
+ // in the provided address space. This call will block while the thread
+ // is executing.
+ //
+ // If cpu is non-negative, and it is not the number of the CPU that the
+ // thread executes on, Context should return ErrContextCPUPreempted. cpu
+ // can only be non-negative if Platform.DetectsCPUPreemption() is true;
+ // Contexts from Platforms for which this does not hold may ignore cpu, or
+ // panic if cpu is non-negative.
+ //
+ // Switch may return one of the following special errors:
+ //
+ // - nil: The Context invoked a system call.
+ //
+ // - ErrContextSignal: The Context was interrupted by a signal. The
+ // returned *arch.SignalInfo contains information about the signal. If
+ // arch.SignalInfo.Signo == SIGSEGV, the returned usermem.AccessType
+ // contains the access type of the triggering fault. The caller owns
+ // the returned SignalInfo.
+ //
+ // - ErrContextInterrupt: The Context was interrupted by a call to
+ // Interrupt(). Switch() may return ErrContextInterrupt spuriously. In
+ // particular, most implementations of Interrupt() will cause the first
+ // following call to Switch() to return ErrContextInterrupt if there is no
+ // concurrent call to Switch().
+ //
+ // - ErrContextCPUPreempted: See the definition of that error for details.
+ Switch(as AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error)
+
+ // Interrupt interrupts a concurrent call to Switch(), causing it to return
+ // ErrContextInterrupt.
+ Interrupt()
+}
+
+var (
+ // ErrContextSignal is returned by Context.Switch() to indicate that the
+ // Context was interrupted by a signal.
+ ErrContextSignal = fmt.Errorf("interrupted by signal")
+
+ // ErrContextSignalCPUID is equivalent to ErrContextSignal, except that
+ // a check should be done for execution of the CPUID instruction. If
+ // the current instruction pointer is a CPUID instruction, then this
+ // should be emulated appropriately. If not, then the given signal
+ // should be handled per above.
+ ErrContextSignalCPUID = fmt.Errorf("interrupted by signal, possible CPUID")
+
+ // ErrContextInterrupt is returned by Context.Switch() to indicate that the
+ // Context was interrupted by a call to Context.Interrupt().
+ ErrContextInterrupt = fmt.Errorf("interrupted by platform.Context.Interrupt()")
+
+ // ErrContextCPUPreempted is returned by Context.Switch() to indicate that
+ // one of the following occurred:
+ //
+ // - The CPU executing the Context is not the CPU passed to
+ // Context.Switch().
+ //
+ // - The CPU executing the Context may have executed another Context since
+ // the last time it executed this one; or the CPU has previously executed
+ // another Context, and has never executed this one.
+ //
+ // - Platform.PreemptAllCPUs() was called since the last return from
+ // Context.Switch().
+ ErrContextCPUPreempted = fmt.Errorf("interrupted by CPU preemption")
+)
+
+// SignalInterrupt is a signal reserved for use by implementations of
+// Context.Interrupt(). The sentry guarantees that it will ignore delivery of
+// this signal both to Contexts and to the sentry itself, under the assumption
+// that they originate from races with Context.Interrupt().
+//
+// NOTE(b/23420492): The Go runtime only guarantees that a small subset
+// of signals will be always be unblocked on all threads, one of which
+// is SIGCHLD.
+const SignalInterrupt = linux.SIGCHLD
+
+// AddressSpace represents a virtual address space in which a Context can
+// execute.
+type AddressSpace interface {
+ // MapFile creates a shared mapping of offsets fr from f at address addr.
+ // Any existing overlapping mappings are silently replaced.
+ //
+ // If precommit is true, the platform should eagerly commit resources (e.g.
+ // physical memory) to the mapping. The precommit flag is advisory and
+ // implementations may choose to ignore it.
+ //
+ // Preconditions: addr and fr must be page-aligned. fr.Length() > 0.
+ // at.Any() == true. At least one reference must be held on all pages in
+ // fr, and must continue to be held as long as pages are mapped.
+ MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error
+
+ // Unmap unmaps the given range.
+ //
+ // Preconditions: addr is page-aligned. length > 0.
+ Unmap(addr usermem.Addr, length uint64)
+
+ // Release releases this address space. After releasing, a new AddressSpace
+ // must be acquired via platform.NewAddressSpace().
+ Release()
+
+ // AddressSpaceIO methods are supported iff the associated platform's
+ // Platform.SupportsAddressSpaceIO() == true. AddressSpaces for which this
+ // does not hold may panic if AddressSpaceIO methods are invoked.
+ AddressSpaceIO
+}
+
+// AddressSpaceIO supports IO through the memory mappings installed in an
+// AddressSpace.
+//
+// AddressSpaceIO implementors are responsible for ensuring that address ranges
+// are application-mappable.
+type AddressSpaceIO interface {
+ // CopyOut copies len(src) bytes from src to the memory mapped at addr. It
+ // returns the number of bytes copied. If the number of bytes copied is <
+ // len(src), it returns a non-nil error explaining why.
+ CopyOut(addr usermem.Addr, src []byte) (int, error)
+
+ // CopyIn copies len(dst) bytes from the memory mapped at addr to dst.
+ // It returns the number of bytes copied. If the number of bytes copied is
+ // < len(dst), it returns a non-nil error explaining why.
+ CopyIn(addr usermem.Addr, dst []byte) (int, error)
+
+ // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number
+ // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a
+ // non-nil error explaining why.
+ ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error)
+
+ // SwapUint32 atomically sets the uint32 value at addr to new and returns
+ // the previous value.
+ //
+ // Preconditions: addr must be aligned to a 4-byte boundary.
+ SwapUint32(addr usermem.Addr, new uint32) (uint32, error)
+
+ // CompareAndSwapUint32 atomically compares the uint32 value at addr to
+ // old; if they are equal, the value in memory is replaced by new. In
+ // either case, the previous value stored in memory is returned.
+ //
+ // Preconditions: addr must be aligned to a 4-byte boundary.
+ CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error)
+
+ // LoadUint32 atomically loads the uint32 value at addr and returns it.
+ //
+ // Preconditions: addr must be aligned to a 4-byte boundary.
+ LoadUint32(addr usermem.Addr) (uint32, error)
+}
+
+// NoAddressSpaceIO implements AddressSpaceIO methods by panicing.
+type NoAddressSpaceIO struct{}
+
+// CopyOut implements AddressSpaceIO.CopyOut.
+func (NoAddressSpaceIO) CopyOut(addr usermem.Addr, src []byte) (int, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// CopyIn implements AddressSpaceIO.CopyIn.
+func (NoAddressSpaceIO) CopyIn(addr usermem.Addr, dst []byte) (int, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// ZeroOut implements AddressSpaceIO.ZeroOut.
+func (NoAddressSpaceIO) ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// SwapUint32 implements AddressSpaceIO.SwapUint32.
+func (NoAddressSpaceIO) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// CompareAndSwapUint32 implements AddressSpaceIO.CompareAndSwapUint32.
+func (NoAddressSpaceIO) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// LoadUint32 implements AddressSpaceIO.LoadUint32.
+func (NoAddressSpaceIO) LoadUint32(addr usermem.Addr) (uint32, error) {
+ panic("This platform does not support AddressSpaceIO")
+}
+
+// SegmentationFault is an error returned by AddressSpaceIO methods when IO
+// fails due to access of an unmapped page, or a mapped page with insufficient
+// permissions.
+type SegmentationFault struct {
+ // Addr is the address at which the fault occurred.
+ Addr usermem.Addr
+}
+
+// Error implements error.Error.
+func (f SegmentationFault) Error() string {
+ return fmt.Sprintf("segmentation fault at %#x", f.Addr)
+}
+
+// File represents a host file that may be mapped into an AddressSpace.
+type File interface {
+ // All pages in a File are reference-counted.
+
+ // IncRef increments the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr. (The File
+ // interface does not provide a way to acquire an initial reference;
+ // implementors may define mechanisms for doing so.)
+ IncRef(fr FileRange)
+
+ // DecRef decrements the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr.
+ DecRef(fr FileRange)
+
+ // MapInternal returns a mapping of the given file offsets in the invoking
+ // process' address space for reading and writing.
+ //
+ // Note that fr.Start and fr.End need not be page-aligned.
+ //
+ // Preconditions: fr.Length() > 0. At least one reference must be held on
+ // all pages in fr.
+ //
+ // Postconditions: The returned mapping is valid as long as at least one
+ // reference is held on the mapped pages.
+ MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
+
+ // FD returns the file descriptor represented by the File.
+ //
+ // The only permitted operation on the returned file descriptor is to map
+ // pages from it consistent with the requirements of AddressSpace.MapFile.
+ FD() int
+}
+
+// FileRange represents a range of uint64 offsets into a File.
+//
+// type FileRange <generated using go_generics>
+
+// String implements fmt.Stringer.String.
+func (fr FileRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
+}
diff --git a/pkg/sentry/platform/platform_state_autogen.go b/pkg/sentry/platform/platform_state_autogen.go
new file mode 100755
index 000000000..13ea50daf
--- /dev/null
+++ b/pkg/sentry/platform/platform_state_autogen.go
@@ -0,0 +1,24 @@
+// automatically generated by stateify.
+
+package platform
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *FileRange) beforeSave() {}
+func (x *FileRange) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *FileRange) afterLoad() {}
+func (x *FileRange) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func init() {
+ state.Register("platform.FileRange", (*FileRange)(nil), state.Fns{Save: (*FileRange).save, Load: (*FileRange).load})
+}
diff --git a/pkg/sentry/platform/procid/procid.go b/pkg/sentry/platform/procid/procid.go
new file mode 100644
index 000000000..78b92422c
--- /dev/null
+++ b/pkg/sentry/platform/procid/procid.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package procid provides a way to get the current system thread identifier.
+package procid
+
+// Current returns the current system thread identifier.
+//
+// Precondition: This should only be called with the runtime OS thread locked.
+func Current() uint64
diff --git a/pkg/sentry/platform/procid/procid_amd64.s b/pkg/sentry/platform/procid/procid_amd64.s
new file mode 100644
index 000000000..30ec8e6e2
--- /dev/null
+++ b/pkg/sentry/platform/procid/procid_amd64.s
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+// +build go1.8
+// +build !go1.14
+
+#include "textflag.h"
+
+TEXT ·Current(SB),NOSPLIT,$0-8
+ // The offset specified here is the m_procid offset for Go1.8+.
+ // Changes to this offset should be caught by the tests, and major
+ // version changes require an explicit tag change above.
+ MOVQ TLS, AX
+ MOVQ 0(AX)(TLS*1), AX
+ MOVQ 48(AX), AX // g_m (may change in future versions)
+ MOVQ 72(AX), AX // m_procid (may change in future versions)
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/procid/procid_arm64.s b/pkg/sentry/platform/procid/procid_arm64.s
new file mode 100644
index 000000000..e340d9f98
--- /dev/null
+++ b/pkg/sentry/platform/procid/procid_arm64.s
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+// +build go1.8
+// +build !go1.14
+
+#include "textflag.h"
+
+TEXT ·Current(SB),NOSPLIT,$0-8
+ // The offset specified here is the m_procid offset for Go1.8+.
+ // Changes to this offset should be caught by the tests, and major
+ // version changes require an explicit tag change above.
+ MOVD g, R0 // g
+ MOVD 48(R0), R0 // g_m (may change in future versions)
+ MOVD 72(R0), R0 // m_procid (may change in future versions)
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/procid/procid_state_autogen.go b/pkg/sentry/platform/procid/procid_state_autogen.go
new file mode 100755
index 000000000..f27a7c510
--- /dev/null
+++ b/pkg/sentry/platform/procid/procid_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package procid
+
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
new file mode 100644
index 000000000..6a890dd81
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -0,0 +1,238 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ptrace provides a ptrace-based implementation of the platform
+// interface. This is useful for development and testing purposes primarily,
+// and runs on stock kernels without special permissions.
+//
+// In a nutshell, it works as follows:
+//
+// The creation of a new address space creates a new child processes with a
+// single thread which is traced by a single goroutine.
+//
+// A context is just a collection of temporary variables. Calling Switch on a
+// context does the following:
+//
+// Locks the runtime thread.
+//
+// Looks up a traced subprocess thread for the current runtime thread. If
+// none exists, the dedicated goroutine is asked to create a new stopped
+// thread in the subprocess. This stopped subprocess thread is then traced
+// by the current thread and this information is stored for subsequent
+// switches.
+//
+// The context is then bound with information about the subprocess thread
+// so that the context may be appropriately interrupted via a signal.
+//
+// The requested operation is performed in the traced subprocess thread
+// (e.g. set registers, execute, return).
+//
+// Lock order:
+//
+// subprocess.mu
+// context.mu
+package ptrace
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/interrupt"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+var (
+ // stubStart is the link address for our stub, and determines the
+ // maximum user address. This is valid only after a call to stubInit.
+ //
+ // We attempt to link the stub here, and adjust downward as needed.
+ stubStart uintptr = 0x7fffffff0000
+
+ // stubEnd is the first byte past the end of the stub, as with
+ // stubStart this is valid only after a call to stubInit.
+ stubEnd uintptr
+
+ // stubInitialized controls one-time stub initialization.
+ stubInitialized sync.Once
+)
+
+type context struct {
+ // signalInfo is the signal info, if and when a signal is received.
+ signalInfo arch.SignalInfo
+
+ // interrupt is the interrupt context.
+ interrupt interrupt.Forwarder
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // If lastFaultSP is non-nil, the last context switch was due to a fault
+ // received while executing lastFaultSP. Only context.Switch may set
+ // lastFaultSP to a non-nil value.
+ lastFaultSP *subprocess
+
+ // lastFaultAddr is the last faulting address; this is only meaningful if
+ // lastFaultSP is non-nil.
+ lastFaultAddr usermem.Addr
+
+ // lastFaultIP is the address of the last faulting instruction;
+ // this is also only meaningful if lastFaultSP is non-nil.
+ lastFaultIP usermem.Addr
+}
+
+// Switch runs the provided context in the given address space.
+func (c *context) Switch(as platform.AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) {
+ s := as.(*subprocess)
+ isSyscall := s.switchToApp(c, ac)
+
+ var (
+ faultSP *subprocess
+ faultAddr usermem.Addr
+ faultIP usermem.Addr
+ )
+ if !isSyscall && linux.Signal(c.signalInfo.Signo) == linux.SIGSEGV {
+ faultSP = s
+ faultAddr = usermem.Addr(c.signalInfo.Addr())
+ faultIP = usermem.Addr(ac.IP())
+ }
+
+ // Update the context to reflect the outcome of this context switch.
+ c.mu.Lock()
+ lastFaultSP := c.lastFaultSP
+ lastFaultAddr := c.lastFaultAddr
+ lastFaultIP := c.lastFaultIP
+ // At this point, c may not yet be in s.contexts, so c.lastFaultSP won't be
+ // updated by s.Unmap(). This is fine; we only need to synchronize with
+ // calls to s.Unmap() that occur after the handling of this fault.
+ c.lastFaultSP = faultSP
+ c.lastFaultAddr = faultAddr
+ c.lastFaultIP = faultIP
+ c.mu.Unlock()
+
+ // Update subprocesses to reflect the outcome of this context switch.
+ if lastFaultSP != faultSP {
+ if lastFaultSP != nil {
+ lastFaultSP.mu.Lock()
+ delete(lastFaultSP.contexts, c)
+ lastFaultSP.mu.Unlock()
+ }
+ if faultSP != nil {
+ faultSP.mu.Lock()
+ faultSP.contexts[c] = struct{}{}
+ faultSP.mu.Unlock()
+ }
+ }
+
+ if isSyscall {
+ return nil, usermem.NoAccess, nil
+ }
+
+ si := c.signalInfo
+
+ if faultSP == nil {
+ // Non-fault signal.
+ return &si, usermem.NoAccess, platform.ErrContextSignal
+ }
+
+ // Got a page fault. Ideally, we'd get real fault type here, but ptrace
+ // doesn't expose this information. Instead, we use a simple heuristic:
+ //
+ // It was an instruction fault iff the faulting addr == instruction
+ // pointer.
+ //
+ // It was a write fault if the fault is immediately repeated.
+ at := usermem.Read
+ if faultAddr == faultIP {
+ at.Execute = true
+ }
+ if lastFaultSP == faultSP &&
+ lastFaultAddr == faultAddr &&
+ lastFaultIP == faultIP {
+ at.Write = true
+ }
+
+ // Unfortunately, we have to unilaterally return ErrContextSignalCPUID
+ // here, in case this fault was generated by a CPUID exception. There
+ // is no way to distinguish between CPUID-generated faults and regular
+ // page faults.
+ return &si, at, platform.ErrContextSignalCPUID
+}
+
+// Interrupt interrupts the running guest application associated with this context.
+func (c *context) Interrupt() {
+ c.interrupt.NotifyInterrupt()
+}
+
+// PTrace represents a collection of ptrace subprocesses.
+type PTrace struct {
+ platform.MMapMinAddr
+ platform.NoCPUPreemptionDetection
+}
+
+// New returns a new ptrace-based implementation of the platform interface.
+func New() (*PTrace, error) {
+ stubInitialized.Do(func() {
+ // Initialize the stub.
+ stubInit()
+
+ // Create the master process for the global pool. This must be
+ // done before initializing any other processes.
+ master, err := newSubprocess(createStub)
+ if err != nil {
+ // Should never happen.
+ panic("unable to initialize ptrace master: " + err.Error())
+ }
+
+ // Set the master on the globalPool.
+ globalPool.master = master
+ })
+
+ return &PTrace{}, nil
+}
+
+// SupportsAddressSpaceIO implements platform.Platform.SupportsAddressSpaceIO.
+func (*PTrace) SupportsAddressSpaceIO() bool {
+ return false
+}
+
+// CooperativelySchedulesAddressSpace implements platform.Platform.CooperativelySchedulesAddressSpace.
+func (*PTrace) CooperativelySchedulesAddressSpace() bool {
+ return false
+}
+
+// MapUnit implements platform.Platform.MapUnit.
+func (*PTrace) MapUnit() uint64 {
+ // The host kernel manages page tables and arbitrary-sized mappings
+ // have effectively the same cost.
+ return 0
+}
+
+// MaxUserAddress returns the first address that may not be used by user
+// applications.
+func (*PTrace) MaxUserAddress() usermem.Addr {
+ return usermem.Addr(stubStart)
+}
+
+// NewAddressSpace returns a new subprocess.
+func (p *PTrace) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
+ as, err := newSubprocess(globalPool.master.createStub)
+ return as, nil, err
+}
+
+// NewContext returns an interruptible context.
+func (*PTrace) NewContext() platform.Context {
+ return &context{}
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_state_autogen.go
new file mode 100755
index 000000000..ac83a71e7
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package ptrace
+
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
new file mode 100644
index 000000000..585f6c1fb
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -0,0 +1,166 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// GETREGSET/SETREGSET register set types.
+//
+// See include/uapi/linux/elf.h.
+const (
+ // _NT_PRFPREG is for x86 floating-point state without using xsave.
+ _NT_PRFPREG = 0x2
+
+ // _NT_X86_XSTATE is for x86 extended state using xsave.
+ _NT_X86_XSTATE = 0x202
+)
+
+// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
+func fpRegSet(useXsave bool) uintptr {
+ if useXsave {
+ return _NT_X86_XSTATE
+ }
+ return _NT_PRFPREG
+}
+
+// getRegs sets the regular register set.
+func (t *thread) getRegs(regs *syscall.PtraceRegs) error {
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETREGS,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(regs)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// setRegs sets the regular register set.
+func (t *thread) setRegs(regs *syscall.PtraceRegs) error {
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETREGS,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(regs)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// getFPRegs gets the floating-point data via the GETREGSET ptrace syscall.
+func (t *thread) getFPRegs(fpState *arch.FloatingPointData, fpLen uint64, useXsave bool) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(fpState),
+ Len: fpLen,
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETREGSET,
+ uintptr(t.tid),
+ fpRegSet(useXsave),
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// setFPRegs sets the floating-point data via the SETREGSET ptrace syscall.
+func (t *thread) setFPRegs(fpState *arch.FloatingPointData, fpLen uint64, useXsave bool) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(fpState),
+ Len: fpLen,
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETREGSET,
+ uintptr(t.tid),
+ fpRegSet(useXsave),
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// getSignalInfo retrieves information about the signal that caused the stop.
+func (t *thread) getSignalInfo(si *arch.SignalInfo) error {
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETSIGINFO,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(si)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// clone creates a new thread from this one.
+//
+// The returned thread will be stopped and available for any system thread to
+// call attach on it.
+//
+// Precondition: the OS thread must be locked and own t.
+func (t *thread) clone() (*thread, error) {
+ r, ok := usermem.Addr(t.initRegs.Rsp).RoundUp()
+ if !ok {
+ return nil, syscall.EINVAL
+ }
+ rval, err := t.syscallIgnoreInterrupt(
+ &t.initRegs,
+ syscall.SYS_CLONE,
+ arch.SyscallArgument{Value: uintptr(
+ syscall.CLONE_FILES |
+ syscall.CLONE_FS |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_THREAD |
+ syscall.CLONE_PTRACE |
+ syscall.CLONE_VM)},
+ // The stack pointer is just made up, but we have it be
+ // something sensible so the kernel doesn't think we're
+ // up to no good. Which we are.
+ arch.SyscallArgument{Value: uintptr(r)},
+ arch.SyscallArgument{},
+ arch.SyscallArgument{},
+ // We use these registers initially, but really they
+ // could be anything. We're going to stop immediately.
+ arch.SyscallArgument{Value: uintptr(unsafe.Pointer(&t.initRegs))})
+ if err != nil {
+ return nil, err
+ }
+
+ return &thread{
+ tgid: t.tgid,
+ tid: int32(rval),
+ cpu: ^uint32(0),
+ }, nil
+}
diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s
new file mode 100644
index 000000000..64c718d21
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/stub_amd64.s
@@ -0,0 +1,114 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define SYS_GETPID 39
+#define SYS_EXIT 60
+#define SYS_KILL 62
+#define SYS_GETPPID 110
+#define SYS_PRCTL 157
+
+#define SIGKILL 9
+#define SIGSTOP 19
+
+#define PR_SET_PDEATHSIG 1
+
+// stub bootstraps the child and sends itself SIGSTOP to wait for attach.
+//
+// R15 contains the expected PPID. R15 is used instead of a more typical DI
+// since syscalls will clobber DI and createStub wants to pass a new PPID to
+// grandchildren.
+//
+// This should not be used outside the context of a new ptrace child (as the
+// function is otherwise a bunch of nonsense).
+TEXT ·stub(SB),NOSPLIT,$0
+begin:
+ // N.B. This loop only executes in the context of a single-threaded
+ // fork child.
+
+ MOVQ $SYS_PRCTL, AX
+ MOVQ $PR_SET_PDEATHSIG, DI
+ MOVQ $SIGKILL, SI
+ SYSCALL
+
+ CMPQ AX, $0
+ JNE error
+
+ // If the parent already died before we called PR_SET_DEATHSIG then
+ // we'll have an unexpected PPID.
+ MOVQ $SYS_GETPPID, AX
+ SYSCALL
+
+ CMPQ AX, $0
+ JL error
+
+ CMPQ AX, R15
+ JNE parent_dead
+
+ MOVQ $SYS_GETPID, AX
+ SYSCALL
+
+ CMPQ AX, $0
+ JL error
+
+ // SIGSTOP to wait for attach.
+ //
+ // The SYSCALL instruction will be used for future syscall injection by
+ // thread.syscall.
+ MOVQ AX, DI
+ MOVQ $SYS_KILL, AX
+ MOVQ $SIGSTOP, SI
+ SYSCALL
+
+ // The tracer may "detach" and/or allow code execution here in three cases:
+ //
+ // 1. New (traced) stub threads are explicitly detached by the
+ // goroutine in newSubprocess. However, they are detached while in
+ // group-stop, so they do not execute code here.
+ //
+ // 2. If a tracer thread exits, it implicitly detaches from the stub,
+ // potentially allowing code execution here. However, the Go runtime
+ // never exits individual threads, so this case never occurs.
+ //
+ // 3. subprocess.createStub clones a new stub process that is untraced,
+ // thus executing this code. We setup the PDEATHSIG before SIGSTOPing
+ // ourselves for attach by the tracer.
+ //
+ // R15 has been updated with the expected PPID.
+ JMP begin
+
+error:
+ // Exit with -errno.
+ MOVQ AX, DI
+ NEGQ DI
+ MOVQ $SYS_EXIT, AX
+ SYSCALL
+ HLT
+
+parent_dead:
+ MOVQ $SYS_EXIT, AX
+ MOVQ $1, DI
+ SYSCALL
+ HLT
+
+// stubCall calls the stub function at the given address with the given PPID.
+//
+// This is a distinct function because stub, above, may be mapped at any
+// arbitrary location, and stub has a specific binary API (see above).
+TEXT ·stubCall(SB),NOSPLIT,$0-16
+ MOVQ addr+0(FP), AX
+ MOVQ pid+8(FP), R15
+ JMP AX
diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go
new file mode 100644
index 000000000..54d5021a9
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/stub_unsafe.go
@@ -0,0 +1,98 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "reflect"
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// stub is defined in arch-specific assembly.
+func stub()
+
+// stubCall calls the stub at the given address with the given pid.
+func stubCall(addr, pid uintptr)
+
+// unsafeSlice returns a slice for the given address and length.
+func unsafeSlice(addr uintptr, length int) (slice []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ sh.Data = addr
+ sh.Len = length
+ sh.Cap = length
+ return
+}
+
+// stubInit initializes the stub.
+func stubInit() {
+ // Grab the existing stub.
+ stubBegin := reflect.ValueOf(stub).Pointer()
+ stubLen := int(safecopy.FindEndAddress(stubBegin) - stubBegin)
+ stubSlice := unsafeSlice(stubBegin, stubLen)
+ mapLen := uintptr(stubLen)
+ if offset := mapLen % usermem.PageSize; offset != 0 {
+ mapLen += usermem.PageSize - offset
+ }
+
+ for stubStart > 0 {
+ // Map the target address for the stub.
+ //
+ // We don't use FIXED here because we don't want to unmap
+ // something that may have been there already. We just walk
+ // down the address space until we find a place where the stub
+ // can be placed.
+ addr, _, errno := syscall.RawSyscall6(
+ syscall.SYS_MMAP,
+ stubStart,
+ mapLen,
+ syscall.PROT_WRITE|syscall.PROT_READ,
+ syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS,
+ 0 /* fd */, 0 /* offset */)
+ if addr != stubStart || errno != 0 {
+ if addr != 0 {
+ // Unmap the region we've mapped accidentally.
+ syscall.RawSyscall(syscall.SYS_MUNMAP, addr, mapLen, 0)
+ }
+
+ // Attempt to begin at a lower address.
+ stubStart -= uintptr(usermem.PageSize)
+ continue
+ }
+
+ // Copy the stub to the address.
+ targetSlice := unsafeSlice(addr, stubLen)
+ copy(targetSlice, stubSlice)
+
+ // Make the stub executable.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_MPROTECT,
+ stubStart,
+ mapLen,
+ syscall.PROT_EXEC|syscall.PROT_READ); errno != 0 {
+ panic("mprotect failed: " + errno.Error())
+ }
+
+ // Set the end.
+ stubEnd = stubStart + mapLen
+ return
+ }
+
+ // This will happen only if we exhaust the entire address
+ // space, and it will take a long, long time.
+ panic("failed to map stub")
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
new file mode 100644
index 000000000..83b43057f
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -0,0 +1,610 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "fmt"
+ "os"
+ "runtime"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// globalPool exists to solve two distinct problems:
+//
+// 1) Subprocesses can't always be killed properly (see Release).
+//
+// 2) Any seccomp filters that have been installed will apply to subprocesses
+// created here. Therefore we use the intermediary (master), which is created
+// on initialization of the platform.
+var globalPool struct {
+ mu sync.Mutex
+ master *subprocess
+ available []*subprocess
+}
+
+// thread is a traced thread; it is a thread identifier.
+//
+// This is a convenience type for defining ptrace operations.
+type thread struct {
+ tgid int32
+ tid int32
+ cpu uint32
+
+ // initRegs are the initial registers for the first thread.
+ //
+ // These are used for the register set for system calls.
+ initRegs syscall.PtraceRegs
+}
+
+// threadPool is a collection of threads.
+type threadPool struct {
+ // mu protects below.
+ mu sync.Mutex
+
+ // threads is the collection of threads.
+ //
+ // This map is indexed by system TID (the calling thread); which will
+ // be the tracer for the given *thread, and therefore capable of using
+ // relevant ptrace calls.
+ threads map[int32]*thread
+}
+
+// lookupOrCreate looks up a given thread or creates one.
+//
+// newThread will generally be subprocess.newThread.
+//
+// Precondition: the runtime OS thread must be locked.
+func (tp *threadPool) lookupOrCreate(currentTID int32, newThread func() *thread) *thread {
+ tp.mu.Lock()
+ t, ok := tp.threads[currentTID]
+ if !ok {
+ // Before creating a new thread, see if we can find a thread
+ // whose system tid has disappeared.
+ //
+ // TODO(b/77216482): Other parts of this package depend on
+ // threads never exiting.
+ for origTID, t := range tp.threads {
+ // Signal zero is an easy existence check.
+ if err := syscall.Tgkill(syscall.Getpid(), int(origTID), 0); err != nil {
+ // This thread has been abandoned; reuse it.
+ delete(tp.threads, origTID)
+ tp.threads[currentTID] = t
+ tp.mu.Unlock()
+ return t
+ }
+ }
+
+ // Create a new thread.
+ t = newThread()
+ tp.threads[currentTID] = t
+ }
+ tp.mu.Unlock()
+ return t
+}
+
+// subprocess is a collection of threads being traced.
+type subprocess struct {
+ platform.NoAddressSpaceIO
+
+ // requests is used to signal creation of new threads.
+ requests chan chan *thread
+
+ // sysemuThreads are reserved for emulation.
+ sysemuThreads threadPool
+
+ // syscallThreads are reserved for syscalls (except clone, which is
+ // handled in the dedicated goroutine corresponding to requests above).
+ syscallThreads threadPool
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // contexts is the set of contexts for which it's possible that
+ // context.lastFaultSP == this subprocess.
+ contexts map[*context]struct{}
+}
+
+// newSubprocess returns a useable subprocess.
+//
+// This will either be a newly created subprocess, or one from the global pool.
+// The create function will be called in the latter case, which is guaranteed
+// to happen with the runtime thread locked.
+func newSubprocess(create func() (*thread, error)) (*subprocess, error) {
+ // See Release.
+ globalPool.mu.Lock()
+ if len(globalPool.available) > 0 {
+ sp := globalPool.available[len(globalPool.available)-1]
+ globalPool.available = globalPool.available[:len(globalPool.available)-1]
+ globalPool.mu.Unlock()
+ return sp, nil
+ }
+ globalPool.mu.Unlock()
+
+ // The following goroutine is responsible for creating the first traced
+ // thread, and responding to requests to make additional threads in the
+ // traced process. The process will be killed and reaped when the
+ // request channel is closed, which happens in Release below.
+ errChan := make(chan error)
+ requests := make(chan chan *thread)
+ go func() { // S/R-SAFE: Platform-related.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ // Initialize the first thread.
+ firstThread, err := create()
+ if err != nil {
+ errChan <- err
+ return
+ }
+
+ // Ready to handle requests.
+ errChan <- nil
+
+ // Wait for requests to create threads.
+ for r := range requests {
+ t, err := firstThread.clone()
+ if err != nil {
+ // Should not happen: not recoverable.
+ panic(fmt.Sprintf("error initializing first thread: %v", err))
+ }
+
+ // Since the new thread was created with
+ // clone(CLONE_PTRACE), it will begin execution with
+ // SIGSTOP pending and with this thread as its tracer.
+ // (Hopefully nobody tgkilled it with a signal <
+ // SIGSTOP before the SIGSTOP was delivered, in which
+ // case that signal would be delivered before SIGSTOP.)
+ if sig := t.wait(stopped); sig != syscall.SIGSTOP {
+ panic(fmt.Sprintf("error waiting for new clone: expected SIGSTOP, got %v", sig))
+ }
+
+ // Detach the thread.
+ t.detach()
+
+ // Return the thread.
+ r <- t
+ }
+
+ // Requests should never be closed.
+ panic("unreachable")
+ }()
+
+ // Wait until error or readiness.
+ if err := <-errChan; err != nil {
+ return nil, err
+ }
+
+ // Ready.
+ sp := &subprocess{
+ requests: requests,
+ sysemuThreads: threadPool{
+ threads: make(map[int32]*thread),
+ },
+ syscallThreads: threadPool{
+ threads: make(map[int32]*thread),
+ },
+ contexts: make(map[*context]struct{}),
+ }
+
+ sp.unmap()
+ return sp, nil
+}
+
+// unmap unmaps non-stub regions of the process.
+//
+// This will panic on failure (which should never happen).
+func (s *subprocess) unmap() {
+ s.Unmap(0, uint64(stubStart))
+ if maximumUserAddress != stubEnd {
+ s.Unmap(usermem.Addr(stubEnd), uint64(maximumUserAddress-stubEnd))
+ }
+}
+
+// Release kills the subprocess.
+//
+// Just kidding! We can't safely co-ordinate the detaching of all the
+// tracees (since the tracers are random runtime threads, and the process
+// won't exit until tracers have been notifier).
+//
+// Therefore we simply unmap everything in the subprocess and return it to the
+// globalPool. This has the added benefit of reducing creation time for new
+// subprocesses.
+func (s *subprocess) Release() {
+ go func() { // S/R-SAFE: Platform.
+ s.unmap()
+ globalPool.mu.Lock()
+ globalPool.available = append(globalPool.available, s)
+ globalPool.mu.Unlock()
+ }()
+}
+
+// newThread creates a new traced thread.
+//
+// Precondition: the OS thread must be locked.
+func (s *subprocess) newThread() *thread {
+ // Ask the first thread to create a new one.
+ r := make(chan *thread)
+ s.requests <- r
+ t := <-r
+
+ // Attach the subprocess to this one.
+ t.attach()
+
+ // Return the new thread, which is now bound.
+ return t
+}
+
+// attach attachs to the thread.
+func (t *thread) attach() {
+ if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 {
+ panic(fmt.Sprintf("unable to attach: %v", errno))
+ }
+
+ // PTRACE_ATTACH sends SIGSTOP, and wakes the tracee if it was already
+ // stopped from the SIGSTOP queued by CLONE_PTRACE (see inner loop of
+ // newSubprocess), so we always expect to see signal-delivery-stop with
+ // SIGSTOP.
+ if sig := t.wait(stopped); sig != syscall.SIGSTOP {
+ panic(fmt.Sprintf("wait failed: expected SIGSTOP, got %v", sig))
+ }
+
+ // Initialize options.
+ t.init()
+
+ // Grab registers.
+ //
+ // Note that we adjust the current register RIP value to be just before
+ // the current system call executed. This depends on the definition of
+ // the stub itself.
+ if err := t.getRegs(&t.initRegs); err != nil {
+ panic(fmt.Sprintf("ptrace get regs failed: %v", err))
+ }
+ t.initRegs.Rip -= initRegsRipAdjustment
+}
+
+// detach detachs from the thread.
+//
+// Because the SIGSTOP is not supressed, the thread will enter group-stop.
+func (t *thread) detach() {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_DETACH, uintptr(t.tid), 0, uintptr(syscall.SIGSTOP), 0, 0); errno != 0 {
+ panic(fmt.Sprintf("can't detach new clone: %v", errno))
+ }
+}
+
+// waitOutcome is used for wait below.
+type waitOutcome int
+
+const (
+ // stopped indicates that the process was stopped.
+ stopped waitOutcome = iota
+
+ // killed indicates that the process was killed.
+ killed
+)
+
+// wait waits for a stop event.
+//
+// Precondition: outcome is a valid waitOutcome.
+func (t *thread) wait(outcome waitOutcome) syscall.Signal {
+ var status syscall.WaitStatus
+
+ for {
+ r, err := syscall.Wait4(int(t.tid), &status, syscall.WALL|syscall.WUNTRACED, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ // Wait was interrupted; wait again.
+ continue
+ } else if err != nil {
+ panic(fmt.Sprintf("ptrace wait failed: %v", err))
+ }
+ if int(r) != int(t.tid) {
+ panic(fmt.Sprintf("ptrace wait returned %v, expected %v", r, t.tid))
+ }
+ switch outcome {
+ case stopped:
+ if !status.Stopped() {
+ panic(fmt.Sprintf("ptrace status unexpected: got %v, wanted stopped", status))
+ }
+ stopSig := status.StopSignal()
+ if stopSig == 0 {
+ continue // Spurious stop.
+ }
+ if stopSig == syscall.SIGTRAP {
+ // Re-encode the trap cause the way it's expected.
+ return stopSig | syscall.Signal(status.TrapCause()<<8)
+ }
+ // Not a trap signal.
+ return stopSig
+ case killed:
+ if !status.Exited() && !status.Signaled() {
+ panic(fmt.Sprintf("ptrace status unexpected: got %v, wanted exited", status))
+ }
+ return syscall.Signal(status.ExitStatus())
+ default:
+ // Should not happen.
+ panic(fmt.Sprintf("unknown outcome: %v", outcome))
+ }
+ }
+}
+
+// destroy kills the thread.
+//
+// Note that this should not be used in the general case; the death of threads
+// will typically cause the death of the parent. This is a utility method for
+// manually created threads.
+func (t *thread) destroy() {
+ t.detach()
+ syscall.Tgkill(int(t.tgid), int(t.tid), syscall.Signal(syscall.SIGKILL))
+ t.wait(killed)
+}
+
+// init initializes trace options.
+func (t *thread) init() {
+ // Set our TRACESYSGOOD option to differeniate real SIGTRAP.
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETOPTIONS,
+ uintptr(t.tid),
+ 0,
+ syscall.PTRACE_O_TRACESYSGOOD,
+ 0, 0)
+ if errno != 0 {
+ panic(fmt.Sprintf("ptrace set options failed: %v", errno))
+ }
+}
+
+// syscall executes a system call cycle in the traced context.
+//
+// This is _not_ for use by application system calls, rather it is for use when
+// a system call must be injected into the remote context (e.g. mmap, munmap).
+// Note that clones are handled separately.
+func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
+ // Set registers.
+ if err := t.setRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs failed: %v", err))
+ }
+
+ for {
+ // Execute the syscall instruction.
+ if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ sig := t.wait(stopped)
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Reached syscall-enter-stop.
+ break
+ } else {
+ // Some other signal caused a thread stop; ignore.
+ continue
+ }
+ }
+
+ // Complete the actual system call.
+ if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ // Wait for syscall-exit-stop. "[Signal-delivery-stop] never happens
+ // between syscall-enter-stop and syscall-exit-stop; it happens *after*
+ // syscall-exit-stop.)" - ptrace(2), "Syscall-stops"
+ if sig := t.wait(stopped); sig != (syscallEvent | syscall.SIGTRAP) {
+ panic(fmt.Sprintf("wait failed: expected SIGTRAP, got %v [%d]", sig, sig))
+ }
+
+ // Grab registers.
+ if err := t.getRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace get regs failed: %v", err))
+ }
+
+ return syscallReturnValue(regs)
+}
+
+// syscallIgnoreInterrupt ignores interrupts on the system call thread and
+// restarts the syscall if the kernel indicates that should happen.
+func (t *thread) syscallIgnoreInterrupt(
+ initRegs *syscall.PtraceRegs,
+ sysno uintptr,
+ args ...arch.SyscallArgument) (uintptr, error) {
+ for {
+ regs := createSyscallRegs(initRegs, sysno, args...)
+ rval, err := t.syscall(&regs)
+ switch err {
+ case ERESTARTSYS:
+ continue
+ case ERESTARTNOINTR:
+ continue
+ case ERESTARTNOHAND:
+ continue
+ default:
+ return rval, err
+ }
+ }
+}
+
+// NotifyInterrupt implements interrupt.Receiver.NotifyInterrupt.
+func (t *thread) NotifyInterrupt() {
+ syscall.Tgkill(int(t.tgid), int(t.tid), syscall.Signal(platform.SignalInterrupt))
+}
+
+// switchToApp is called from the main SwitchToApp entrypoint.
+//
+// This function returns true on a system call, false on a signal.
+func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
+ // Lock the thread for ptrace operations.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ // Extract floating point state.
+ fpState := ac.FloatingPointData()
+ fpLen, _ := ac.FeatureSet().ExtendedStateSize()
+ useXsave := ac.FeatureSet().UseXsave()
+
+ // Grab our thread from the pool.
+ currentTID := int32(procid.Current())
+ t := s.sysemuThreads.lookupOrCreate(currentTID, s.newThread)
+
+ // Reset necessary registers.
+ regs := &ac.StateData().Regs
+ t.resetSysemuRegs(regs)
+
+ // Check for interrupts, and ensure that future interrupts will signal t.
+ if !c.interrupt.Enable(t) {
+ // Pending interrupt; simulate.
+ c.signalInfo = arch.SignalInfo{Signo: int32(platform.SignalInterrupt)}
+ return false
+ }
+ defer c.interrupt.Disable()
+
+ // Ensure that the CPU set is bound appropriately; this makes the
+ // emulation below several times faster, presumably by avoiding
+ // interprocessor wakeups and by simplifying the schedule.
+ t.bind()
+
+ // Set registers.
+ if err := t.setRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs (%+v) failed: %v", regs, err))
+ }
+ if err := t.setFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
+ panic(fmt.Sprintf("ptrace set fpregs (%+v) failed: %v", fpState, err))
+ }
+
+ for {
+ // Start running until the next system call.
+ if isSingleStepping(regs) {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SYSEMU_SINGLESTEP,
+ uintptr(t.tid), 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
+ }
+ } else {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SYSEMU,
+ uintptr(t.tid), 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
+ }
+ }
+
+ // Wait for the syscall-enter stop.
+ sig := t.wait(stopped)
+
+ // Refresh all registers.
+ if err := t.getRegs(regs); err != nil {
+ panic(fmt.Sprintf("ptrace get regs failed: %v", err))
+ }
+ if err := t.getFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
+ panic(fmt.Sprintf("ptrace get fpregs failed: %v", err))
+ }
+
+ // Is it a system call?
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Ensure registers are sane.
+ updateSyscallRegs(regs)
+ return true
+ } else if sig == syscall.SIGSTOP {
+ // SIGSTOP was delivered to another thread in the same thread
+ // group, which initiated another group stop. Just ignore it.
+ continue
+ }
+
+ // Grab signal information.
+ if err := t.getSignalInfo(&c.signalInfo); err != nil {
+ // Should never happen.
+ panic(fmt.Sprintf("ptrace get signal info failed: %v", err))
+ }
+
+ // We have a signal. We verify however, that the signal was
+ // either delivered from the kernel or from this process. We
+ // don't respect other signals.
+ if c.signalInfo.Code > 0 {
+ // The signal was generated by the kernel. We inspect
+ // the signal information, and may patch it in order to
+ // faciliate vsyscall emulation. See patchSignalInfo.
+ patchSignalInfo(regs, &c.signalInfo)
+ return false
+ } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) {
+ // The signal was generated by this process. That means
+ // that it was an interrupt or something else that we
+ // should bail for. Note that we ignore signals
+ // generated by other processes.
+ return false
+ }
+ }
+}
+
+// syscall executes the given system call without handling interruptions.
+func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintptr, error) {
+ // Grab a thread.
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+ currentTID := int32(procid.Current())
+ t := s.syscallThreads.lookupOrCreate(currentTID, s.newThread)
+
+ return t.syscallIgnoreInterrupt(&t.initRegs, sysno, args...)
+}
+
+// MapFile implements platform.AddressSpace.MapFile.
+func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+ var flags int
+ if precommit {
+ flags |= syscall.MAP_POPULATE
+ }
+ _, err := s.syscall(
+ syscall.SYS_MMAP,
+ arch.SyscallArgument{Value: uintptr(addr)},
+ arch.SyscallArgument{Value: uintptr(fr.Length())},
+ arch.SyscallArgument{Value: uintptr(at.Prot())},
+ arch.SyscallArgument{Value: uintptr(flags | syscall.MAP_SHARED | syscall.MAP_FIXED)},
+ arch.SyscallArgument{Value: uintptr(f.FD())},
+ arch.SyscallArgument{Value: uintptr(fr.Start)})
+ return err
+}
+
+// Unmap implements platform.AddressSpace.Unmap.
+func (s *subprocess) Unmap(addr usermem.Addr, length uint64) {
+ ar, ok := addr.ToRange(length)
+ if !ok {
+ panic(fmt.Sprintf("addr %#x + length %#x overflows", addr, length))
+ }
+ s.mu.Lock()
+ for c := range s.contexts {
+ c.mu.Lock()
+ if c.lastFaultSP == s && ar.Contains(c.lastFaultAddr) {
+ // Forget the last fault so that if c faults again, the fault isn't
+ // incorrectly reported as a write fault. If this is being called
+ // due to munmap() of the corresponding vma, handling of the second
+ // fault will fail anyway.
+ c.lastFaultSP = nil
+ delete(s.contexts, c)
+ }
+ c.mu.Unlock()
+ }
+ s.mu.Unlock()
+ _, err := s.syscall(
+ syscall.SYS_MUNMAP,
+ arch.SyscallArgument{Value: uintptr(addr)},
+ arch.SyscallArgument{Value: uintptr(length)})
+ if err != nil {
+ // We never expect this to happen.
+ panic(fmt.Sprintf("munmap(%x, %x)) failed: %v", addr, length, err))
+ }
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
new file mode 100644
index 000000000..77a0e908f
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -0,0 +1,104 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ptrace
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+)
+
+const (
+ // maximumUserAddress is the largest possible user address.
+ maximumUserAddress = 0x7ffffffff000
+
+ // initRegsRipAdjustment is the size of the syscall instruction.
+ initRegsRipAdjustment = 2
+)
+
+// Linux kernel errnos which "should never be seen by user programs", but will
+// be revealed to ptrace syscall exit tracing.
+//
+// These constants are used in subprocess.go.
+const (
+ ERESTARTSYS = syscall.Errno(512)
+ ERESTARTNOINTR = syscall.Errno(513)
+ ERESTARTNOHAND = syscall.Errno(514)
+)
+
+// resetSysemuRegs sets up emulation registers.
+//
+// This should be called prior to calling sysemu.
+func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) {
+ regs.Cs = t.initRegs.Cs
+ regs.Ss = t.initRegs.Ss
+ regs.Ds = t.initRegs.Ds
+ regs.Es = t.initRegs.Es
+ regs.Fs = t.initRegs.Fs
+ regs.Gs = t.initRegs.Gs
+}
+
+// createSyscallRegs sets up syscall registers.
+//
+// This should be called to generate registers for a system call.
+func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs {
+ // Copy initial registers.
+ regs := *initRegs
+
+ // Set our syscall number.
+ regs.Rax = uint64(sysno)
+ if len(args) >= 1 {
+ regs.Rdi = args[0].Uint64()
+ }
+ if len(args) >= 2 {
+ regs.Rsi = args[1].Uint64()
+ }
+ if len(args) >= 3 {
+ regs.Rdx = args[2].Uint64()
+ }
+ if len(args) >= 4 {
+ regs.R10 = args[3].Uint64()
+ }
+ if len(args) >= 5 {
+ regs.R8 = args[4].Uint64()
+ }
+ if len(args) >= 6 {
+ regs.R9 = args[5].Uint64()
+ }
+
+ return regs
+}
+
+// isSingleStepping determines if the registers indicate single-stepping.
+func isSingleStepping(regs *syscall.PtraceRegs) bool {
+ return (regs.Eflags & arch.X86TrapFlag) != 0
+}
+
+// updateSyscallRegs updates registers after finishing sysemu.
+func updateSyscallRegs(regs *syscall.PtraceRegs) {
+ // Ptrace puts -ENOSYS in rax on syscall-enter-stop.
+ regs.Rax = regs.Orig_rax
+}
+
+// syscallReturnValue extracts a sensible return from registers.
+func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
+ rval := int64(regs.Rax)
+ if rval < 0 {
+ return 0, syscall.Errno(-rval)
+ }
+ return uintptr(rval), nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
new file mode 100644
index 000000000..2c07b4ac3
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -0,0 +1,338 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package ptrace
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/seccomp"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid"
+)
+
+const syscallEvent syscall.Signal = 0x80
+
+// probeSeccomp returns true iff seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8. This check is dynamic
+// because kernels have be backported behavior.
+//
+// See createStub for more information.
+//
+// Precondition: the runtime OS thread must be locked.
+func probeSeccomp() bool {
+ // Create a completely new, destroyable process.
+ t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
+ if err != nil {
+ panic(fmt.Sprintf("seccomp probe failed: %v", err))
+ }
+ defer t.destroy()
+
+ // Set registers to the yield system call. This call is not allowed
+ // by the filters specified in the attachThread function.
+ regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
+ if err := t.setRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs failed: %v", err))
+ }
+
+ for {
+ // Attempt an emulation.
+ if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ sig := t.wait(stopped)
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Did the seccomp errno hook already run? This would
+ // indicate that seccomp is first in line and we're
+ // less than 4.8.
+ if err := t.getRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
+ }
+ if _, err := syscallReturnValue(&regs); err == nil {
+ // The seccomp errno mode ran first, and reset
+ // the error in the registers.
+ return false
+ }
+ // The seccomp hook did not run yet, and therefore it
+ // is safe to use RET_KILL mode for dispatched calls.
+ return true
+ }
+ }
+}
+
+// patchSignalInfo patches the signal info to account for hitting the seccomp
+// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
+// synchronous trap, but patch the structure to appear like a SIGSEGV with the
+// Rip as the faulting address.
+//
+// Note that this should only be called after verifying that the signalInfo has
+// been generated by the kernel.
+func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) {
+ if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
+ signalInfo.Signo = int32(linux.SIGSEGV)
+
+ // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
+ // with the si_call_addr field pointing to the current RIP. This field
+ // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
+ // anything there. We do need to unwind emulation however, so we set the
+ // instruction pointer to the faulting value, and "unpop" the stack.
+ regs.Rip = signalInfo.Addr()
+ regs.Rsp -= 8
+ }
+}
+
+// createStub creates a fresh stub processes.
+//
+// Precondition: the runtime OS thread must be locked.
+func createStub() (*thread, error) {
+ // The exact interactions of ptrace and seccomp are complex, and
+ // changed in recent kernel versions. Before commit 93e35efb8de45, the
+ // seccomp check is done before the ptrace emulation check. This means
+ // that any calls not matching this list will trigger the seccomp
+ // default action instead of notifying ptrace.
+ //
+ // After commit 93e35efb8de45, the seccomp check is done after the
+ // ptrace emulation check. This simplifies using SYSEMU, since seccomp
+ // will never run for emulation. Seccomp will only run for injected
+ // system calls, and thus we can use RET_KILL as our violation action.
+ var defaultAction linux.BPFAction
+ if probeSeccomp() {
+ log.Infof("Latest seccomp behavior found (kernel >= 4.8 likely)")
+ defaultAction = linux.SECCOMP_RET_KILL_THREAD
+ } else {
+ // We must rely on SYSEMU behavior; tracing with SYSEMU is broken.
+ log.Infof("Legacy seccomp behavior found (kernel < 4.8 likely)")
+ defaultAction = linux.SECCOMP_RET_ALLOW
+ }
+
+ // When creating the new child process, we specify SIGKILL as the
+ // signal to deliver when the child exits. We never expect a subprocess
+ // to exit; they are pooled and reused. This is done to ensure that if
+ // a subprocess is OOM-killed, this process (and all other stubs,
+ // transitively) will be killed as well. It's simply not possible to
+ // safely handle a single stub getting killed: the exact state of
+ // execution is unknown and not recoverable.
+ return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction)
+}
+
+// attachedThread returns a new attached thread.
+//
+// Precondition: the runtime OS thread must be locked.
+func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, error) {
+ // Create a BPF program that allows only the system calls needed by the
+ // stub and all its children. This is used to create child stubs
+ // (below), so we must include the ability to fork, but otherwise lock
+ // down available calls only to what is needed.
+ rules := []seccomp.RuleSet{
+ // Rules for trapping vsyscall access.
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_GETTIMEOFDAY: {},
+ syscall.SYS_TIME: {},
+ 309: {}, // SYS_GETCPU.
+ },
+ Action: linux.SECCOMP_RET_TRAP,
+ Vsyscall: true,
+ },
+ }
+ if defaultAction != linux.SECCOMP_RET_ALLOW {
+ rules = append(rules, seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_CLONE: []seccomp.Rule{
+ // Allow creation of new subprocesses (used by the master).
+ {seccomp.AllowValue(syscall.CLONE_FILES | syscall.SIGKILL)},
+ // Allow creation of new threads within a single address space (used by addresss spaces).
+ {seccomp.AllowValue(
+ syscall.CLONE_FILES |
+ syscall.CLONE_FS |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_THREAD |
+ syscall.CLONE_PTRACE |
+ syscall.CLONE_VM)},
+ },
+
+ // For the initial process creation.
+ syscall.SYS_WAIT4: {},
+ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ },
+ syscall.SYS_EXIT: {},
+
+ // For the stub prctl dance (all).
+ syscall.SYS_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(syscall.PR_SET_PDEATHSIG), seccomp.AllowValue(syscall.SIGKILL)},
+ },
+ syscall.SYS_GETPPID: {},
+
+ // For the stub to stop itself (all).
+ syscall.SYS_GETPID: {},
+ syscall.SYS_KILL: []seccomp.Rule{
+ {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SIGSTOP)},
+ },
+
+ // Injected to support the address space operations.
+ syscall.SYS_MMAP: {},
+ syscall.SYS_MUNMAP: {},
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ })
+ }
+ instrs, err := seccomp.BuildProgram(rules, defaultAction)
+ if err != nil {
+ return nil, err
+ }
+
+ // Declare all variables up front in order to ensure that there's no
+ // need for allocations between beforeFork & afterFork.
+ var (
+ pid uintptr
+ ppid uintptr
+ errno syscall.Errno
+ )
+
+ // Remember the current ppid for the pdeathsig race.
+ ppid, _, _ = syscall.RawSyscall(syscall.SYS_GETPID, 0, 0, 0)
+
+ // Among other things, beforeFork masks all signals.
+ beforeFork()
+
+ // Do the clone.
+ pid, _, errno = syscall.RawSyscall6(syscall.SYS_CLONE, flags, 0, 0, 0, 0, 0)
+ if errno != 0 {
+ afterFork()
+ return nil, errno
+ }
+
+ // Is this the parent?
+ if pid != 0 {
+ // Among other things, restore signal mask.
+ afterFork()
+
+ // Initialize the first thread.
+ t := &thread{
+ tgid: int32(pid),
+ tid: int32(pid),
+ cpu: ^uint32(0),
+ }
+ if sig := t.wait(stopped); sig != syscall.SIGSTOP {
+ return nil, fmt.Errorf("wait failed: expected SIGSTOP, got %v", sig)
+ }
+ t.attach()
+
+ return t, nil
+ }
+
+ // Move the stub to a new session (and thus a new process group). This
+ // prevents the stub from getting PTY job control signals intended only
+ // for the sentry process. We must call this before restoring signal
+ // mask.
+ if _, _, errno := syscall.RawSyscall(syscall.SYS_SETSID, 0, 0, 0); errno != 0 {
+ syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0)
+ }
+
+ // afterForkInChild resets all signals to their default dispositions
+ // and restores the signal mask to its pre-fork state.
+ afterForkInChild()
+
+ // Explicitly unmask all signals to ensure that the tracer can see
+ // them.
+ if errno := unmaskAllSignals(); errno != 0 {
+ syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0)
+ }
+
+ // Set an aggressive BPF filter for the stub and all it's children. See
+ // the description of the BPF program built above.
+ if errno := seccomp.SetFilter(instrs); errno != 0 {
+ syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0)
+ }
+
+ // Enable cpuid-faulting; this may fail on older kernels or hardware,
+ // so we just disregard the result. Host CPUID will be enabled.
+ syscall.RawSyscall(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0)
+
+ // Call the stub; should not return.
+ stubCall(stubStart, ppid)
+ panic("unreachable")
+}
+
+// createStub creates a stub processes as a child of an existing subprocesses.
+//
+// Precondition: the runtime OS thread must be locked.
+func (s *subprocess) createStub() (*thread, error) {
+ // There's no need to lock the runtime thread here, as this can only be
+ // called from a context that is already locked.
+ currentTID := int32(procid.Current())
+ t := s.syscallThreads.lookupOrCreate(currentTID, s.newThread)
+
+ // Pass the expected PPID to the child via R15.
+ regs := t.initRegs
+ regs.R15 = uint64(t.tgid)
+
+ // Call fork in a subprocess.
+ //
+ // The new child must set up PDEATHSIG to ensure it dies if this
+ // process dies. Since this process could die at any time, this cannot
+ // be done via instrumentation from here.
+ //
+ // Instead, we create the child untraced, which will do the PDEATHSIG
+ // setup and then SIGSTOP itself for our attach below.
+ //
+ // See above re: SIGKILL.
+ pid, err := t.syscallIgnoreInterrupt(
+ &regs,
+ syscall.SYS_CLONE,
+ arch.SyscallArgument{Value: uintptr(syscall.SIGKILL | syscall.CLONE_FILES)},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0})
+ if err != nil {
+ return nil, err
+ }
+
+ // Wait for child to enter group-stop, so we don't stop its
+ // bootstrapping work with t.attach below.
+ //
+ // We unfortunately don't have a handy part of memory to write the wait
+ // status. If the wait succeeds, we'll assume that it was the SIGSTOP.
+ // If the child actually exited, the attach below will fail.
+ _, err = t.syscallIgnoreInterrupt(
+ &t.initRegs,
+ syscall.SYS_WAIT4,
+ arch.SyscallArgument{Value: uintptr(pid)},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: syscall.WALL | syscall.WUNTRACED},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0},
+ arch.SyscallArgument{Value: 0})
+ if err != nil {
+ return nil, err
+ }
+
+ childT := &thread{
+ tgid: int32(pid),
+ tid: int32(pid),
+ cpu: ^uint32(0),
+ }
+ childT.attach()
+
+ return childT, nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go
new file mode 100644
index 000000000..1bf7eab28
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go
@@ -0,0 +1,109 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 linux
+
+package ptrace
+
+import (
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// maskPool contains reusable CPU masks for setting affinity. Unfortunately,
+// runtime.NumCPU doesn't actually record the number of CPUs on the system, it
+// just records the number of CPUs available in the scheduler affinity set at
+// startup. This may a) change over time and b) gives a number far lower than
+// the maximum indexable CPU. To prevent lots of allocation in the hot path, we
+// use a pool to store large masks that we can reuse during bind.
+var maskPool = sync.Pool{
+ New: func() interface{} {
+ const maxCPUs = 1024 // Not a hard limit; see below.
+ return make([]uintptr, maxCPUs/64)
+ },
+}
+
+// unmaskAllSignals unmasks all signals on the current thread.
+//
+//go:nosplit
+func unmaskAllSignals() syscall.Errno {
+ var set linux.SignalSet
+ _, _, errno := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_SETMASK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0)
+ return errno
+}
+
+// getCPU gets the current CPU.
+//
+// Precondition: the current runtime thread should be locked.
+func getCPU() (uint32, error) {
+ var cpu uintptr
+ if _, _, errno := syscall.RawSyscall(
+ unix.SYS_GETCPU,
+ uintptr(unsafe.Pointer(&cpu)),
+ 0, 0); errno != 0 {
+ return 0, errno
+ }
+ return uint32(cpu), nil
+}
+
+// setCPU sets the CPU affinity.
+func (t *thread) setCPU(cpu uint32) error {
+ mask := maskPool.Get().([]uintptr)
+ n := int(cpu / 64)
+ v := uintptr(1 << uintptr(cpu%64))
+ if n >= len(mask) {
+ // See maskPool note above. We've actually exceeded the number
+ // of available cores. Grow the mask and return it.
+ mask = make([]uintptr, n+1)
+ }
+ mask[n] |= v
+ if _, _, errno := syscall.RawSyscall(
+ unix.SYS_SCHED_SETAFFINITY,
+ uintptr(t.tid),
+ uintptr(len(mask)*8),
+ uintptr(unsafe.Pointer(&mask[0]))); errno != 0 {
+ return errno
+ }
+ mask[n] &^= v
+ maskPool.Put(mask)
+ return nil
+}
+
+// bind attempts to ensure that the thread is on the same CPU as the current
+// thread. This provides no guarantees as it is fundamentally a racy operation:
+// CPU sets may change and we may be rescheduled in the middle of this
+// operation. As a result, no failures are reported.
+//
+// Precondition: the current runtime thread should be locked.
+func (t *thread) bind() {
+ currentCPU, err := getCPU()
+ if err != nil {
+ return
+ }
+ if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU {
+ // Set the affinity on the thread and save the CPU for next
+ // round; we don't expect CPUs to bounce around too frequently.
+ //
+ // (It's worth noting that we could move CPUs between this point
+ // and when the tracee finishes executing. But that would be
+ // roughly the status quo anyways -- we're just maximizing our
+ // chances of colocation, not guaranteeing it.)
+ t.setCPU(currentCPU)
+ }
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
new file mode 100644
index 000000000..b80a3604d
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.12
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+package ptrace
+
+import (
+ _ "unsafe" // required for go:linkname.
+)
+
+//go:linkname beforeFork syscall.runtime_BeforeFork
+func beforeFork()
+
+//go:linkname afterFork syscall.runtime_AfterFork
+func afterFork()
+
+//go:linkname afterForkInChild syscall.runtime_AfterForkInChild
+func afterForkInChild()
diff --git a/pkg/sentry/platform/ring0/defs_impl.go b/pkg/sentry/platform/ring0/defs_impl.go
new file mode 100755
index 000000000..582553bc7
--- /dev/null
+++ b/pkg/sentry/platform/ring0/defs_impl.go
@@ -0,0 +1,538 @@
+package ring0
+
+import (
+ "syscall"
+
+ "fmt"
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "io"
+ "reflect"
+)
+
+var (
+ // UserspaceSize is the total size of userspace.
+ UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1)
+
+ // MaximumUserAddress is the largest possible user address.
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+
+ // KernelStartAddress is the starting kernel address.
+ KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
+)
+
+// Kernel is a global kernel object.
+//
+// This contains global state, shared by multiple CPUs.
+type Kernel struct {
+ KernelArchState
+}
+
+// Hooks are hooks for kernel functions.
+type Hooks interface {
+ // KernelSyscall is called for kernel system calls.
+ //
+ // Return from this call will restore registers and return to the kernel: the
+ // registers must be modified directly.
+ //
+ // If this function is not provided, a kernel exception results in halt.
+ //
+ // This must be go:nosplit, as this will be on the interrupt stack.
+ // Closures are permitted, as the pointer to the closure frame is not
+ // passed on the stack.
+ KernelSyscall()
+
+ // KernelException handles an exception during kernel execution.
+ //
+ // Return from this call will restore registers and return to the kernel: the
+ // registers must be modified directly.
+ //
+ // If this function is not provided, a kernel exception results in halt.
+ //
+ // This must be go:nosplit, as this will be on the interrupt stack.
+ // Closures are permitted, as the pointer to the closure frame is not
+ // passed on the stack.
+ KernelException(Vector)
+}
+
+// CPU is the per-CPU struct.
+type CPU struct {
+ // self is a self reference.
+ //
+ // This is always guaranteed to be at offset zero.
+ self *CPU
+
+ // kernel is reference to the kernel that this CPU was initialized
+ // with. This reference is kept for garbage collection purposes: CPU
+ // registers may refer to objects within the Kernel object that cannot
+ // be safely freed.
+ kernel *Kernel
+
+ // CPUArchState is architecture-specific state.
+ CPUArchState
+
+ // registers is a set of registers; these may be used on kernel system
+ // calls and exceptions via the Registers function.
+ registers syscall.PtraceRegs
+
+ // hooks are kernel hooks.
+ hooks Hooks
+}
+
+// Registers returns a modifiable-copy of the kernel registers.
+//
+// This is explicitly safe to call during KernelException and KernelSyscall.
+//
+//go:nosplit
+func (c *CPU) Registers() *syscall.PtraceRegs {
+ return &c.registers
+}
+
+// SwitchOpts are passed to the Switch function.
+type SwitchOpts struct {
+ // Registers are the user register state.
+ Registers *syscall.PtraceRegs
+
+ // FloatingPointState is a byte pointer where floating point state is
+ // saved and restored.
+ FloatingPointState *byte
+
+ // PageTables are the application page tables.
+ PageTables *pagetables.PageTables
+
+ // Flush indicates that a TLB flush should be forced on switch.
+ Flush bool
+
+ // FullRestore indicates that an iret-based restore should be used.
+ FullRestore bool
+
+ // SwitchArchOpts are architecture-specific options.
+ SwitchArchOpts
+}
+
+// Segment indices and Selectors.
+const (
+ // Index into GDT array.
+ _ = iota // Null descriptor first.
+ _ // Reserved (Linux is kernel 32).
+ segKcode // Kernel code (64-bit).
+ segKdata // Kernel data.
+ segUcode32 // User code (32-bit).
+ segUdata // User data.
+ segUcode64 // User code (64-bit).
+ segTss // Task segment descriptor.
+ segTssHi // Upper bits for TSS.
+ segLast // Last segment (terminal, not included).
+)
+
+// Selectors.
+const (
+ Kcode Selector = segKcode << 3
+ Kdata Selector = segKdata << 3
+ Ucode32 Selector = (segUcode32 << 3) | 3
+ Udata Selector = (segUdata << 3) | 3
+ Ucode64 Selector = (segUcode64 << 3) | 3
+ Tss Selector = segTss << 3
+)
+
+// Standard segments.
+var (
+ UserCodeSegment32 SegmentDescriptor
+ UserDataSegment SegmentDescriptor
+ UserCodeSegment64 SegmentDescriptor
+ KernelCodeSegment SegmentDescriptor
+ KernelDataSegment SegmentDescriptor
+)
+
+// KernelOpts has initialization options for the kernel.
+type KernelOpts struct {
+ // PageTables are the kernel pagetables; this must be provided.
+ PageTables *pagetables.PageTables
+}
+
+// KernelArchState contains architecture-specific state.
+type KernelArchState struct {
+ KernelOpts
+
+ // globalIDT is our set of interrupt gates.
+ globalIDT idt64
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
+ // stack is the stack used for interrupts on this CPU.
+ stack [256]byte
+
+ // errorCode is the error code from the last exception.
+ errorCode uintptr
+
+ // errorType indicates the type of error code here, it is always set
+ // along with the errorCode value above.
+ //
+ // It will either by 1, which indicates a user error, or 0 indicating a
+ // kernel error. If the error code below returns false (kernel error),
+ // then it cannot provide relevant information about the last
+ // exception.
+ errorType uintptr
+
+ // gdt is the CPU's descriptor table.
+ gdt descriptorTable
+
+ // tss is the CPU's task state.
+ tss TaskState64
+}
+
+// ErrorCode returns the last error code.
+//
+// The returned boolean indicates whether the error code corresponds to the
+// last user error or not. If it does not, then fault information must be
+// ignored. This is generally the result of a kernel fault while servicing a
+// user fault.
+//
+//go:nosplit
+func (c *CPU) ErrorCode() (value uintptr, user bool) {
+ return c.errorCode, c.errorType != 0
+}
+
+// ClearErrorCode resets the error code.
+//
+//go:nosplit
+func (c *CPU) ClearErrorCode() {
+ c.errorCode = 0
+ c.errorType = 1
+}
+
+// SwitchArchOpts are embedded in SwitchOpts.
+type SwitchArchOpts struct {
+ // UserPCID indicates that the application PCID to be used on switch,
+ // assuming that PCIDs are supported.
+ //
+ // Per pagetables_x86.go, a zero PCID implies a flush.
+ UserPCID uint16
+
+ // KernelPCID indicates that the kernel PCID to be used on return,
+ // assuming that PCIDs are supported.
+ //
+ // Per pagetables_x86.go, a zero PCID implies a flush.
+ KernelPCID uint16
+}
+
+func init() {
+ KernelCodeSegment.setCode64(0, 0, 0)
+ KernelDataSegment.setData(0, 0xffffffff, 0)
+ UserCodeSegment32.setCode64(0, 0, 3)
+ UserDataSegment.setData(0, 0xffffffff, 3)
+ UserCodeSegment64.setCode64(0, 0, 3)
+}
+
+// Emit prints architecture-specific offsets.
+func Emit(w io.Writer) {
+ fmt.Fprintf(w, "// Automatically generated, do not edit.\n")
+
+ c := &CPU{}
+ fmt.Fprintf(w, "\n// CPU offsets.\n")
+ fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
+ fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+
+ fmt.Fprintf(w, "\n// Bits.\n")
+ fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF)
+ fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
+
+ fmt.Fprintf(w, "\n// Vectors.\n")
+ fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero)
+ fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug)
+ fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI)
+ fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint)
+ fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow)
+ fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded)
+ fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode)
+ fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable)
+ fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault)
+ fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun)
+ fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS)
+ fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent)
+ fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault)
+ fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault)
+ fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
+ fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException)
+ fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck)
+ fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck)
+ fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException)
+ fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException)
+ fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException)
+ fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80)
+ fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
+
+ p := &syscall.PtraceRegs{}
+ fmt.Fprintf(w, "\n// Ptrace registers.\n")
+ fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_FS 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_GS 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer())
+}
+
+// Useful bits.
+const (
+ _CR0_PE = 1 << 0
+ _CR0_ET = 1 << 4
+ _CR0_AM = 1 << 18
+ _CR0_PG = 1 << 31
+
+ _CR4_PSE = 1 << 4
+ _CR4_PAE = 1 << 5
+ _CR4_PGE = 1 << 7
+ _CR4_OSFXSR = 1 << 9
+ _CR4_OSXMMEXCPT = 1 << 10
+ _CR4_FSGSBASE = 1 << 16
+ _CR4_PCIDE = 1 << 17
+ _CR4_OSXSAVE = 1 << 18
+ _CR4_SMEP = 1 << 20
+
+ _RFLAGS_AC = 1 << 18
+ _RFLAGS_NT = 1 << 14
+ _RFLAGS_IOPL = 3 << 12
+ _RFLAGS_DF = 1 << 10
+ _RFLAGS_IF = 1 << 9
+ _RFLAGS_STEP = 1 << 8
+ _RFLAGS_RESERVED = 1 << 1
+
+ _EFER_SCE = 0x001
+ _EFER_LME = 0x100
+ _EFER_LMA = 0x400
+ _EFER_NX = 0x800
+
+ _MSR_STAR = 0xc0000081
+ _MSR_LSTAR = 0xc0000082
+ _MSR_CSTAR = 0xc0000083
+ _MSR_SYSCALL_MASK = 0xc0000084
+ _MSR_PLATFORM_INFO = 0xce
+ _MSR_MISC_FEATURES = 0x140
+
+ _PLATFORM_INFO_CPUID_FAULT = 1 << 31
+
+ _MISC_FEATURE_CPUID_TRAP = 0x1
+)
+
+const (
+ // KernelFlagsSet should always be set in the kernel.
+ KernelFlagsSet = _RFLAGS_RESERVED
+
+ // UserFlagsSet are always set in userspace.
+ UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF
+
+ // KernelFlagsClear should always be clear in the kernel.
+ KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT
+
+ // UserFlagsClear are always cleared in userspace.
+ UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL
+)
+
+// Vector is an exception vector.
+type Vector uintptr
+
+// Exception vectors.
+const (
+ DivideByZero Vector = iota
+ Debug
+ NMI
+ Breakpoint
+ Overflow
+ BoundRangeExceeded
+ InvalidOpcode
+ DeviceNotAvailable
+ DoubleFault
+ CoprocessorSegmentOverrun
+ InvalidTSS
+ SegmentNotPresent
+ StackSegmentFault
+ GeneralProtectionFault
+ PageFault
+ _
+ X87FloatingPointException
+ AlignmentCheck
+ MachineCheck
+ SIMDFloatingPointException
+ VirtualizationException
+ SecurityException = 0x1e
+ SyscallInt80 = 0x80
+ _NR_INTERRUPTS = SyscallInt80 + 1
+)
+
+// System call vectors.
+const (
+ Syscall Vector = _NR_INTERRUPTS
+)
+
+// VirtualAddressBits returns the number bits available for virtual addresses.
+//
+// Note that sign-extension semantics apply to the highest order bit.
+//
+// FIXME(b/69382326): This should use the cpuid passed to Init.
+func VirtualAddressBits() uint32 {
+ ax, _, _, _ := cpuid.HostID(0x80000008, 0)
+ return (ax >> 8) & 0xff
+}
+
+// PhysicalAddressBits returns the number of bits available for physical addresses.
+//
+// FIXME(b/69382326): This should use the cpuid passed to Init.
+func PhysicalAddressBits() uint32 {
+ ax, _, _, _ := cpuid.HostID(0x80000008, 0)
+ return ax & 0xff
+}
+
+// Selector is a segment Selector.
+type Selector uint16
+
+// SegmentDescriptor is a segment descriptor.
+type SegmentDescriptor struct {
+ bits [2]uint32
+}
+
+// descriptorTable is a collection of descriptors.
+type descriptorTable [32]SegmentDescriptor
+
+// SegmentDescriptorFlags are typed flags within a descriptor.
+type SegmentDescriptorFlags uint32
+
+// SegmentDescriptorFlag declarations.
+const (
+ SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set).
+ SegmentDescriptorWrite = 1 << 9 // Write permission.
+ SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used.
+ SegmentDescriptorExecute = 1 << 11 // Execute permission.
+ SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data.
+ SegmentDescriptorPresent = 1 << 15 // Present.
+ SegmentDescriptorAVL = 1 << 20 // Available.
+ SegmentDescriptorLong = 1 << 21 // Long mode.
+ SegmentDescriptorDB = 1 << 22 // 16 or 32-bit.
+ SegmentDescriptorG = 1 << 23 // Granularity: page or byte.
+)
+
+// Base returns the descriptor's base linear address.
+func (d *SegmentDescriptor) Base() uint32 {
+ return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16
+}
+
+// Limit returns the descriptor size.
+func (d *SegmentDescriptor) Limit() uint32 {
+ l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000
+ if d.bits[1]&uint32(SegmentDescriptorG) != 0 {
+ l <<= 12
+ l |= 0xFFF
+ }
+ return l
+}
+
+// Flags returns descriptor flags.
+func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags {
+ return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00)
+}
+
+// DPL returns the descriptor privilege level.
+func (d *SegmentDescriptor) DPL() int {
+ return int((d.bits[1] >> 13) & 3)
+}
+
+func (d *SegmentDescriptor) setNull() {
+ d.bits[0] = 0
+ d.bits[1] = 0
+}
+
+func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) {
+ flags |= SegmentDescriptorPresent
+ if limit>>12 != 0 {
+ limit >>= 12
+ flags |= SegmentDescriptorG
+ }
+ d.bits[0] = base<<16 | limit&0xFFFF
+ d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13
+}
+
+func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) {
+ d.set(base, limit, dpl,
+ SegmentDescriptorDB|
+ SegmentDescriptorExecute|
+ SegmentDescriptorSystem)
+}
+
+func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) {
+ d.set(base, limit, dpl,
+ SegmentDescriptorG|
+ SegmentDescriptorLong|
+ SegmentDescriptorExecute|
+ SegmentDescriptorSystem)
+}
+
+func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) {
+ d.set(base, limit, dpl,
+ SegmentDescriptorWrite|
+ SegmentDescriptorSystem)
+}
+
+// setHi is only used for the TSS segment, which is magically 64-bits.
+func (d *SegmentDescriptor) setHi(base uint32) {
+ d.bits[0] = base
+ d.bits[1] = 0
+}
+
+// Gate64 is a 64-bit task, trap, or interrupt gate.
+type Gate64 struct {
+ bits [4]uint32
+}
+
+// idt64 is a 64-bit interrupt descriptor table.
+type idt64 [_NR_INTERRUPTS]Gate64
+
+func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) {
+ g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF
+ g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7
+ g.bits[2] = uint32(rip >> 32)
+}
+
+func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) {
+ g.setInterrupt(cs, rip, dpl, ist)
+ g.bits[1] |= 1 << 8
+}
+
+// TaskState64 is a 64-bit task state structure.
+type TaskState64 struct {
+ _ uint32
+ rsp0Lo, rsp0Hi uint32
+ rsp1Lo, rsp1Hi uint32
+ rsp2Lo, rsp2Hi uint32
+ _ [2]uint32
+ ist1Lo, ist1Hi uint32
+ ist2Lo, ist2Hi uint32
+ ist3Lo, ist3Hi uint32
+ ist4Lo, ist4Hi uint32
+ ist5Lo, ist5Hi uint32
+ ist6Lo, ist6Hi uint32
+ ist7Lo, ist7Hi uint32
+ _ [2]uint32
+ _ uint16
+ ioPerm uint16
+}
diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go
new file mode 100644
index 000000000..a5ce67885
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_amd64.go
@@ -0,0 +1,128 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ring0
+
+import (
+ "syscall"
+)
+
+// This is an assembly function.
+//
+// The sysenter function is invoked in two situations:
+//
+// (1) The guest kernel has executed a system call.
+// (2) The guest application has executed a system call.
+//
+// The interrupt flag is examined to determine whether the system call was
+// executed from kernel mode or not and the appropriate stub is called.
+func sysenter()
+
+// swapgs swaps the current GS value.
+//
+// This must be called prior to sysret/iret.
+func swapgs()
+
+// sysret returns to userspace from a system call.
+//
+// The return code is the vector that interrupted execution.
+//
+// See stubs.go for a note regarding the frame size of this function.
+func sysret(*CPU, *syscall.PtraceRegs) Vector
+
+// "iret is the cadillac of CPL switching."
+//
+// -- Neel Natu
+//
+// iret is nearly identical to sysret, except an iret is used to fully restore
+// all user state. This must be called in cases where all registers need to be
+// restored.
+func iret(*CPU, *syscall.PtraceRegs) Vector
+
+// exception is the generic exception entry.
+//
+// This is called by the individual stub definitions.
+func exception()
+
+// resume is a stub that restores the CPU kernel registers.
+//
+// This is used when processing kernel exceptions and syscalls.
+func resume()
+
+// Start is the CPU entrypoint.
+//
+// The following start conditions must be satisfied:
+//
+// * AX should contain the CPU pointer.
+// * c.GDT() should be loaded as the GDT.
+// * c.IDT() should be loaded as the IDT.
+// * c.CR0() should be the current CR0 value.
+// * c.CR3() should be set to the kernel PageTables.
+// * c.CR4() should be the current CR4 value.
+// * c.EFER() should be the current EFER value.
+//
+// The CPU state will be set to c.Registers().
+func Start()
+
+// Exception stubs.
+func divideByZero()
+func debug()
+func nmi()
+func breakpoint()
+func overflow()
+func boundRangeExceeded()
+func invalidOpcode()
+func deviceNotAvailable()
+func doubleFault()
+func coprocessorSegmentOverrun()
+func invalidTSS()
+func segmentNotPresent()
+func stackSegmentFault()
+func generalProtectionFault()
+func pageFault()
+func x87FloatingPointException()
+func alignmentCheck()
+func machineCheck()
+func simdFloatingPointException()
+func virtualizationException()
+func securityException()
+func syscallInt80()
+
+// Exception handler index.
+var handlers = map[Vector]func(){
+ DivideByZero: divideByZero,
+ Debug: debug,
+ NMI: nmi,
+ Breakpoint: breakpoint,
+ Overflow: overflow,
+ BoundRangeExceeded: boundRangeExceeded,
+ InvalidOpcode: invalidOpcode,
+ DeviceNotAvailable: deviceNotAvailable,
+ DoubleFault: doubleFault,
+ CoprocessorSegmentOverrun: coprocessorSegmentOverrun,
+ InvalidTSS: invalidTSS,
+ SegmentNotPresent: segmentNotPresent,
+ StackSegmentFault: stackSegmentFault,
+ GeneralProtectionFault: generalProtectionFault,
+ PageFault: pageFault,
+ X87FloatingPointException: x87FloatingPointException,
+ AlignmentCheck: alignmentCheck,
+ MachineCheck: machineCheck,
+ SIMDFloatingPointException: simdFloatingPointException,
+ VirtualizationException: virtualizationException,
+ SecurityException: securityException,
+ SyscallInt80: syscallInt80,
+}
diff --git a/pkg/sentry/platform/ring0/entry_impl_amd64.s b/pkg/sentry/platform/ring0/entry_impl_amd64.s
new file mode 100755
index 000000000..d082d06a9
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_impl_amd64.s
@@ -0,0 +1,383 @@
+// build +amd64
+
+// Automatically generated, do not edit.
+
+// CPU offsets.
+#define CPU_SELF 0x00
+#define CPU_REGISTERS 0x288
+#define CPU_STACK_TOP 0x110
+#define CPU_ERROR_CODE 0x110
+#define CPU_ERROR_TYPE 0x118
+
+// Bits.
+#define _RFLAGS_IF 0x200
+#define _KERNEL_FLAGS 0x02
+
+// Vectors.
+#define DivideByZero 0x00
+#define Debug 0x01
+#define NMI 0x02
+#define Breakpoint 0x03
+#define Overflow 0x04
+#define BoundRangeExceeded 0x05
+#define InvalidOpcode 0x06
+#define DeviceNotAvailable 0x07
+#define DoubleFault 0x08
+#define CoprocessorSegmentOverrun 0x09
+#define InvalidTSS 0x0a
+#define SegmentNotPresent 0x0b
+#define StackSegmentFault 0x0c
+#define GeneralProtectionFault 0x0d
+#define PageFault 0x0e
+#define X87FloatingPointException 0x10
+#define AlignmentCheck 0x11
+#define MachineCheck 0x12
+#define SIMDFloatingPointException 0x13
+#define VirtualizationException 0x14
+#define SecurityException 0x1e
+#define SyscallInt80 0x80
+#define Syscall 0x81
+
+// Ptrace registers.
+#define PTRACE_R15 0x00
+#define PTRACE_R14 0x08
+#define PTRACE_R13 0x10
+#define PTRACE_R12 0x18
+#define PTRACE_RBP 0x20
+#define PTRACE_RBX 0x28
+#define PTRACE_R11 0x30
+#define PTRACE_R10 0x38
+#define PTRACE_R9 0x40
+#define PTRACE_R8 0x48
+#define PTRACE_RAX 0x50
+#define PTRACE_RCX 0x58
+#define PTRACE_RDX 0x60
+#define PTRACE_RSI 0x68
+#define PTRACE_RDI 0x70
+#define PTRACE_ORIGRAX 0x78
+#define PTRACE_RIP 0x80
+#define PTRACE_CS 0x88
+#define PTRACE_FLAGS 0x90
+#define PTRACE_RSP 0x98
+#define PTRACE_SS 0xa0
+#define PTRACE_FS 0xa8
+#define PTRACE_GS 0xb0
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+// NB: Offsets are programatically generated (see BUILD).
+//
+// This file is concatenated with the definitions.
+
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not saved: AX, SP, IP, FLAGS, all segments.
+#define REGISTERS_SAVE(reg, offset) \
+ MOVQ R15, offset+PTRACE_R15(reg); \
+ MOVQ R14, offset+PTRACE_R14(reg); \
+ MOVQ R13, offset+PTRACE_R13(reg); \
+ MOVQ R12, offset+PTRACE_R12(reg); \
+ MOVQ BP, offset+PTRACE_RBP(reg); \
+ MOVQ BX, offset+PTRACE_RBX(reg); \
+ MOVQ CX, offset+PTRACE_RCX(reg); \
+ MOVQ DX, offset+PTRACE_RDX(reg); \
+ MOVQ R11, offset+PTRACE_R11(reg); \
+ MOVQ R10, offset+PTRACE_R10(reg); \
+ MOVQ R9, offset+PTRACE_R9(reg); \
+ MOVQ R8, offset+PTRACE_R8(reg); \
+ MOVQ SI, offset+PTRACE_RSI(reg); \
+ MOVQ DI, offset+PTRACE_RDI(reg);
+
+// Loads a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not loaded: AX, SP, IP, FLAGS, all segments.
+#define REGISTERS_LOAD(reg, offset) \
+ MOVQ offset+PTRACE_R15(reg), R15; \
+ MOVQ offset+PTRACE_R14(reg), R14; \
+ MOVQ offset+PTRACE_R13(reg), R13; \
+ MOVQ offset+PTRACE_R12(reg), R12; \
+ MOVQ offset+PTRACE_RBP(reg), BP; \
+ MOVQ offset+PTRACE_RBX(reg), BX; \
+ MOVQ offset+PTRACE_RCX(reg), CX; \
+ MOVQ offset+PTRACE_RDX(reg), DX; \
+ MOVQ offset+PTRACE_R11(reg), R11; \
+ MOVQ offset+PTRACE_R10(reg), R10; \
+ MOVQ offset+PTRACE_R9(reg), R9; \
+ MOVQ offset+PTRACE_R8(reg), R8; \
+ MOVQ offset+PTRACE_RSI(reg), SI; \
+ MOVQ offset+PTRACE_RDI(reg), DI;
+
+// SWAP_GS swaps the kernel GS (CPU).
+#define SWAP_GS() \
+ BYTE $0x0F; BYTE $0x01; BYTE $0xf8;
+
+// IRET returns from an interrupt frame.
+#define IRET() \
+ BYTE $0x48; BYTE $0xcf;
+
+// SYSRET64 executes the sysret instruction.
+#define SYSRET64() \
+ BYTE $0x48; BYTE $0x0f; BYTE $0x07;
+
+// LOAD_KERNEL_ADDRESS loads a kernel address.
+#define LOAD_KERNEL_ADDRESS(from, to) \
+ MOVQ from, to; \
+ ORQ ·KernelStartAddress(SB), to;
+
+// LOAD_KERNEL_STACK loads the kernel stack.
+#define LOAD_KERNEL_STACK(from) \
+ LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \
+ LEAQ CPU_STACK_TOP(SP), SP;
+
+// See kernel.go.
+TEXT ·Halt(SB),NOSPLIT,$0
+ HLT
+ RET
+
+// See entry_amd64.go.
+TEXT ·swapgs(SB),NOSPLIT,$0
+ SWAP_GS()
+ RET
+
+// See entry_amd64.go.
+TEXT ·sysret(SB),NOSPLIT,$0-24
+ // Save original state.
+ LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
+ LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
+ MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
+ MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
+
+ // Restore user register state.
+ REGISTERS_LOAD(AX, 0)
+ MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET.
+ MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET.
+ MOVQ PTRACE_RSP(AX), SP // Restore the stack directly.
+ MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+ SYSRET64()
+
+// See entry_amd64.go.
+TEXT ·iret(SB),NOSPLIT,$0-24
+ // Save original state.
+ LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
+ LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
+ MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
+ MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
+
+ // Build an IRET frame & restore state.
+ LOAD_KERNEL_STACK(BX)
+ MOVQ PTRACE_SS(AX), BX; PUSHQ BX
+ MOVQ PTRACE_RSP(AX), CX; PUSHQ CX
+ MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX
+ MOVQ PTRACE_CS(AX), DI; PUSHQ DI
+ MOVQ PTRACE_RIP(AX), SI; PUSHQ SI
+ REGISTERS_LOAD(AX, 0) // Restore most registers.
+ MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+ IRET()
+
+// See entry_amd64.go.
+TEXT ·resume(SB),NOSPLIT,$0
+ // See iret, above.
+ MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX
+ MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX
+ MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX
+ MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI
+ MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI
+ REGISTERS_LOAD(GS, CPU_REGISTERS)
+ MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX
+ IRET()
+
+// See entry_amd64.go.
+TEXT ·Start(SB),NOSPLIT,$0
+ LOAD_KERNEL_STACK(AX) // Set the stack.
+ PUSHQ $0x0 // Previous frame pointer.
+ MOVQ SP, BP // Set frame pointer.
+ PUSHQ AX // First argument (CPU).
+ CALL ·start(SB) // Call Go hook.
+ JMP ·resume(SB) // Restore to registers.
+
+// See entry_amd64.go.
+TEXT ·sysenter(SB),NOSPLIT,$0
+ // Interrupts are always disabled while we're executing in kernel mode
+ // and always enabled while executing in user mode. Therefore, we can
+ // reliably look at the flags in R11 to determine where this syscall
+ // was from.
+ TESTL $_RFLAGS_IF, R11
+ JZ kernel
+
+user:
+ SWAP_GS()
+ XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks.
+ XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs).
+ REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value.
+ MOVQ BX, PTRACE_RAX(AX) // Save everything else.
+ MOVQ BX, PTRACE_ORIGRAX(AX)
+ MOVQ CX, PTRACE_RIP(AX)
+ MOVQ R11, PTRACE_FLAGS(AX)
+ MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX)
+ MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
+ MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
+
+ // Return to the kernel, where the frame is:
+ //
+ // vector (sp+24)
+ // regs (sp+16)
+ // cpu (sp+8)
+ // vcpu.Switch (sp+0)
+ //
+ MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
+ MOVQ $Syscall, 24(SP) // Output vector.
+ RET
+
+kernel:
+ // We can't restore the original stack, but we can access the registers
+ // in the CPU state directly. No need for temporary juggling.
+ MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
+ MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
+ REGISTERS_SAVE(GS, CPU_REGISTERS)
+ MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS)
+ MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS)
+ MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
+ MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+
+ // Call the syscall trampoline.
+ LOAD_KERNEL_STACK(GS)
+ MOVQ CPU_SELF(GS), AX // Load vCPU.
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ POPQ AX // Pop vCPU.
+ JMP ·resume(SB)
+
+// exception is a generic exception handler.
+//
+// There are two cases handled:
+//
+// 1) An exception in kernel mode: this results in saving the state at the time
+// of the exception and calling the defined hook.
+//
+// 2) An exception in guest mode: the original kernel frame is restored, and
+// the vector & error codes are pushed as return values.
+//
+// See below for the stubs that call exception.
+TEXT ·exception(SB),NOSPLIT,$0
+ // Determine whether the exception occurred in kernel mode or user
+ // mode, based on the flags. We expect the following stack:
+ //
+ // SS (sp+48)
+ // SP (sp+40)
+ // FLAGS (sp+32)
+ // CS (sp+24)
+ // IP (sp+16)
+ // ERROR_CODE (sp+8)
+ // VECTOR (sp+0)
+ //
+ TESTL $_RFLAGS_IF, 32(SP)
+ JZ kernel
+
+user:
+ SWAP_GS()
+ ADDQ $-8, SP // Adjust for flags.
+ MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ).
+ XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs.
+ REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX.
+ MOVQ BX, PTRACE_RAX(AX) // Save it.
+ MOVQ BX, PTRACE_ORIGRAX(AX)
+ MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX)
+ MOVQ 24(SP), CX; MOVQ CX, PTRACE_CS(AX)
+ MOVQ 32(SP), DX; MOVQ DX, PTRACE_FLAGS(AX)
+ MOVQ 40(SP), DI; MOVQ DI, PTRACE_RSP(AX)
+ MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX)
+
+ // Copy out and return.
+ MOVQ 0(SP), BX // Load vector.
+ MOVQ 8(SP), CX // Load error code.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version).
+ MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
+ MOVQ CX, CPU_ERROR_CODE(GS) // Set error code.
+ MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
+ MOVQ BX, 24(SP) // Output vector.
+ RET
+
+kernel:
+ // As per above, we can save directly.
+ MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
+ MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
+ REGISTERS_SAVE(GS, CPU_REGISTERS)
+ MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS)
+ MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS)
+ MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS)
+
+ // Set the error code and adjust the stack.
+ MOVQ 8(SP), AX // Load the error code.
+ MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU.
+ MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ 0(SP), BX // BX contains the vector.
+ ADDQ $48, SP // Drop the exception frame.
+
+ // Call the exception trampoline.
+ LOAD_KERNEL_STACK(GS)
+ MOVQ CPU_SELF(GS), AX // Load vCPU.
+ PUSHQ BX // Second argument (vector).
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelException(SB) // Call the trampoline.
+ POPQ BX // Pop vector.
+ POPQ AX // Pop vCPU.
+ JMP ·resume(SB)
+
+#define EXCEPTION_WITH_ERROR(value, symbol) \
+TEXT symbol,NOSPLIT,$0; \
+ PUSHQ $value; \
+ JMP ·exception(SB);
+
+#define EXCEPTION_WITHOUT_ERROR(value, symbol) \
+TEXT symbol,NOSPLIT,$0; \
+ PUSHQ $0x0; \
+ PUSHQ $value; \
+ JMP ·exception(SB);
+
+EXCEPTION_WITHOUT_ERROR(DivideByZero, ·divideByZero(SB))
+EXCEPTION_WITHOUT_ERROR(Debug, ·debug(SB))
+EXCEPTION_WITHOUT_ERROR(NMI, ·nmi(SB))
+EXCEPTION_WITHOUT_ERROR(Breakpoint, ·breakpoint(SB))
+EXCEPTION_WITHOUT_ERROR(Overflow, ·overflow(SB))
+EXCEPTION_WITHOUT_ERROR(BoundRangeExceeded, ·boundRangeExceeded(SB))
+EXCEPTION_WITHOUT_ERROR(InvalidOpcode, ·invalidOpcode(SB))
+EXCEPTION_WITHOUT_ERROR(DeviceNotAvailable, ·deviceNotAvailable(SB))
+EXCEPTION_WITH_ERROR(DoubleFault, ·doubleFault(SB))
+EXCEPTION_WITHOUT_ERROR(CoprocessorSegmentOverrun, ·coprocessorSegmentOverrun(SB))
+EXCEPTION_WITH_ERROR(InvalidTSS, ·invalidTSS(SB))
+EXCEPTION_WITH_ERROR(SegmentNotPresent, ·segmentNotPresent(SB))
+EXCEPTION_WITH_ERROR(StackSegmentFault, ·stackSegmentFault(SB))
+EXCEPTION_WITH_ERROR(GeneralProtectionFault, ·generalProtectionFault(SB))
+EXCEPTION_WITH_ERROR(PageFault, ·pageFault(SB))
+EXCEPTION_WITHOUT_ERROR(X87FloatingPointException, ·x87FloatingPointException(SB))
+EXCEPTION_WITH_ERROR(AlignmentCheck, ·alignmentCheck(SB))
+EXCEPTION_WITHOUT_ERROR(MachineCheck, ·machineCheck(SB))
+EXCEPTION_WITHOUT_ERROR(SIMDFloatingPointException, ·simdFloatingPointException(SB))
+EXCEPTION_WITHOUT_ERROR(VirtualizationException, ·virtualizationException(SB))
+EXCEPTION_WITH_ERROR(SecurityException, ·securityException(SB))
+EXCEPTION_WITHOUT_ERROR(SyscallInt80, ·syscallInt80(SB))
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
new file mode 100644
index 000000000..900c0bba7
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ring0
+
+// Init initializes a new kernel.
+//
+// N.B. that constraints on KernelOpts must be satisfied.
+//
+//go:nosplit
+func (k *Kernel) Init(opts KernelOpts) {
+ k.init(opts)
+}
+
+// Halt halts execution.
+func Halt()
+
+// defaultHooks implements hooks.
+type defaultHooks struct{}
+
+// KernelSyscall implements Hooks.KernelSyscall.
+//
+//go:nosplit
+func (defaultHooks) KernelSyscall() { Halt() }
+
+// KernelException implements Hooks.KernelException.
+//
+//go:nosplit
+func (defaultHooks) KernelException(Vector) { Halt() }
+
+// kernelSyscall is a trampoline.
+//
+//go:nosplit
+func kernelSyscall(c *CPU) { c.hooks.KernelSyscall() }
+
+// kernelException is a trampoline.
+//
+//go:nosplit
+func kernelException(c *CPU, vector Vector) { c.hooks.KernelException(vector) }
+
+// Init initializes a new CPU.
+//
+// Init allows embedding in other objects.
+func (c *CPU) Init(k *Kernel, hooks Hooks) {
+ c.self = c // Set self reference.
+ c.kernel = k // Set kernel reference.
+ c.init() // Perform architectural init.
+
+ // Require hooks.
+ if hooks != nil {
+ c.hooks = hooks
+ } else {
+ c.hooks = defaultHooks{}
+ }
+}
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
new file mode 100644
index 000000000..3577b5127
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -0,0 +1,271 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ring0
+
+import (
+ "encoding/binary"
+)
+
+// init initializes architecture-specific state.
+func (k *Kernel) init(opts KernelOpts) {
+ // Save the root page tables.
+ k.PageTables = opts.PageTables
+
+ // Setup the IDT, which is uniform.
+ for v, handler := range handlers {
+ // Allow Breakpoint and Overflow to be called from all
+ // privilege levels.
+ dpl := 0
+ if v == Breakpoint || v == Overflow {
+ dpl = 3
+ }
+ // Note that we set all traps to use the interrupt stack, this
+ // is defined below when setting up the TSS.
+ k.globalIDT[v].setInterrupt(Kcode, uint64(kernelFunc(handler)), dpl, 1 /* ist */)
+ }
+}
+
+// init initializes architecture-specific state.
+func (c *CPU) init() {
+ // Null segment.
+ c.gdt[0].setNull()
+
+ // Kernel & user segments.
+ c.gdt[segKcode] = KernelCodeSegment
+ c.gdt[segKdata] = KernelDataSegment
+ c.gdt[segUcode32] = UserCodeSegment32
+ c.gdt[segUdata] = UserDataSegment
+ c.gdt[segUcode64] = UserCodeSegment64
+
+ // The task segment, this spans two entries.
+ tssBase, tssLimit, _ := c.TSS()
+ c.gdt[segTss].set(
+ uint32(tssBase),
+ uint32(tssLimit),
+ 0, // Privilege level zero.
+ SegmentDescriptorPresent|
+ SegmentDescriptorAccess|
+ SegmentDescriptorWrite|
+ SegmentDescriptorExecute)
+ c.gdt[segTssHi].setHi(uint32((tssBase) >> 32))
+
+ // Set the kernel stack pointer in the TSS (virtual address).
+ stackAddr := c.StackTop()
+ c.tss.rsp0Lo = uint32(stackAddr)
+ c.tss.rsp0Hi = uint32(stackAddr >> 32)
+ c.tss.ist1Lo = uint32(stackAddr)
+ c.tss.ist1Hi = uint32(stackAddr >> 32)
+
+ // Permanently set the kernel segments.
+ c.registers.Cs = uint64(Kcode)
+ c.registers.Ds = uint64(Kdata)
+ c.registers.Es = uint64(Kdata)
+ c.registers.Ss = uint64(Kdata)
+ c.registers.Fs = uint64(Kdata)
+ c.registers.Gs = uint64(Kdata)
+
+ // Set mandatory flags.
+ c.registers.Eflags = KernelFlagsSet
+}
+
+// StackTop returns the kernel's stack address.
+//
+//go:nosplit
+func (c *CPU) StackTop() uint64 {
+ return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack))
+}
+
+// IDT returns the CPU's IDT base and limit.
+//
+//go:nosplit
+func (c *CPU) IDT() (uint64, uint16) {
+ return uint64(kernelAddr(&c.kernel.globalIDT[0])), uint16(binary.Size(&c.kernel.globalIDT) - 1)
+}
+
+// GDT returns the CPU's GDT base and limit.
+//
+//go:nosplit
+func (c *CPU) GDT() (uint64, uint16) {
+ return uint64(kernelAddr(&c.gdt[0])), uint16(8*segLast - 1)
+}
+
+// TSS returns the CPU's TSS base, limit and value.
+//
+//go:nosplit
+func (c *CPU) TSS() (uint64, uint16, *SegmentDescriptor) {
+ return uint64(kernelAddr(&c.tss)), uint16(binary.Size(&c.tss) - 1), &c.gdt[segTss]
+}
+
+// CR0 returns the CPU's CR0 value.
+//
+//go:nosplit
+func (c *CPU) CR0() uint64 {
+ return _CR0_PE | _CR0_PG | _CR0_AM | _CR0_ET
+}
+
+// CR4 returns the CPU's CR4 value.
+//
+//go:nosplit
+func (c *CPU) CR4() uint64 {
+ cr4 := uint64(_CR4_PAE | _CR4_PSE | _CR4_OSFXSR | _CR4_OSXMMEXCPT)
+ if hasPCID {
+ cr4 |= _CR4_PCIDE
+ }
+ if hasXSAVE {
+ cr4 |= _CR4_OSXSAVE
+ }
+ if hasSMEP {
+ cr4 |= _CR4_SMEP
+ }
+ if hasFSGSBASE {
+ cr4 |= _CR4_FSGSBASE
+ }
+ return cr4
+}
+
+// EFER returns the CPU's EFER value.
+//
+//go:nosplit
+func (c *CPU) EFER() uint64 {
+ return _EFER_LME | _EFER_LMA | _EFER_SCE | _EFER_NX
+}
+
+// IsCanonical indicates whether addr is canonical per the amd64 spec.
+//
+//go:nosplit
+func IsCanonical(addr uint64) bool {
+ return addr <= 0x00007fffffffffff || addr > 0xffff800000000000
+}
+
+// SwitchToUser performs either a sysret or an iret.
+//
+// The return value is the vector that interrupted execution.
+//
+// This function will not split the stack. Callers will probably want to call
+// runtime.entersyscall (and pair with a call to runtime.exitsyscall) prior to
+// calling this function.
+//
+// When this is done, this region is quite sensitive to things like system
+// calls. After calling entersyscall, any memory used must have been allocated
+// and no function calls without go:nosplit are permitted. Any calls made here
+// are protected appropriately (e.g. IsCanonical and CR3).
+//
+// Also note that this function transitively depends on the compiler generating
+// code that uses IP-relative addressing inside of absolute addresses. That's
+// the case for amd64, but may not be the case for other architectures.
+//
+// Precondition: the Rip, Rsp, Fs and Gs registers must be canonical.
+//
+//go:nosplit
+func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
+ userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
+ kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)
+
+ // Sanitize registers.
+ regs := switchOpts.Registers
+ regs.Eflags &= ^uint64(UserFlagsClear)
+ regs.Eflags |= UserFlagsSet
+ regs.Cs = uint64(Ucode64) // Required for iret.
+ regs.Ss = uint64(Udata) // Ditto.
+
+ // Perform the switch.
+ swapgs() // GS will be swapped on return.
+ WriteFS(uintptr(regs.Fs_base)) // Set application FS.
+ WriteGS(uintptr(regs.Gs_base)) // Set application GS.
+ LoadFloatingPoint(switchOpts.FloatingPointState) // Copy in floating point.
+ jumpToKernel() // Switch to upper half.
+ writeCR3(uintptr(userCR3)) // Change to user address space.
+ if switchOpts.FullRestore {
+ vector = iret(c, regs)
+ } else {
+ vector = sysret(c, regs)
+ }
+ writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
+ jumpToUser() // Return to lower half.
+ SaveFloatingPoint(switchOpts.FloatingPointState) // Copy out floating point.
+ WriteFS(uintptr(c.registers.Fs_base)) // Restore kernel FS.
+ return
+}
+
+// start is the CPU entrypoint.
+//
+// This is called from the Start asm stub (see entry_amd64.go); on return the
+// registers in c.registers will be restored (not segments).
+//
+//go:nosplit
+func start(c *CPU) {
+ // Save per-cpu & FS segment.
+ WriteGS(kernelAddr(c))
+ WriteFS(uintptr(c.registers.Fs_base))
+
+ // Initialize floating point.
+ //
+ // Note that on skylake, the valid XCR0 mask reported seems to be 0xff.
+ // This breaks down as:
+ //
+ // bit0 - x87
+ // bit1 - SSE
+ // bit2 - AVX
+ // bit3-4 - MPX
+ // bit5-7 - AVX512
+ //
+ // For some reason, enabled MPX & AVX512 on platforms that report them
+ // seems to be cause a general protection fault. (Maybe there are some
+ // virtualization issues and these aren't exported to the guest cpuid.)
+ // This needs further investigation, but we can limit the floating
+ // point operations to x87, SSE & AVX for now.
+ fninit()
+ xsetbv(0, validXCR0Mask&0x7)
+
+ // Set the syscall target.
+ wrmsr(_MSR_LSTAR, kernelFunc(sysenter))
+ wrmsr(_MSR_SYSCALL_MASK, KernelFlagsClear|_RFLAGS_DF)
+
+ // NOTE: This depends on having the 64-bit segments immediately
+ // following the 32-bit user segments. This is simply the way the
+ // sysret instruction is designed to work (it assumes they follow).
+ wrmsr(_MSR_STAR, uintptr(uint64(Kcode)<<32|uint64(Ucode32)<<48))
+ wrmsr(_MSR_CSTAR, kernelFunc(sysenter))
+}
+
+// SetCPUIDFaulting sets CPUID faulting per the boolean value.
+//
+// True is returned if faulting could be set.
+//
+//go:nosplit
+func SetCPUIDFaulting(on bool) bool {
+ // Per the SDM (Vol 3, Table 2-43), PLATFORM_INFO bit 31 denotes support
+ // for CPUID faulting, and we enable and disable via the MISC_FEATURES MSR.
+ if rdmsr(_MSR_PLATFORM_INFO)&_PLATFORM_INFO_CPUID_FAULT != 0 {
+ features := rdmsr(_MSR_MISC_FEATURES)
+ if on {
+ features |= _MISC_FEATURE_CPUID_TRAP
+ } else {
+ features &^= _MISC_FEATURE_CPUID_TRAP
+ }
+ wrmsr(_MSR_MISC_FEATURES, features)
+ return true // Setting successful.
+ }
+ return false
+}
+
+// ReadCR2 reads the current CR2 value.
+//
+//go:nosplit
+func ReadCR2() uintptr {
+ return readCR2()
+}
diff --git a/pkg/sentry/platform/ring0/kernel_unsafe.go b/pkg/sentry/platform/ring0/kernel_unsafe.go
new file mode 100644
index 000000000..16955ad91
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ring0
+
+import (
+ "unsafe"
+)
+
+// eface mirrors runtime.eface.
+type eface struct {
+ typ uintptr
+ data unsafe.Pointer
+}
+
+// kernelAddr returns the kernel virtual address for the given object.
+//
+//go:nosplit
+func kernelAddr(obj interface{}) uintptr {
+ e := (*eface)(unsafe.Pointer(&obj))
+ return KernelStartAddress | uintptr(e.data)
+}
+
+// kernelFunc returns the address of the given function.
+//
+//go:nosplit
+func kernelFunc(fn func()) uintptr {
+ fnptr := (**uintptr)(unsafe.Pointer(&fn))
+ return KernelStartAddress | **fnptr
+}
diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go
new file mode 100644
index 000000000..9c5f26962
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_amd64.go
@@ -0,0 +1,131 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package ring0
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/cpuid"
+)
+
+// LoadFloatingPoint loads floating point state by the most efficient mechanism
+// available (set by Init).
+var LoadFloatingPoint func(*byte)
+
+// SaveFloatingPoint saves floating point state by the most efficient mechanism
+// available (set by Init).
+var SaveFloatingPoint func(*byte)
+
+// fxrstor uses fxrstor64 to load floating point state.
+func fxrstor(*byte)
+
+// xrstor uses xrstor to load floating point state.
+func xrstor(*byte)
+
+// fxsave uses fxsave64 to save floating point state.
+func fxsave(*byte)
+
+// xsave uses xsave to save floating point state.
+func xsave(*byte)
+
+// xsaveopt uses xsaveopt to save floating point state.
+func xsaveopt(*byte)
+
+// WriteFS sets the GS address (set by init).
+var WriteFS func(addr uintptr)
+
+// wrfsbase writes to the GS base address.
+func wrfsbase(addr uintptr)
+
+// wrfsmsr writes to the GS_BASE MSR.
+func wrfsmsr(addr uintptr)
+
+// WriteGS sets the GS address (set by init).
+var WriteGS func(addr uintptr)
+
+// wrgsbase writes to the GS base address.
+func wrgsbase(addr uintptr)
+
+// wrgsmsr writes to the GS_BASE MSR.
+func wrgsmsr(addr uintptr)
+
+// writeCR3 writes the CR3 value.
+func writeCR3(phys uintptr)
+
+// readCR3 reads the current CR3 value.
+func readCR3() uintptr
+
+// readCR2 reads the current CR2 value.
+func readCR2() uintptr
+
+// jumpToKernel jumps to the kernel version of the current RIP.
+func jumpToKernel()
+
+// jumpToUser jumps to the user version of the current RIP.
+func jumpToUser()
+
+// fninit initializes the floating point unit.
+func fninit()
+
+// xsetbv writes to an extended control register.
+func xsetbv(reg, value uintptr)
+
+// xgetbv reads an extended control register.
+func xgetbv(reg uintptr) uintptr
+
+// wrmsr reads to the given MSR.
+func wrmsr(reg, value uintptr)
+
+// rdmsr reads the given MSR.
+func rdmsr(reg uintptr) uintptr
+
+// Mostly-constants set by Init.
+var (
+ hasSMEP bool
+ hasPCID bool
+ hasXSAVEOPT bool
+ hasXSAVE bool
+ hasFSGSBASE bool
+ validXCR0Mask uintptr
+)
+
+// Init sets function pointers based on architectural features.
+//
+// This must be called prior to using ring0.
+func Init(featureSet *cpuid.FeatureSet) {
+ hasSMEP = featureSet.HasFeature(cpuid.X86FeatureSMEP)
+ hasPCID = featureSet.HasFeature(cpuid.X86FeaturePCID)
+ hasXSAVEOPT = featureSet.UseXsaveopt()
+ hasXSAVE = featureSet.UseXsave()
+ hasFSGSBASE = featureSet.HasFeature(cpuid.X86FeatureFSGSBase)
+ validXCR0Mask = uintptr(featureSet.ValidXCR0Mask())
+ if hasXSAVEOPT {
+ SaveFloatingPoint = xsaveopt
+ LoadFloatingPoint = xrstor
+ } else if hasXSAVE {
+ SaveFloatingPoint = xsave
+ LoadFloatingPoint = xrstor
+ } else {
+ SaveFloatingPoint = fxsave
+ LoadFloatingPoint = fxrstor
+ }
+ if hasFSGSBASE {
+ WriteFS = wrfsbase
+ WriteGS = wrgsbase
+ } else {
+ WriteFS = wrfsmsr
+ WriteGS = wrgsmsr
+ }
+}
diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s
new file mode 100644
index 000000000..75d742750
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_amd64.s
@@ -0,0 +1,247 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+// fxrstor loads floating point state.
+//
+// The code corresponds to:
+//
+// fxrstor64 (%rbx)
+//
+TEXT ·fxrstor(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), BX
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x0b;
+ RET
+
+// xrstor loads floating point state.
+//
+// The code corresponds to:
+//
+// xrstor (%rdi)
+//
+TEXT ·xrstor(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), DI
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f;
+ RET
+
+// fxsave saves floating point state.
+//
+// The code corresponds to:
+//
+// fxsave64 (%rbx)
+//
+TEXT ·fxsave(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), BX
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x03;
+ RET
+
+// xsave saves floating point state.
+//
+// The code corresponds to:
+//
+// xsave (%rdi)
+//
+TEXT ·xsave(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), DI
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27;
+ RET
+
+// xsaveopt saves floating point state.
+//
+// The code corresponds to:
+//
+// xsaveopt (%rdi)
+//
+TEXT ·xsaveopt(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), DI
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37;
+ RET
+
+// wrfsbase writes to the FS base.
+//
+// The code corresponds to:
+//
+// wrfsbase %rax
+//
+TEXT ·wrfsbase(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), AX
+ BYTE $0xf3; BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0xd0;
+ RET
+
+// wrfsmsr writes to the FSBASE MSR.
+//
+// The code corresponds to:
+//
+// wrmsr (writes EDX:EAX to the MSR in ECX)
+//
+TEXT ·wrfsmsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), AX
+ MOVQ AX, DX
+ SHRQ $32, DX
+ MOVQ $0xc0000100, CX // MSR_FS_BASE
+ BYTE $0x0f; BYTE $0x30;
+ RET
+
+// wrgsbase writes to the GS base.
+//
+// The code corresponds to:
+//
+// wrgsbase %rax
+//
+TEXT ·wrgsbase(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), AX
+ BYTE $0xf3; BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0xd8;
+ RET
+
+// wrgsmsr writes to the GSBASE MSR.
+//
+// See wrfsmsr.
+TEXT ·wrgsmsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), AX
+ MOVQ AX, DX
+ SHRQ $32, DX
+ MOVQ $0xc0000101, CX // MSR_GS_BASE
+ BYTE $0x0f; BYTE $0x30; // WRMSR
+ RET
+
+// jumpToUser changes execution to the user address.
+//
+// This works by changing the return value to the user version.
+TEXT ·jumpToUser(SB),NOSPLIT,$0
+ MOVQ 0(SP), AX
+ MOVQ ·KernelStartAddress(SB), BX
+ NOTQ BX
+ ANDQ BX, SP // Switch the stack.
+ ANDQ BX, BP // Switch the frame pointer.
+ ANDQ BX, AX // Future return value.
+ MOVQ AX, 0(SP)
+ RET
+
+// jumpToKernel changes execution to the kernel address space.
+//
+// This works by changing the return value to the kernel version.
+TEXT ·jumpToKernel(SB),NOSPLIT,$0
+ MOVQ 0(SP), AX
+ MOVQ ·KernelStartAddress(SB), BX
+ ORQ BX, SP // Switch the stack.
+ ORQ BX, BP // Switch the frame pointer.
+ ORQ BX, AX // Future return value.
+ MOVQ AX, 0(SP)
+ RET
+
+// writeCR3 writes the given CR3 value.
+//
+// The code corresponds to:
+//
+// mov %rax, %cr3
+//
+TEXT ·writeCR3(SB),NOSPLIT,$0-8
+ MOVQ cr3+0(FP), AX
+ BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
+ RET
+
+// readCR3 reads the current CR3 value.
+//
+// The code corresponds to:
+//
+// mov %cr3, %rax
+//
+TEXT ·readCR3(SB),NOSPLIT,$0-8
+ BYTE $0x0f; BYTE $0x20; BYTE $0xd8;
+ MOVQ AX, ret+0(FP)
+ RET
+
+// readCR2 reads the current CR2 value.
+//
+// The code corresponds to:
+//
+// mov %cr2, %rax
+//
+TEXT ·readCR2(SB),NOSPLIT,$0-8
+ BYTE $0x0f; BYTE $0x20; BYTE $0xd0;
+ MOVQ AX, ret+0(FP)
+ RET
+
+// fninit initializes the floating point unit.
+//
+// The code corresponds to:
+//
+// fninit
+TEXT ·fninit(SB),NOSPLIT,$0
+ BYTE $0xdb; BYTE $0xe3;
+ RET
+
+// xsetbv writes to an extended control register.
+//
+// The code corresponds to:
+//
+// xsetbv
+//
+TEXT ·xsetbv(SB),NOSPLIT,$0-16
+ MOVL reg+0(FP), CX
+ MOVL value+8(FP), AX
+ MOVL value+12(FP), DX
+ BYTE $0x0f; BYTE $0x01; BYTE $0xd1;
+ RET
+
+// xgetbv reads an extended control register.
+//
+// The code corresponds to:
+//
+// xgetbv
+//
+TEXT ·xgetbv(SB),NOSPLIT,$0-16
+ MOVL reg+0(FP), CX
+ BYTE $0x0f; BYTE $0x01; BYTE $0xd0;
+ MOVL AX, ret+8(FP)
+ MOVL DX, ret+12(FP)
+ RET
+
+// wrmsr writes to a control register.
+//
+// The code corresponds to:
+//
+// wrmsr
+//
+TEXT ·wrmsr(SB),NOSPLIT,$0-16
+ MOVL reg+0(FP), CX
+ MOVL value+8(FP), AX
+ MOVL value+12(FP), DX
+ BYTE $0x0f; BYTE $0x30;
+ RET
+
+// rdmsr reads a control register.
+//
+// The code corresponds to:
+//
+// rdmsr
+//
+TEXT ·rdmsr(SB),NOSPLIT,$0-16
+ MOVL reg+0(FP), CX
+ BYTE $0x0f; BYTE $0x32;
+ MOVL AX, ret+8(FP)
+ MOVL DX, ret+12(FP)
+ RET
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/sentry/platform/ring0/pagetables/allocator.go
new file mode 100644
index 000000000..23fd5c352
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/allocator.go
@@ -0,0 +1,122 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+// Allocator is used to allocate and map PTEs.
+//
+// Note that allocators may be called concurrently.
+type Allocator interface {
+ // NewPTEs returns a new set of PTEs and their physical address.
+ NewPTEs() *PTEs
+
+ // PhysicalFor gives the physical address for a set of PTEs.
+ PhysicalFor(ptes *PTEs) uintptr
+
+ // LookupPTEs looks up PTEs by physical address.
+ LookupPTEs(physical uintptr) *PTEs
+
+ // FreePTEs marks a set of PTEs a freed, although they may not be available
+ // for use again until Recycle is called, below.
+ FreePTEs(ptes *PTEs)
+
+ // Recycle makes freed PTEs available for use again.
+ Recycle()
+}
+
+// RuntimeAllocator is a trivial allocator.
+type RuntimeAllocator struct {
+ // used is the set of PTEs that have been allocated. This includes any
+ // PTEs that may be in the pool below. PTEs are only freed from this
+ // map by the Drain call.
+ //
+ // This exists to prevent accidental garbage collection.
+ used map[*PTEs]struct{}
+
+ // pool is the set of free-to-use PTEs.
+ pool []*PTEs
+
+ // freed is the set of recently-freed PTEs.
+ freed []*PTEs
+}
+
+// NewRuntimeAllocator returns an allocator that uses runtime allocation.
+func NewRuntimeAllocator() *RuntimeAllocator {
+ return &RuntimeAllocator{
+ used: make(map[*PTEs]struct{}),
+ }
+}
+
+// Recycle returns freed pages to the pool.
+func (r *RuntimeAllocator) Recycle() {
+ r.pool = append(r.pool, r.freed...)
+ r.freed = r.freed[:0]
+}
+
+// Drain empties the pool.
+func (r *RuntimeAllocator) Drain() {
+ r.Recycle()
+ for i, ptes := range r.pool {
+ // Zap the entry in the underlying array to ensure that it can
+ // be properly garbage collected.
+ r.pool[i] = nil
+ // Similarly, free the reference held by the used map (these
+ // also apply for the pool entries).
+ delete(r.used, ptes)
+ }
+ r.pool = r.pool[:0]
+}
+
+// NewPTEs implements Allocator.NewPTEs.
+//
+// Note that the "physical" address here is actually the virtual address of the
+// PTEs structure. The entries are tracked only to avoid garbage collection.
+//
+// This is guaranteed not to split as long as the pool is sufficiently full.
+//
+//go:nosplit
+func (r *RuntimeAllocator) NewPTEs() *PTEs {
+ // Pull from the pool if we can.
+ if len(r.pool) > 0 {
+ ptes := r.pool[len(r.pool)-1]
+ r.pool = r.pool[:len(r.pool)-1]
+ return ptes
+ }
+
+ // Allocate a new entry.
+ ptes := newAlignedPTEs()
+ r.used[ptes] = struct{}{}
+ return ptes
+}
+
+// PhysicalFor returns the physical address for the given PTEs.
+//
+//go:nosplit
+func (r *RuntimeAllocator) PhysicalFor(ptes *PTEs) uintptr {
+ return physicalFor(ptes)
+}
+
+// LookupPTEs implements Allocator.LookupPTEs.
+//
+//go:nosplit
+func (r *RuntimeAllocator) LookupPTEs(physical uintptr) *PTEs {
+ return fromPhysical(physical)
+}
+
+// FreePTEs implements Allocator.FreePTEs.
+//
+//go:nosplit
+func (r *RuntimeAllocator) FreePTEs(ptes *PTEs) {
+ r.freed = append(r.freed, ptes)
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go b/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go
new file mode 100644
index 000000000..1b996b4e2
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+import (
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// newAlignedPTEs returns a set of aligned PTEs.
+func newAlignedPTEs() *PTEs {
+ ptes := new(PTEs)
+ offset := physicalFor(ptes) & (usermem.PageSize - 1)
+ if offset == 0 {
+ // Already aligned.
+ return ptes
+ }
+
+ // Need to force an aligned allocation.
+ unaligned := make([]byte, (2*usermem.PageSize)-1)
+ offset = uintptr(unsafe.Pointer(&unaligned[0])) & (usermem.PageSize - 1)
+ if offset != 0 {
+ offset = usermem.PageSize - offset
+ }
+ return (*PTEs)(unsafe.Pointer(&unaligned[offset]))
+}
+
+// physicalFor returns the "physical" address for PTEs.
+//
+//go:nosplit
+func physicalFor(ptes *PTEs) uintptr {
+ return uintptr(unsafe.Pointer(ptes))
+}
+
+// fromPhysical returns the PTEs from the "physical" address.
+//
+//go:nosplit
+func fromPhysical(physical uintptr) *PTEs {
+ return (*PTEs)(unsafe.Pointer(physical))
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
new file mode 100644
index 000000000..e5dcaada7
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -0,0 +1,221 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pagetables provides a generic implementation of pagetables.
+//
+// The core functions must be safe to call from a nosplit context. Furthermore,
+// this pagetables implementation goes to lengths to ensure that all functions
+// are free from runtime allocation. Calls to NewPTEs/FreePTEs may be made
+// during walks, but these can be cached elsewhere if required.
+package pagetables
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// PageTables is a set of page tables.
+type PageTables struct {
+ // Allocator is used to allocate nodes.
+ Allocator Allocator
+
+ // root is the pagetable root.
+ root *PTEs
+
+ // rootPhysical is the cached physical address of the root.
+ //
+ // This is saved only to prevent constant translation.
+ rootPhysical uintptr
+
+ // archPageTables includes architecture-specific features.
+ archPageTables
+}
+
+// New returns new PageTables.
+func New(a Allocator) *PageTables {
+ p := new(PageTables)
+ p.Init(a)
+ return p
+}
+
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+}
+
+// mapVisitor is used for map.
+type mapVisitor struct {
+ target uintptr // Input.
+ physical uintptr // Input.
+ opts MapOpts // Input.
+ prev bool // Output.
+}
+
+// visit is used for map.
+//
+//go:nosplit
+func (v *mapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ p := v.physical + (start - uintptr(v.target))
+ if pte.Valid() && (pte.Address() != p || pte.Opts() != v.opts) {
+ v.prev = true
+ }
+ if p&align != 0 {
+ // We will install entries at a smaller granulaity if we don't
+ // install a valid entry here, however we must zap any existing
+ // entry to ensure this happens.
+ pte.Clear()
+ return
+ }
+ pte.Set(p, v.opts)
+}
+
+//go:nosplit
+func (*mapVisitor) requiresAlloc() bool { return true }
+
+//go:nosplit
+func (*mapVisitor) requiresSplit() bool { return true }
+
+// Map installs a mapping with the given physical address.
+//
+// True is returned iff there was a previous mapping in the range.
+//
+// Precondition: addr & length must be page-aligned, their sum must not overflow.
+//
+//go:nosplit
+func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
+ if !opts.AccessType.Any() {
+ return p.Unmap(addr, length)
+ }
+ w := mapWalker{
+ pageTables: p,
+ visitor: mapVisitor{
+ target: uintptr(addr),
+ physical: physical,
+ opts: opts,
+ },
+ }
+ w.iterateRange(uintptr(addr), uintptr(addr)+length)
+ return w.visitor.prev
+}
+
+// unmapVisitor is used for unmap.
+type unmapVisitor struct {
+ count int
+}
+
+//go:nosplit
+func (*unmapVisitor) requiresAlloc() bool { return false }
+
+//go:nosplit
+func (*unmapVisitor) requiresSplit() bool { return true }
+
+// visit unmaps the given entry.
+//
+//go:nosplit
+func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ pte.Clear()
+ v.count++
+}
+
+// Unmap unmaps the given range.
+//
+// True is returned iff there was a previous mapping in the range.
+//
+// Precondition: addr & length must be page-aligned.
+//
+//go:nosplit
+func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
+ w := unmapWalker{
+ pageTables: p,
+ visitor: unmapVisitor{
+ count: 0,
+ },
+ }
+ w.iterateRange(uintptr(addr), uintptr(addr)+length)
+ return w.visitor.count > 0
+}
+
+// emptyVisitor is used for emptiness checks.
+type emptyVisitor struct {
+ count int
+}
+
+//go:nosplit
+func (*emptyVisitor) requiresAlloc() bool { return false }
+
+//go:nosplit
+func (*emptyVisitor) requiresSplit() bool { return false }
+
+// visit unmaps the given entry.
+//
+//go:nosplit
+func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ v.count++
+}
+
+// IsEmpty checks if the given range is empty.
+//
+// Precondition: addr & length must be page-aligned.
+//
+//go:nosplit
+func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool {
+ w := emptyWalker{
+ pageTables: p,
+ }
+ w.iterateRange(uintptr(addr), uintptr(addr)+length)
+ return w.visitor.count == 0
+}
+
+// lookupVisitor is used for lookup.
+type lookupVisitor struct {
+ target uintptr // Input.
+ physical uintptr // Output.
+ opts MapOpts // Output.
+}
+
+// visit matches the given address.
+//
+//go:nosplit
+func (v *lookupVisitor) visit(start uintptr, pte *PTE, align uintptr) {
+ if !pte.Valid() {
+ return
+ }
+ v.physical = pte.Address() + (start - uintptr(v.target))
+ v.opts = pte.Opts()
+}
+
+//go:nosplit
+func (*lookupVisitor) requiresAlloc() bool { return false }
+
+//go:nosplit
+func (*lookupVisitor) requiresSplit() bool { return false }
+
+// Lookup returns the physical address for the given virtual address.
+//
+//go:nosplit
+func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) {
+ mask := uintptr(usermem.PageSize - 1)
+ offset := uintptr(addr) & mask
+ w := lookupWalker{
+ pageTables: p,
+ visitor: lookupVisitor{
+ target: uintptr(addr &^ usermem.Addr(mask)),
+ },
+ }
+ w.iterateRange(uintptr(addr), uintptr(addr)+1)
+ return w.visitor.physical + offset, w.visitor.opts
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
new file mode 100644
index 000000000..7aa6c524e
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+// Address constraints.
+//
+// The lowerTop and upperBottom currently apply to four-level pagetables;
+// additional refactoring would be necessary to support five-level pagetables.
+const (
+ lowerTop = 0x00007fffffffffff
+ upperBottom = 0xffff800000000000
+
+ pteShift = 12
+ pmdShift = 21
+ pudShift = 30
+ pgdShift = 39
+
+ pteMask = 0x1ff << pteShift
+ pmdMask = 0x1ff << pmdShift
+ pudMask = 0x1ff << pudShift
+ pgdMask = 0x1ff << pgdShift
+
+ pteSize = 1 << pteShift
+ pmdSize = 1 << pmdShift
+ pudSize = 1 << pudShift
+ pgdSize = 1 << pgdShift
+
+ executeDisable = 1 << 63
+ entriesPerPage = 512
+)
+
+// PTEs is a collection of entries.
+type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go b/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go
new file mode 100755
index 000000000..ac1ccf3d3
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package pagetables
+
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
new file mode 100644
index 000000000..ff427fbe9
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
@@ -0,0 +1,180 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build i386 amd64
+
+package pagetables
+
+import (
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// archPageTables is architecture-specific data.
+type archPageTables struct {
+ // pcid is the value assigned by PCIDs.Assign.
+ //
+ // Note that zero is a valid PCID.
+ pcid uint16
+}
+
+// CR3 returns the CR3 value for these tables.
+//
+// This may be called in interrupt contexts. A PCID of zero always implies a
+// flush and should be passed when PCIDs are not enabled. See pcids_x86.go for
+// more information.
+//
+//go:nosplit
+func (p *PageTables) CR3(noFlush bool, pcid uint16) uint64 {
+ // Bit 63 is set to avoid flushing the PCID (per SDM 4.10.4.1).
+ const noFlushBit uint64 = 0x8000000000000000
+ if noFlush && pcid != 0 {
+ return noFlushBit | uint64(p.rootPhysical) | uint64(pcid)
+ }
+ return uint64(p.rootPhysical) | uint64(pcid)
+}
+
+// Bits in page table entries.
+const (
+ present = 0x001
+ writable = 0x002
+ user = 0x004
+ writeThrough = 0x008
+ cacheDisable = 0x010
+ accessed = 0x020
+ dirty = 0x040
+ super = 0x080
+ global = 0x100
+ optionMask = executeDisable | 0xfff
+)
+
+// MapOpts are x86 options.
+type MapOpts struct {
+ // AccessType defines permissions.
+ AccessType usermem.AccessType
+
+ // Global indicates the page is globally accessible.
+ Global bool
+
+ // User indicates the page is a user page.
+ User bool
+}
+
+// PTE is a page table entry.
+type PTE uintptr
+
+// Clear clears this PTE, including super page information.
+//
+//go:nosplit
+func (p *PTE) Clear() {
+ atomic.StoreUintptr((*uintptr)(p), 0)
+}
+
+// Valid returns true iff this entry is valid.
+//
+//go:nosplit
+func (p *PTE) Valid() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&present != 0
+}
+
+// Opts returns the PTE options.
+//
+// These are all options except Valid and Super.
+//
+//go:nosplit
+func (p *PTE) Opts() MapOpts {
+ v := atomic.LoadUintptr((*uintptr)(p))
+ return MapOpts{
+ AccessType: usermem.AccessType{
+ Read: v&present != 0,
+ Write: v&writable != 0,
+ Execute: v&executeDisable == 0,
+ },
+ Global: v&global != 0,
+ User: v&user != 0,
+ }
+}
+
+// SetSuper sets this page as a super page.
+//
+// The page must not be valid or a panic will result.
+//
+//go:nosplit
+func (p *PTE) SetSuper() {
+ if p.Valid() {
+ // This is not allowed.
+ panic("SetSuper called on valid page!")
+ }
+ atomic.StoreUintptr((*uintptr)(p), super)
+}
+
+// IsSuper returns true iff this page is a super page.
+//
+//go:nosplit
+func (p *PTE) IsSuper() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&super != 0
+}
+
+// Set sets this PTE value.
+//
+// This does not change the super page property.
+//
+//go:nosplit
+func (p *PTE) Set(addr uintptr, opts MapOpts) {
+ if !opts.AccessType.Any() {
+ p.Clear()
+ return
+ }
+ v := (addr &^ optionMask) | present | accessed
+ if opts.User {
+ v |= user
+ }
+ if opts.Global {
+ v |= global
+ }
+ if !opts.AccessType.Execute {
+ v |= executeDisable
+ }
+ if opts.AccessType.Write {
+ v |= writable | dirty
+ }
+ if p.IsSuper() {
+ // Note that this is inherited from the previous instance. Set
+ // does not change the value of Super. See above.
+ v |= super
+ }
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// setPageTable sets this PTE value and forces the write bit and super bit to
+// be cleared. This is used explicitly for breaking super pages.
+//
+//go:nosplit
+func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) {
+ addr := pt.Allocator.PhysicalFor(ptes)
+ if addr&^optionMask != addr {
+ // This should never happen.
+ panic("unaligned physical address!")
+ }
+ v := addr | present | user | writable | accessed | dirty
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// Address extracts the address. This should only be used if Valid returns true.
+//
+//go:nosplit
+func (p *PTE) Address() uintptr {
+ return atomic.LoadUintptr((*uintptr)(p)) &^ optionMask
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
new file mode 100644
index 000000000..0f029f25d
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
@@ -0,0 +1,109 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build i386 amd64
+
+package pagetables
+
+import (
+ "sync"
+)
+
+// limitPCID is the number of valid PCIDs.
+const limitPCID = 4096
+
+// PCIDs is a simple PCID database.
+//
+// This is not protected by locks and is thus suitable for use only with a
+// single CPU at a time.
+type PCIDs struct {
+ // mu protects below.
+ mu sync.Mutex
+
+ // cache are the assigned page tables.
+ cache map[*PageTables]uint16
+
+ // avail are available PCIDs.
+ avail []uint16
+}
+
+// NewPCIDs returns a new PCID database.
+//
+// start is the first index to assign. Typically this will be one, as the zero
+// pcid will always be flushed on transition (see pagetables_x86.go). This may
+// be more than one if specific PCIDs are reserved.
+//
+// Nil is returned iff the start and size are out of range.
+func NewPCIDs(start, size uint16) *PCIDs {
+ if start+uint16(size) >= limitPCID {
+ return nil // See comment.
+ }
+ p := &PCIDs{
+ cache: make(map[*PageTables]uint16),
+ }
+ for pcid := start; pcid < start+size; pcid++ {
+ p.avail = append(p.avail, pcid)
+ }
+ return p
+}
+
+// Assign assigns a PCID to the given PageTables.
+//
+// This may overwrite any previous assignment provided. If this in the case,
+// true is returned to indicate that the PCID should be flushed.
+func (p *PCIDs) Assign(pt *PageTables) (uint16, bool) {
+ p.mu.Lock()
+ if pcid, ok := p.cache[pt]; ok {
+ p.mu.Unlock()
+ return pcid, false // No flush.
+ }
+
+ // Is there something available?
+ if len(p.avail) > 0 {
+ pcid := p.avail[len(p.avail)-1]
+ p.avail = p.avail[:len(p.avail)-1]
+ p.cache[pt] = pcid
+
+ // We need to flush because while this is in the available
+ // pool, it may have been used previously.
+ p.mu.Unlock()
+ return pcid, true
+ }
+
+ // Evict an existing table.
+ for old, pcid := range p.cache {
+ delete(p.cache, old)
+ p.cache[pt] = pcid
+
+ // A flush is definitely required in this case, these page
+ // tables may still be active. (They will just be assigned some
+ // other PCID if and when they hit the given CPU again.)
+ p.mu.Unlock()
+ return pcid, true
+ }
+
+ // No PCID.
+ p.mu.Unlock()
+ return 0, false
+}
+
+// Drop drops references to a set of page tables.
+func (p *PCIDs) Drop(pt *PageTables) {
+ p.mu.Lock()
+ if pcid, ok := p.cache[pt]; ok {
+ delete(p.cache, pt)
+ p.avail = append(p.avail, pcid)
+ }
+ p.mu.Unlock()
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_empty.go b/pkg/sentry/platform/ring0/pagetables/walker_empty.go
new file mode 100755
index 000000000..417784e17
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/walker_empty.go
@@ -0,0 +1,255 @@
+package pagetables
+
+// Walker walks page tables.
+type emptyWalker struct {
+ // pageTables are the tables to walk.
+ pageTables *PageTables
+
+ // Visitor is the set of arguments.
+ visitor emptyVisitor
+}
+
+// iterateRange iterates over all appropriate levels of page tables for the given range.
+//
+// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The
+// exception is super pages. If a valid super page (huge or jumbo) cannot be
+// installed, then the walk will continue to individual entries.
+//
+// This algorithm will attempt to maximize the use of super pages whenever
+// possible. Whether a super page is provided will be clear through the range
+// provided in the callback.
+//
+// Note that if requiresAlloc is true, then no gaps will be present. However,
+// if alloc is not set, then the iteration will likely be full of gaps.
+//
+// Note that this function should generally be avoided in favor of Map, Unmap,
+// etc. when not necessary.
+//
+// Precondition: start must be page-aligned.
+//
+// Precondition: start must be less than end.
+//
+// Precondition: If requiresAlloc is true, then start and end should not span
+// non-canonical ranges. If they do, a panic will result.
+//
+//go:nosplit
+func (w *emptyWalker) iterateRange(start, end uintptr) {
+ if start%pteSize != 0 {
+ panic("unaligned start")
+ }
+ if end < start {
+ panic("start > end")
+ }
+ if start < lowerTop {
+ if end <= lowerTop {
+ w.iterateRangeCanonical(start, end)
+ } else if end > lowerTop && end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else if start < upperBottom {
+ if end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else {
+ w.iterateRangeCanonical(start, end)
+ }
+}
+
+// next returns the next address quantized by the given size.
+//
+//go:nosplit
+func emptynext(start uintptr, size uintptr) uintptr {
+ start &= ^(size - 1)
+ start += size
+ return start
+}
+
+// iterateRangeCanonical walks a canonical range.
+//
+//go:nosplit
+func (w *emptyWalker) iterateRangeCanonical(start, end uintptr) {
+ for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
+ var (
+ pgdEntry = &w.pageTables.root[pgdIndex]
+ pudEntries *PTEs
+ )
+ if !pgdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ start = emptynext(start, pgdSize)
+ continue
+ }
+
+ pudEntries = w.pageTables.Allocator.NewPTEs()
+ pgdEntry.setPageTable(w.pageTables, pudEntries)
+ } else {
+ pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address())
+ }
+
+ clearPUDEntries := uint16(0)
+
+ for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ {
+ var (
+ pudEntry = &pudEntries[pudIndex]
+ pmdEntries *PTEs
+ )
+ if !pudEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ clearPUDEntries++
+ start = emptynext(start, pudSize)
+ continue
+ }
+
+ if start&(pudSize-1) == 0 && end-start >= pudSize {
+ pudEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+ if pudEntry.Valid() {
+ start = emptynext(start, pudSize)
+ continue
+ }
+ }
+
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+
+ } else if pudEntry.IsSuper() {
+
+ if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < emptynext(start, pudSize)) {
+
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pmdEntries[index].SetSuper()
+ pmdEntries[index].Set(
+ pudEntry.Address()+(pmdSize*uintptr(index)),
+ pudEntry.Opts())
+ }
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+ } else {
+
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+
+ if !pudEntry.Valid() {
+ clearPUDEntries++
+ }
+
+ start = emptynext(start, pudSize)
+ continue
+ }
+ } else {
+ pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address())
+ }
+
+ clearPMDEntries := uint16(0)
+
+ for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ {
+ var (
+ pmdEntry = &pmdEntries[pmdIndex]
+ pteEntries *PTEs
+ )
+ if !pmdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ clearPMDEntries++
+ start = emptynext(start, pmdSize)
+ continue
+ }
+
+ if start&(pmdSize-1) == 0 && end-start >= pmdSize {
+ pmdEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+ if pmdEntry.Valid() {
+ start = emptynext(start, pmdSize)
+ continue
+ }
+ }
+
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+
+ } else if pmdEntry.IsSuper() {
+
+ if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < emptynext(start, pmdSize)) {
+
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pteEntries[index].Set(
+ pmdEntry.Address()+(pteSize*uintptr(index)),
+ pmdEntry.Opts())
+ }
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+ } else {
+
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+
+ if !pmdEntry.Valid() {
+ clearPMDEntries++
+ }
+
+ start = emptynext(start, pmdSize)
+ continue
+ }
+ } else {
+ pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address())
+ }
+
+ clearPTEEntries := uint16(0)
+
+ for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ {
+ var (
+ pteEntry = &pteEntries[pteIndex]
+ )
+ if !pteEntry.Valid() && !w.visitor.requiresAlloc() {
+ clearPTEEntries++
+ start += pteSize
+ continue
+ }
+
+ w.visitor.visit(uintptr(start), pteEntry, pteSize-1)
+ if !pteEntry.Valid() {
+ if w.visitor.requiresAlloc() {
+ panic("PTE not set after iteration with requiresAlloc!")
+ }
+ clearPTEEntries++
+ }
+
+ start += pteSize
+ continue
+ }
+
+ if clearPTEEntries == entriesPerPage {
+ pmdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pteEntries)
+ clearPMDEntries++
+ }
+ }
+
+ if clearPMDEntries == entriesPerPage {
+ pudEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pmdEntries)
+ clearPUDEntries++
+ }
+ }
+
+ if clearPUDEntries == entriesPerPage {
+ pgdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pudEntries)
+ }
+ }
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_lookup.go b/pkg/sentry/platform/ring0/pagetables/walker_lookup.go
new file mode 100755
index 000000000..906c9c50f
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/walker_lookup.go
@@ -0,0 +1,255 @@
+package pagetables
+
+// Walker walks page tables.
+type lookupWalker struct {
+ // pageTables are the tables to walk.
+ pageTables *PageTables
+
+ // Visitor is the set of arguments.
+ visitor lookupVisitor
+}
+
+// iterateRange iterates over all appropriate levels of page tables for the given range.
+//
+// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The
+// exception is super pages. If a valid super page (huge or jumbo) cannot be
+// installed, then the walk will continue to individual entries.
+//
+// This algorithm will attempt to maximize the use of super pages whenever
+// possible. Whether a super page is provided will be clear through the range
+// provided in the callback.
+//
+// Note that if requiresAlloc is true, then no gaps will be present. However,
+// if alloc is not set, then the iteration will likely be full of gaps.
+//
+// Note that this function should generally be avoided in favor of Map, Unmap,
+// etc. when not necessary.
+//
+// Precondition: start must be page-aligned.
+//
+// Precondition: start must be less than end.
+//
+// Precondition: If requiresAlloc is true, then start and end should not span
+// non-canonical ranges. If they do, a panic will result.
+//
+//go:nosplit
+func (w *lookupWalker) iterateRange(start, end uintptr) {
+ if start%pteSize != 0 {
+ panic("unaligned start")
+ }
+ if end < start {
+ panic("start > end")
+ }
+ if start < lowerTop {
+ if end <= lowerTop {
+ w.iterateRangeCanonical(start, end)
+ } else if end > lowerTop && end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else if start < upperBottom {
+ if end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else {
+ w.iterateRangeCanonical(start, end)
+ }
+}
+
+// next returns the next address quantized by the given size.
+//
+//go:nosplit
+func lookupnext(start uintptr, size uintptr) uintptr {
+ start &= ^(size - 1)
+ start += size
+ return start
+}
+
+// iterateRangeCanonical walks a canonical range.
+//
+//go:nosplit
+func (w *lookupWalker) iterateRangeCanonical(start, end uintptr) {
+ for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
+ var (
+ pgdEntry = &w.pageTables.root[pgdIndex]
+ pudEntries *PTEs
+ )
+ if !pgdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ start = lookupnext(start, pgdSize)
+ continue
+ }
+
+ pudEntries = w.pageTables.Allocator.NewPTEs()
+ pgdEntry.setPageTable(w.pageTables, pudEntries)
+ } else {
+ pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address())
+ }
+
+ clearPUDEntries := uint16(0)
+
+ for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ {
+ var (
+ pudEntry = &pudEntries[pudIndex]
+ pmdEntries *PTEs
+ )
+ if !pudEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ clearPUDEntries++
+ start = lookupnext(start, pudSize)
+ continue
+ }
+
+ if start&(pudSize-1) == 0 && end-start >= pudSize {
+ pudEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+ if pudEntry.Valid() {
+ start = lookupnext(start, pudSize)
+ continue
+ }
+ }
+
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+
+ } else if pudEntry.IsSuper() {
+
+ if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < lookupnext(start, pudSize)) {
+
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pmdEntries[index].SetSuper()
+ pmdEntries[index].Set(
+ pudEntry.Address()+(pmdSize*uintptr(index)),
+ pudEntry.Opts())
+ }
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+ } else {
+
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+
+ if !pudEntry.Valid() {
+ clearPUDEntries++
+ }
+
+ start = lookupnext(start, pudSize)
+ continue
+ }
+ } else {
+ pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address())
+ }
+
+ clearPMDEntries := uint16(0)
+
+ for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ {
+ var (
+ pmdEntry = &pmdEntries[pmdIndex]
+ pteEntries *PTEs
+ )
+ if !pmdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ clearPMDEntries++
+ start = lookupnext(start, pmdSize)
+ continue
+ }
+
+ if start&(pmdSize-1) == 0 && end-start >= pmdSize {
+ pmdEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+ if pmdEntry.Valid() {
+ start = lookupnext(start, pmdSize)
+ continue
+ }
+ }
+
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+
+ } else if pmdEntry.IsSuper() {
+
+ if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < lookupnext(start, pmdSize)) {
+
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pteEntries[index].Set(
+ pmdEntry.Address()+(pteSize*uintptr(index)),
+ pmdEntry.Opts())
+ }
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+ } else {
+
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+
+ if !pmdEntry.Valid() {
+ clearPMDEntries++
+ }
+
+ start = lookupnext(start, pmdSize)
+ continue
+ }
+ } else {
+ pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address())
+ }
+
+ clearPTEEntries := uint16(0)
+
+ for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ {
+ var (
+ pteEntry = &pteEntries[pteIndex]
+ )
+ if !pteEntry.Valid() && !w.visitor.requiresAlloc() {
+ clearPTEEntries++
+ start += pteSize
+ continue
+ }
+
+ w.visitor.visit(uintptr(start), pteEntry, pteSize-1)
+ if !pteEntry.Valid() {
+ if w.visitor.requiresAlloc() {
+ panic("PTE not set after iteration with requiresAlloc!")
+ }
+ clearPTEEntries++
+ }
+
+ start += pteSize
+ continue
+ }
+
+ if clearPTEEntries == entriesPerPage {
+ pmdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pteEntries)
+ clearPMDEntries++
+ }
+ }
+
+ if clearPMDEntries == entriesPerPage {
+ pudEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pmdEntries)
+ clearPUDEntries++
+ }
+ }
+
+ if clearPUDEntries == entriesPerPage {
+ pgdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pudEntries)
+ }
+ }
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_map.go b/pkg/sentry/platform/ring0/pagetables/walker_map.go
new file mode 100755
index 000000000..61ee3c825
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/walker_map.go
@@ -0,0 +1,255 @@
+package pagetables
+
+// Walker walks page tables.
+type mapWalker struct {
+ // pageTables are the tables to walk.
+ pageTables *PageTables
+
+ // Visitor is the set of arguments.
+ visitor mapVisitor
+}
+
+// iterateRange iterates over all appropriate levels of page tables for the given range.
+//
+// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The
+// exception is super pages. If a valid super page (huge or jumbo) cannot be
+// installed, then the walk will continue to individual entries.
+//
+// This algorithm will attempt to maximize the use of super pages whenever
+// possible. Whether a super page is provided will be clear through the range
+// provided in the callback.
+//
+// Note that if requiresAlloc is true, then no gaps will be present. However,
+// if alloc is not set, then the iteration will likely be full of gaps.
+//
+// Note that this function should generally be avoided in favor of Map, Unmap,
+// etc. when not necessary.
+//
+// Precondition: start must be page-aligned.
+//
+// Precondition: start must be less than end.
+//
+// Precondition: If requiresAlloc is true, then start and end should not span
+// non-canonical ranges. If they do, a panic will result.
+//
+//go:nosplit
+func (w *mapWalker) iterateRange(start, end uintptr) {
+ if start%pteSize != 0 {
+ panic("unaligned start")
+ }
+ if end < start {
+ panic("start > end")
+ }
+ if start < lowerTop {
+ if end <= lowerTop {
+ w.iterateRangeCanonical(start, end)
+ } else if end > lowerTop && end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else if start < upperBottom {
+ if end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else {
+ w.iterateRangeCanonical(start, end)
+ }
+}
+
+// next returns the next address quantized by the given size.
+//
+//go:nosplit
+func mapnext(start uintptr, size uintptr) uintptr {
+ start &= ^(size - 1)
+ start += size
+ return start
+}
+
+// iterateRangeCanonical walks a canonical range.
+//
+//go:nosplit
+func (w *mapWalker) iterateRangeCanonical(start, end uintptr) {
+ for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
+ var (
+ pgdEntry = &w.pageTables.root[pgdIndex]
+ pudEntries *PTEs
+ )
+ if !pgdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ start = mapnext(start, pgdSize)
+ continue
+ }
+
+ pudEntries = w.pageTables.Allocator.NewPTEs()
+ pgdEntry.setPageTable(w.pageTables, pudEntries)
+ } else {
+ pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address())
+ }
+
+ clearPUDEntries := uint16(0)
+
+ for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ {
+ var (
+ pudEntry = &pudEntries[pudIndex]
+ pmdEntries *PTEs
+ )
+ if !pudEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ clearPUDEntries++
+ start = mapnext(start, pudSize)
+ continue
+ }
+
+ if start&(pudSize-1) == 0 && end-start >= pudSize {
+ pudEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+ if pudEntry.Valid() {
+ start = mapnext(start, pudSize)
+ continue
+ }
+ }
+
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+
+ } else if pudEntry.IsSuper() {
+
+ if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < mapnext(start, pudSize)) {
+
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pmdEntries[index].SetSuper()
+ pmdEntries[index].Set(
+ pudEntry.Address()+(pmdSize*uintptr(index)),
+ pudEntry.Opts())
+ }
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+ } else {
+
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+
+ if !pudEntry.Valid() {
+ clearPUDEntries++
+ }
+
+ start = mapnext(start, pudSize)
+ continue
+ }
+ } else {
+ pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address())
+ }
+
+ clearPMDEntries := uint16(0)
+
+ for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ {
+ var (
+ pmdEntry = &pmdEntries[pmdIndex]
+ pteEntries *PTEs
+ )
+ if !pmdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ clearPMDEntries++
+ start = mapnext(start, pmdSize)
+ continue
+ }
+
+ if start&(pmdSize-1) == 0 && end-start >= pmdSize {
+ pmdEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+ if pmdEntry.Valid() {
+ start = mapnext(start, pmdSize)
+ continue
+ }
+ }
+
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+
+ } else if pmdEntry.IsSuper() {
+
+ if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < mapnext(start, pmdSize)) {
+
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pteEntries[index].Set(
+ pmdEntry.Address()+(pteSize*uintptr(index)),
+ pmdEntry.Opts())
+ }
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+ } else {
+
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+
+ if !pmdEntry.Valid() {
+ clearPMDEntries++
+ }
+
+ start = mapnext(start, pmdSize)
+ continue
+ }
+ } else {
+ pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address())
+ }
+
+ clearPTEEntries := uint16(0)
+
+ for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ {
+ var (
+ pteEntry = &pteEntries[pteIndex]
+ )
+ if !pteEntry.Valid() && !w.visitor.requiresAlloc() {
+ clearPTEEntries++
+ start += pteSize
+ continue
+ }
+
+ w.visitor.visit(uintptr(start), pteEntry, pteSize-1)
+ if !pteEntry.Valid() {
+ if w.visitor.requiresAlloc() {
+ panic("PTE not set after iteration with requiresAlloc!")
+ }
+ clearPTEEntries++
+ }
+
+ start += pteSize
+ continue
+ }
+
+ if clearPTEEntries == entriesPerPage {
+ pmdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pteEntries)
+ clearPMDEntries++
+ }
+ }
+
+ if clearPMDEntries == entriesPerPage {
+ pudEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pmdEntries)
+ clearPUDEntries++
+ }
+ }
+
+ if clearPUDEntries == entriesPerPage {
+ pgdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pudEntries)
+ }
+ }
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_unmap.go b/pkg/sentry/platform/ring0/pagetables/walker_unmap.go
new file mode 100755
index 000000000..be2aa0ce4
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/walker_unmap.go
@@ -0,0 +1,255 @@
+package pagetables
+
+// Walker walks page tables.
+type unmapWalker struct {
+ // pageTables are the tables to walk.
+ pageTables *PageTables
+
+ // Visitor is the set of arguments.
+ visitor unmapVisitor
+}
+
+// iterateRange iterates over all appropriate levels of page tables for the given range.
+//
+// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The
+// exception is super pages. If a valid super page (huge or jumbo) cannot be
+// installed, then the walk will continue to individual entries.
+//
+// This algorithm will attempt to maximize the use of super pages whenever
+// possible. Whether a super page is provided will be clear through the range
+// provided in the callback.
+//
+// Note that if requiresAlloc is true, then no gaps will be present. However,
+// if alloc is not set, then the iteration will likely be full of gaps.
+//
+// Note that this function should generally be avoided in favor of Map, Unmap,
+// etc. when not necessary.
+//
+// Precondition: start must be page-aligned.
+//
+// Precondition: start must be less than end.
+//
+// Precondition: If requiresAlloc is true, then start and end should not span
+// non-canonical ranges. If they do, a panic will result.
+//
+//go:nosplit
+func (w *unmapWalker) iterateRange(start, end uintptr) {
+ if start%pteSize != 0 {
+ panic("unaligned start")
+ }
+ if end < start {
+ panic("start > end")
+ }
+ if start < lowerTop {
+ if end <= lowerTop {
+ w.iterateRangeCanonical(start, end)
+ } else if end > lowerTop && end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else if start < upperBottom {
+ if end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else {
+ w.iterateRangeCanonical(start, end)
+ }
+}
+
+// next returns the next address quantized by the given size.
+//
+//go:nosplit
+func unmapnext(start uintptr, size uintptr) uintptr {
+ start &= ^(size - 1)
+ start += size
+ return start
+}
+
+// iterateRangeCanonical walks a canonical range.
+//
+//go:nosplit
+func (w *unmapWalker) iterateRangeCanonical(start, end uintptr) {
+ for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
+ var (
+ pgdEntry = &w.pageTables.root[pgdIndex]
+ pudEntries *PTEs
+ )
+ if !pgdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ start = unmapnext(start, pgdSize)
+ continue
+ }
+
+ pudEntries = w.pageTables.Allocator.NewPTEs()
+ pgdEntry.setPageTable(w.pageTables, pudEntries)
+ } else {
+ pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address())
+ }
+
+ clearPUDEntries := uint16(0)
+
+ for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ {
+ var (
+ pudEntry = &pudEntries[pudIndex]
+ pmdEntries *PTEs
+ )
+ if !pudEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ clearPUDEntries++
+ start = unmapnext(start, pudSize)
+ continue
+ }
+
+ if start&(pudSize-1) == 0 && end-start >= pudSize {
+ pudEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+ if pudEntry.Valid() {
+ start = unmapnext(start, pudSize)
+ continue
+ }
+ }
+
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+
+ } else if pudEntry.IsSuper() {
+
+ if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < unmapnext(start, pudSize)) {
+
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pmdEntries[index].SetSuper()
+ pmdEntries[index].Set(
+ pudEntry.Address()+(pmdSize*uintptr(index)),
+ pudEntry.Opts())
+ }
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+ } else {
+
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+
+ if !pudEntry.Valid() {
+ clearPUDEntries++
+ }
+
+ start = unmapnext(start, pudSize)
+ continue
+ }
+ } else {
+ pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address())
+ }
+
+ clearPMDEntries := uint16(0)
+
+ for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ {
+ var (
+ pmdEntry = &pmdEntries[pmdIndex]
+ pteEntries *PTEs
+ )
+ if !pmdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+
+ clearPMDEntries++
+ start = unmapnext(start, pmdSize)
+ continue
+ }
+
+ if start&(pmdSize-1) == 0 && end-start >= pmdSize {
+ pmdEntry.SetSuper()
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+ if pmdEntry.Valid() {
+ start = unmapnext(start, pmdSize)
+ continue
+ }
+ }
+
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+
+ } else if pmdEntry.IsSuper() {
+
+ if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < unmapnext(start, pmdSize)) {
+
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pteEntries[index].Set(
+ pmdEntry.Address()+(pteSize*uintptr(index)),
+ pmdEntry.Opts())
+ }
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+ } else {
+
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+
+ if !pmdEntry.Valid() {
+ clearPMDEntries++
+ }
+
+ start = unmapnext(start, pmdSize)
+ continue
+ }
+ } else {
+ pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address())
+ }
+
+ clearPTEEntries := uint16(0)
+
+ for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ {
+ var (
+ pteEntry = &pteEntries[pteIndex]
+ )
+ if !pteEntry.Valid() && !w.visitor.requiresAlloc() {
+ clearPTEEntries++
+ start += pteSize
+ continue
+ }
+
+ w.visitor.visit(uintptr(start), pteEntry, pteSize-1)
+ if !pteEntry.Valid() {
+ if w.visitor.requiresAlloc() {
+ panic("PTE not set after iteration with requiresAlloc!")
+ }
+ clearPTEEntries++
+ }
+
+ start += pteSize
+ continue
+ }
+
+ if clearPTEEntries == entriesPerPage {
+ pmdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pteEntries)
+ clearPMDEntries++
+ }
+ }
+
+ if clearPMDEntries == entriesPerPage {
+ pudEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pmdEntries)
+ clearPUDEntries++
+ }
+ }
+
+ if clearPUDEntries == entriesPerPage {
+ pgdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pudEntries)
+ }
+ }
+}
diff --git a/pkg/sentry/platform/ring0/ring0.go b/pkg/sentry/platform/ring0/ring0.go
new file mode 100644
index 000000000..cdeb1b43a
--- /dev/null
+++ b/pkg/sentry/platform/ring0/ring0.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ring0 provides basic operating system-level stubs.
+package ring0
diff --git a/pkg/sentry/platform/ring0/ring0_state_autogen.go b/pkg/sentry/platform/ring0/ring0_state_autogen.go
new file mode 100755
index 000000000..462f9a446
--- /dev/null
+++ b/pkg/sentry/platform/ring0/ring0_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package ring0
+
diff --git a/pkg/sentry/platform/safecopy/atomic_amd64.s b/pkg/sentry/platform/safecopy/atomic_amd64.s
new file mode 100644
index 000000000..a0cd78f33
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/atomic_amd64.s
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// handleSwapUint32Fault returns the value stored in DI. Control is transferred
+// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in DI.
+//
+// It must have the same frame configuration as swapUint32 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24
+ MOVL DI, sig+20(FP)
+ RET
+
+// swapUint32 atomically stores new into *addr and returns (the previous *addr
+// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
+// value of old is unspecified, and sig is the number of the signal that was
+// received.
+//
+// Preconditions: addr must be aligned to a 4-byte boundary.
+//
+//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32)
+TEXT ·swapUint32(SB), NOSPLIT, $0-24
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleSwapUint32Fault will store a different value in this address.
+ MOVL $0, sig+20(FP)
+
+ MOVQ addr+0(FP), DI
+ MOVL new+8(FP), AX
+ XCHGL AX, 0(DI)
+ MOVL AX, old+16(FP)
+ RET
+
+// handleSwapUint64Fault returns the value stored in DI. Control is transferred
+// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in DI.
+//
+// It must have the same frame configuration as swapUint64 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28
+ MOVL DI, sig+24(FP)
+ RET
+
+// swapUint64 atomically stores new into *addr and returns (the previous *addr
+// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
+// value of old is unspecified, and sig is the number of the signal that was
+// received.
+//
+// Preconditions: addr must be aligned to a 8-byte boundary.
+//
+//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32)
+TEXT ·swapUint64(SB), NOSPLIT, $0-28
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleSwapUint64Fault will store a different value in this address.
+ MOVL $0, sig+24(FP)
+
+ MOVQ addr+0(FP), DI
+ MOVQ new+8(FP), AX
+ XCHGQ AX, 0(DI)
+ MOVQ AX, old+16(FP)
+ RET
+
+// handleCompareAndSwapUint32Fault returns the value stored in DI. Control is
+// transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the
+// signal number stored in DI.
+//
+// It must have the same frame configuration as compareAndSwapUint32 so that it
+// can undo any potential call frame set up by the assembler.
+TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24
+ MOVL DI, sig+20(FP)
+ RET
+
+// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
+// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is
+// received during the operation, the value of prev is unspecified, and sig is
+// the number of the signal that was received.
+//
+// Preconditions: addr must be aligned to a 4-byte boundary.
+//
+//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32)
+TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
+ // Store 0 as the returned signal number. If we run to completion, this is
+ // the value the caller will see; if a signal is received,
+ // handleCompareAndSwapUint32Fault will store a different value in this
+ // address.
+ MOVL $0, sig+20(FP)
+
+ MOVQ addr+0(FP), DI
+ MOVL old+8(FP), AX
+ MOVL new+12(FP), DX
+ LOCK
+ CMPXCHGL DX, 0(DI)
+ MOVL AX, prev+16(FP)
+ RET
+
+// handleLoadUint32Fault returns the value stored in DI. Control is transferred
+// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in DI.
+//
+// It must have the same frame configuration as loadUint32 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16
+ MOVL DI, sig+12(FP)
+ RET
+
+// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS
+// signal is received, the value returned is unspecified, and sig is the number
+// of the signal that was received.
+//
+// Preconditions: addr must be aligned to a 4-byte boundary.
+//
+//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
+TEXT ·loadUint32(SB), NOSPLIT, $0-16
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleLoadUint32Fault will store a different value in this address.
+ MOVL $0, sig+12(FP)
+
+ MOVQ addr+0(FP), AX
+ MOVL (AX), BX
+ MOVL BX, val+8(FP)
+ RET
diff --git a/pkg/sentry/platform/safecopy/atomic_arm64.s b/pkg/sentry/platform/safecopy/atomic_arm64.s
new file mode 100644
index 000000000..d58ed71f7
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/atomic_arm64.s
@@ -0,0 +1,126 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// handleSwapUint32Fault returns the value stored in R1. Control is transferred
+// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in R1.
+//
+// It must have the same frame configuration as swapUint32 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24
+ MOVW R1, sig+20(FP)
+ RET
+
+// See the corresponding doc in safecopy_unsafe.go
+//
+// The code is derived from Go source runtime/internal/atomic.Xchg.
+//
+//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32)
+TEXT ·swapUint32(SB), NOSPLIT, $0-24
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleSwapUint32Fault will store a different value in this address.
+ MOVW $0, sig+20(FP)
+again:
+ MOVD addr+0(FP), R0
+ MOVW new+8(FP), R1
+ LDAXRW (R0), R2
+ STLXRW R1, (R0), R3
+ CBNZ R3, again
+ MOVW R2, old+16(FP)
+ RET
+
+// handleSwapUint64Fault returns the value stored in R1. Control is transferred
+// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in R1.
+//
+// It must have the same frame configuration as swapUint64 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28
+ MOVW R1, sig+24(FP)
+ RET
+
+// See the corresponding doc in safecopy_unsafe.go
+//
+// The code is derived from Go source runtime/internal/atomic.Xchg64.
+//
+//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32)
+TEXT ·swapUint64(SB), NOSPLIT, $0-28
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleSwapUint64Fault will store a different value in this address.
+ MOVW $0, sig+24(FP)
+again:
+ MOVD addr+0(FP), R0
+ MOVD new+8(FP), R1
+ LDAXR (R0), R2
+ STLXR R1, (R0), R3
+ CBNZ R3, again
+ MOVD R2, old+16(FP)
+ RET
+
+// handleCompareAndSwapUint32Fault returns the value stored in R1. Control is
+// transferred to it when compareAndSwapUint32 below receives SIGSEGV or SIGBUS,
+// with the signal number stored in R1.
+//
+// It must have the same frame configuration as compareAndSwapUint32 so that it
+// can undo any potential call frame set up by the assembler.
+TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24
+ MOVW R1, sig+20(FP)
+ RET
+
+// See the corresponding doc in safecopy_unsafe.go
+//
+// The code is derived from Go source runtime/internal/atomic.Cas.
+//
+//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32)
+TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
+ // Store 0 as the returned signal number. If we run to completion, this is
+ // the value the caller will see; if a signal is received,
+ // handleCompareAndSwapUint32Fault will store a different value in this
+ // address.
+ MOVW $0, sig+20(FP)
+
+ MOVD addr+0(FP), R0
+ MOVW old+8(FP), R1
+ MOVW new+12(FP), R2
+again:
+ LDAXRW (R0), R3
+ CMPW R1, R3
+ BNE done
+ STLXRW R2, (R0), R4
+ CBNZ R4, again
+done:
+ MOVW R3, prev+16(FP)
+ RET
+
+// handleLoadUint32Fault returns the value stored in DI. Control is transferred
+// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal
+// number stored in DI.
+//
+// It must have the same frame configuration as loadUint32 so that it can undo
+// any potential call frame set up by the assembler.
+TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16
+ MOVW R1, sig+12(FP)
+ RET
+
+// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS
+// signal is received, the value returned is unspecified, and sig is the number
+// of the signal that was received.
+//
+// Preconditions: addr must be aligned to a 4-byte boundary.
+//
+//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
+TEXT ·loadUint32(SB), NOSPLIT, $0-16
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleLoadUint32Fault will store a different value in this address.
+ MOVW $0, sig+12(FP)
+
+ MOVD addr+0(FP), R0
+ LDARW (R0), R1
+ MOVW R1, val+8(FP)
+ RET
diff --git a/pkg/sentry/platform/safecopy/memclr_amd64.s b/pkg/sentry/platform/safecopy/memclr_amd64.s
new file mode 100644
index 000000000..64cf32f05
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/memclr_amd64.s
@@ -0,0 +1,147 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// handleMemclrFault returns (the value stored in AX, the value stored in DI).
+// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS,
+// with the faulting address stored in AX and the signal number stored in DI.
+//
+// It must have the same frame configuration as memclr so that it can undo any
+// potential call frame set up by the assembler.
+TEXT handleMemclrFault(SB), NOSPLIT, $0-28
+ MOVQ AX, addr+16(FP)
+ MOVL DI, sig+24(FP)
+ RET
+
+// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS
+// signal is received during the write, it returns the address that caused the
+// fault and the number of the signal that was received. Otherwise, it returns
+// an unspecified address and a signal number of 0.
+//
+// Data is written in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully written.
+//
+// The code is derived from runtime.memclrNoHeapPointers.
+//
+// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+TEXT ·memclr(SB), NOSPLIT, $0-28
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleMemclrFault will store a different value in this address.
+ MOVL $0, sig+24(FP)
+
+ MOVQ ptr+0(FP), DI
+ MOVQ n+8(FP), BX
+ XORQ AX, AX
+
+ // MOVOU seems always faster than REP STOSQ.
+tail:
+ TESTQ BX, BX
+ JEQ _0
+ CMPQ BX, $2
+ JBE _1or2
+ CMPQ BX, $4
+ JBE _3or4
+ CMPQ BX, $8
+ JB _5through7
+ JE _8
+ CMPQ BX, $16
+ JBE _9through16
+ PXOR X0, X0
+ CMPQ BX, $32
+ JBE _17through32
+ CMPQ BX, $64
+ JBE _33through64
+ CMPQ BX, $128
+ JBE _65through128
+ CMPQ BX, $256
+ JBE _129through256
+ // TODO: use branch table and BSR to make this just a single dispatch
+ // TODO: for really big clears, use MOVNTDQ, even without AVX2.
+
+loop:
+ MOVOU X0, 0(DI)
+ MOVOU X0, 16(DI)
+ MOVOU X0, 32(DI)
+ MOVOU X0, 48(DI)
+ MOVOU X0, 64(DI)
+ MOVOU X0, 80(DI)
+ MOVOU X0, 96(DI)
+ MOVOU X0, 112(DI)
+ MOVOU X0, 128(DI)
+ MOVOU X0, 144(DI)
+ MOVOU X0, 160(DI)
+ MOVOU X0, 176(DI)
+ MOVOU X0, 192(DI)
+ MOVOU X0, 208(DI)
+ MOVOU X0, 224(DI)
+ MOVOU X0, 240(DI)
+ SUBQ $256, BX
+ ADDQ $256, DI
+ CMPQ BX, $256
+ JAE loop
+ JMP tail
+
+_1or2:
+ MOVB AX, (DI)
+ MOVB AX, -1(DI)(BX*1)
+ RET
+_0:
+ RET
+_3or4:
+ MOVW AX, (DI)
+ MOVW AX, -2(DI)(BX*1)
+ RET
+_5through7:
+ MOVL AX, (DI)
+ MOVL AX, -4(DI)(BX*1)
+ RET
+_8:
+ // We need a separate case for 8 to make sure we clear pointers atomically.
+ MOVQ AX, (DI)
+ RET
+_9through16:
+ MOVQ AX, (DI)
+ MOVQ AX, -8(DI)(BX*1)
+ RET
+_17through32:
+ MOVOU X0, (DI)
+ MOVOU X0, -16(DI)(BX*1)
+ RET
+_33through64:
+ MOVOU X0, (DI)
+ MOVOU X0, 16(DI)
+ MOVOU X0, -32(DI)(BX*1)
+ MOVOU X0, -16(DI)(BX*1)
+ RET
+_65through128:
+ MOVOU X0, (DI)
+ MOVOU X0, 16(DI)
+ MOVOU X0, 32(DI)
+ MOVOU X0, 48(DI)
+ MOVOU X0, -64(DI)(BX*1)
+ MOVOU X0, -48(DI)(BX*1)
+ MOVOU X0, -32(DI)(BX*1)
+ MOVOU X0, -16(DI)(BX*1)
+ RET
+_129through256:
+ MOVOU X0, (DI)
+ MOVOU X0, 16(DI)
+ MOVOU X0, 32(DI)
+ MOVOU X0, 48(DI)
+ MOVOU X0, 64(DI)
+ MOVOU X0, 80(DI)
+ MOVOU X0, 96(DI)
+ MOVOU X0, 112(DI)
+ MOVOU X0, -128(DI)(BX*1)
+ MOVOU X0, -112(DI)(BX*1)
+ MOVOU X0, -96(DI)(BX*1)
+ MOVOU X0, -80(DI)(BX*1)
+ MOVOU X0, -64(DI)(BX*1)
+ MOVOU X0, -48(DI)(BX*1)
+ MOVOU X0, -32(DI)(BX*1)
+ MOVOU X0, -16(DI)(BX*1)
+ RET
diff --git a/pkg/sentry/platform/safecopy/memclr_arm64.s b/pkg/sentry/platform/safecopy/memclr_arm64.s
new file mode 100644
index 000000000..7361b9067
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/memclr_arm64.s
@@ -0,0 +1,74 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// handleMemclrFault returns (the value stored in R0, the value stored in R1).
+// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS,
+// with the faulting address stored in R0 and the signal number stored in R1.
+//
+// It must have the same frame configuration as memclr so that it can undo any
+// potential call frame set up by the assembler.
+TEXT handleMemclrFault(SB), NOSPLIT, $0-28
+ MOVD R0, addr+16(FP)
+ MOVW R1, sig+24(FP)
+ RET
+
+// See the corresponding doc in safecopy_unsafe.go
+//
+// The code is derived from runtime.memclrNoHeapPointers.
+//
+// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+TEXT ·memclr(SB), NOSPLIT, $0-28
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleMemclrFault will store a different value in this address.
+ MOVW $0, sig+24(FP)
+ MOVD ptr+0(FP), R0
+ MOVD n+8(FP), R1
+
+ // If size is less than 16 bytes, use tail_zero to zero what remains
+ CMP $16, R1
+ BLT tail_zero
+ // Get buffer offset into 16 byte aligned address for better performance
+ ANDS $15, R0, ZR
+ BNE unaligned_to_16
+aligned_to_16:
+ LSR $4, R1, R2
+zero_by_16:
+ STP.P (ZR, ZR), 16(R0) // Store pair with post index.
+ SUBS $1, R2, R2
+ BNE zero_by_16
+ ANDS $15, R1, R1
+ BEQ end
+
+ // Zero buffer with size=R1 < 16
+tail_zero:
+ TBZ $3, R1, tail_zero_4
+ MOVD.P ZR, 8(R0)
+tail_zero_4:
+ TBZ $2, R1, tail_zero_2
+ MOVW.P ZR, 4(R0)
+tail_zero_2:
+ TBZ $1, R1, tail_zero_1
+ MOVH.P ZR, 2(R0)
+tail_zero_1:
+ TBZ $0, R1, end
+ MOVB ZR, (R0)
+end:
+ RET
+
+unaligned_to_16:
+ MOVD R0, R2
+head_loop:
+ MOVBU.P ZR, 1(R0)
+ ANDS $15, R0, ZR
+ BNE head_loop
+ // Adjust length for what remains
+ SUB R2, R0, R3
+ SUB R3, R1
+ // If size is less than 16 bytes, use tail_zero to zero what remains
+ CMP $16, R1
+ BLT tail_zero
+ B aligned_to_16
diff --git a/pkg/sentry/platform/safecopy/memcpy_amd64.s b/pkg/sentry/platform/safecopy/memcpy_amd64.s
new file mode 100644
index 000000000..129691d68
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/memcpy_amd64.s
@@ -0,0 +1,250 @@
+// Copyright © 1994-1999 Lucent Technologies Inc. All rights reserved.
+// Revisions Copyright © 2000-2007 Vita Nuova Holdings Limited (www.vitanuova.com). All rights reserved.
+// Portions Copyright 2009 The Go Authors. All rights reserved.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+#include "textflag.h"
+
+// handleMemcpyFault returns (the value stored in AX, the value stored in DI).
+// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS,
+// with the faulting address stored in AX and the signal number stored in DI.
+//
+// It must have the same frame configuration as memcpy so that it can undo any
+// potential call frame set up by the assembler.
+TEXT handleMemcpyFault(SB), NOSPLIT, $0-36
+ MOVQ AX, addr+24(FP)
+ MOVL DI, sig+32(FP)
+ RET
+
+// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received
+// during the copy, it returns the address that caused the fault and the number
+// of the signal that was received. Otherwise, it returns an unspecified address
+// and a signal number of 0.
+//
+// Data is copied in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully copied.
+//
+// The code is derived from the forward copying part of runtime.memmove.
+//
+// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+TEXT ·memcpy(SB), NOSPLIT, $0-36
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleMemcpyFault will store a different value in this address.
+ MOVL $0, sig+32(FP)
+
+ MOVQ to+0(FP), DI
+ MOVQ from+8(FP), SI
+ MOVQ n+16(FP), BX
+
+ // REP instructions have a high startup cost, so we handle small sizes
+ // with some straightline code. The REP MOVSQ instruction is really fast
+ // for large sizes. The cutover is approximately 2K.
+tail:
+ // move_129through256 or smaller work whether or not the source and the
+ // destination memory regions overlap because they load all data into
+ // registers before writing it back. move_256through2048 on the other
+ // hand can be used only when the memory regions don't overlap or the copy
+ // direction is forward.
+ TESTQ BX, BX
+ JEQ move_0
+ CMPQ BX, $2
+ JBE move_1or2
+ CMPQ BX, $4
+ JBE move_3or4
+ CMPQ BX, $8
+ JB move_5through7
+ JE move_8
+ CMPQ BX, $16
+ JBE move_9through16
+ CMPQ BX, $32
+ JBE move_17through32
+ CMPQ BX, $64
+ JBE move_33through64
+ CMPQ BX, $128
+ JBE move_65through128
+ CMPQ BX, $256
+ JBE move_129through256
+ // TODO: use branch table and BSR to make this just a single dispatch
+
+/*
+ * forward copy loop
+ */
+ CMPQ BX, $2048
+ JLS move_256through2048
+
+ // Check alignment
+ MOVL SI, AX
+ ORL DI, AX
+ TESTL $7, AX
+ JEQ fwdBy8
+
+ // Do 1 byte at a time
+ MOVQ BX, CX
+ REP; MOVSB
+ RET
+
+fwdBy8:
+ // Do 8 bytes at a time
+ MOVQ BX, CX
+ SHRQ $3, CX
+ ANDQ $7, BX
+ REP; MOVSQ
+ JMP tail
+
+move_1or2:
+ MOVB (SI), AX
+ MOVB AX, (DI)
+ MOVB -1(SI)(BX*1), CX
+ MOVB CX, -1(DI)(BX*1)
+ RET
+move_0:
+ RET
+move_3or4:
+ MOVW (SI), AX
+ MOVW AX, (DI)
+ MOVW -2(SI)(BX*1), CX
+ MOVW CX, -2(DI)(BX*1)
+ RET
+move_5through7:
+ MOVL (SI), AX
+ MOVL AX, (DI)
+ MOVL -4(SI)(BX*1), CX
+ MOVL CX, -4(DI)(BX*1)
+ RET
+move_8:
+ // We need a separate case for 8 to make sure we write pointers atomically.
+ MOVQ (SI), AX
+ MOVQ AX, (DI)
+ RET
+move_9through16:
+ MOVQ (SI), AX
+ MOVQ AX, (DI)
+ MOVQ -8(SI)(BX*1), CX
+ MOVQ CX, -8(DI)(BX*1)
+ RET
+move_17through32:
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU -16(SI)(BX*1), X1
+ MOVOU X1, -16(DI)(BX*1)
+ RET
+move_33through64:
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU -32(SI)(BX*1), X2
+ MOVOU X2, -32(DI)(BX*1)
+ MOVOU -16(SI)(BX*1), X3
+ MOVOU X3, -16(DI)(BX*1)
+ RET
+move_65through128:
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU -64(SI)(BX*1), X4
+ MOVOU X4, -64(DI)(BX*1)
+ MOVOU -48(SI)(BX*1), X5
+ MOVOU X5, -48(DI)(BX*1)
+ MOVOU -32(SI)(BX*1), X6
+ MOVOU X6, -32(DI)(BX*1)
+ MOVOU -16(SI)(BX*1), X7
+ MOVOU X7, -16(DI)(BX*1)
+ RET
+move_129through256:
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU 64(SI), X4
+ MOVOU X4, 64(DI)
+ MOVOU 80(SI), X5
+ MOVOU X5, 80(DI)
+ MOVOU 96(SI), X6
+ MOVOU X6, 96(DI)
+ MOVOU 112(SI), X7
+ MOVOU X7, 112(DI)
+ MOVOU -128(SI)(BX*1), X8
+ MOVOU X8, -128(DI)(BX*1)
+ MOVOU -112(SI)(BX*1), X9
+ MOVOU X9, -112(DI)(BX*1)
+ MOVOU -96(SI)(BX*1), X10
+ MOVOU X10, -96(DI)(BX*1)
+ MOVOU -80(SI)(BX*1), X11
+ MOVOU X11, -80(DI)(BX*1)
+ MOVOU -64(SI)(BX*1), X12
+ MOVOU X12, -64(DI)(BX*1)
+ MOVOU -48(SI)(BX*1), X13
+ MOVOU X13, -48(DI)(BX*1)
+ MOVOU -32(SI)(BX*1), X14
+ MOVOU X14, -32(DI)(BX*1)
+ MOVOU -16(SI)(BX*1), X15
+ MOVOU X15, -16(DI)(BX*1)
+ RET
+move_256through2048:
+ SUBQ $256, BX
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU 64(SI), X4
+ MOVOU X4, 64(DI)
+ MOVOU 80(SI), X5
+ MOVOU X5, 80(DI)
+ MOVOU 96(SI), X6
+ MOVOU X6, 96(DI)
+ MOVOU 112(SI), X7
+ MOVOU X7, 112(DI)
+ MOVOU 128(SI), X8
+ MOVOU X8, 128(DI)
+ MOVOU 144(SI), X9
+ MOVOU X9, 144(DI)
+ MOVOU 160(SI), X10
+ MOVOU X10, 160(DI)
+ MOVOU 176(SI), X11
+ MOVOU X11, 176(DI)
+ MOVOU 192(SI), X12
+ MOVOU X12, 192(DI)
+ MOVOU 208(SI), X13
+ MOVOU X13, 208(DI)
+ MOVOU 224(SI), X14
+ MOVOU X14, 224(DI)
+ MOVOU 240(SI), X15
+ MOVOU X15, 240(DI)
+ CMPQ BX, $256
+ LEAQ 256(SI), SI
+ LEAQ 256(DI), DI
+ JGE move_256through2048
+ JMP tail
diff --git a/pkg/sentry/platform/safecopy/memcpy_arm64.s b/pkg/sentry/platform/safecopy/memcpy_arm64.s
new file mode 100644
index 000000000..e7e541565
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/memcpy_arm64.s
@@ -0,0 +1,78 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// handleMemcpyFault returns (the value stored in R0, the value stored in R1).
+// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS,
+// with the faulting address stored in R0 and the signal number stored in R1.
+//
+// It must have the same frame configuration as memcpy so that it can undo any
+// potential call frame set up by the assembler.
+TEXT handleMemcpyFault(SB), NOSPLIT, $0-36
+ MOVD R0, addr+24(FP)
+ MOVW R1, sig+32(FP)
+ RET
+
+// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received
+// during the copy, it returns the address that caused the fault and the number
+// of the signal that was received. Otherwise, it returns an unspecified address
+// and a signal number of 0.
+//
+// Data is copied in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully copied.
+//
+// The code is derived from the Go source runtime.memmove.
+//
+// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+TEXT ·memcpy(SB), NOSPLIT, $-8-36
+ // Store 0 as the returned signal number. If we run to completion,
+ // this is the value the caller will see; if a signal is received,
+ // handleMemcpyFault will store a different value in this address.
+ MOVW $0, sig+32(FP)
+
+ MOVD to+0(FP), R3
+ MOVD from+8(FP), R4
+ MOVD n+16(FP), R5
+ CMP $0, R5
+ BNE check
+ RET
+
+check:
+ AND $~7, R5, R7 // R7 is N&~7.
+ SUB R7, R5, R6 // R6 is N&7.
+
+ // Copying forward proceeds by copying R7/8 words then copying R6 bytes.
+ // R3 and R4 are advanced as we copy.
+
+ // (There may be implementations of armv8 where copying by bytes until
+ // at least one of source or dest is word aligned is a worthwhile
+ // optimization, but the on the one tested so far (xgene) it did not
+ // make a significance difference.)
+
+ CMP $0, R7 // Do we need to do any word-by-word copying?
+ BEQ noforwardlarge
+ ADD R3, R7, R9 // R9 points just past where we copy by word.
+
+forwardlargeloop:
+ MOVD.P 8(R4), R8 // R8 is just a scratch register.
+ MOVD.P R8, 8(R3)
+ CMP R3, R9
+ BNE forwardlargeloop
+
+noforwardlarge:
+ CMP $0, R6 // Do we need to do any byte-by-byte copying?
+ BNE forwardtail
+ RET
+
+forwardtail:
+ ADD R3, R6, R9 // R9 points just past the destination memory.
+
+forwardtailloop:
+ MOVBU.P 1(R4), R8
+ MOVBU.P R8, 1(R3)
+ CMP R3, R9
+ BNE forwardtailloop
+ RET
diff --git a/pkg/sentry/platform/safecopy/safecopy.go b/pkg/sentry/platform/safecopy/safecopy.go
new file mode 100644
index 000000000..5126871eb
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/safecopy.go
@@ -0,0 +1,144 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package safecopy provides an efficient implementation of functions to access
+// memory that may result in SIGSEGV or SIGBUS being sent to the accessor.
+package safecopy
+
+import (
+ "fmt"
+ "reflect"
+ "runtime"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// SegvError is returned when a safecopy function receives SIGSEGV.
+type SegvError struct {
+ // Addr is the address at which the SIGSEGV occurred.
+ Addr uintptr
+}
+
+// Error implements error.Error.
+func (e SegvError) Error() string {
+ return fmt.Sprintf("SIGSEGV at %#x", e.Addr)
+}
+
+// BusError is returned when a safecopy function receives SIGBUS.
+type BusError struct {
+ // Addr is the address at which the SIGBUS occurred.
+ Addr uintptr
+}
+
+// Error implements error.Error.
+func (e BusError) Error() string {
+ return fmt.Sprintf("SIGBUS at %#x", e.Addr)
+}
+
+// AlignmentError is returned when a safecopy function is passed an address
+// that does not meet alignment requirements.
+type AlignmentError struct {
+ // Addr is the invalid address.
+ Addr uintptr
+
+ // Alignment is the required alignment.
+ Alignment uintptr
+}
+
+// Error implements error.Error.
+func (e AlignmentError) Error() string {
+ return fmt.Sprintf("address %#x is not aligned to a %d-byte boundary", e.Addr, e.Alignment)
+}
+
+var (
+ // The begin and end addresses below are for the functions that are
+ // checked by the signal handler.
+ memcpyBegin uintptr
+ memcpyEnd uintptr
+ memclrBegin uintptr
+ memclrEnd uintptr
+ swapUint32Begin uintptr
+ swapUint32End uintptr
+ swapUint64Begin uintptr
+ swapUint64End uintptr
+ compareAndSwapUint32Begin uintptr
+ compareAndSwapUint32End uintptr
+ loadUint32Begin uintptr
+ loadUint32End uintptr
+
+ // savedSigSegVHandler is a pointer to the SIGSEGV handler that was
+ // configured before we replaced it with our own. We still call into it
+ // when we get a SIGSEGV that is not interesting to us.
+ savedSigSegVHandler uintptr
+
+ // same a above, but for SIGBUS signals.
+ savedSigBusHandler uintptr
+)
+
+// signalHandler is our replacement signal handler for SIGSEGV and SIGBUS
+// signals.
+func signalHandler()
+
+// FindEndAddress returns the end address (one byte beyond the last) of the
+// function that contains the specified address (begin).
+func FindEndAddress(begin uintptr) uintptr {
+ f := runtime.FuncForPC(begin)
+ if f != nil {
+ for p := begin; ; p++ {
+ g := runtime.FuncForPC(p)
+ if f != g {
+ return p
+ }
+ }
+ }
+ return begin
+}
+
+// initializeAddresses initializes the addresses used by the signal handler.
+func initializeAddresses() {
+ // The following functions are written in assembly language, so they won't
+ // be inlined by the existing compiler/linker. Tests will fail if this
+ // assumption is violated.
+ memcpyBegin = reflect.ValueOf(memcpy).Pointer()
+ memcpyEnd = FindEndAddress(memcpyBegin)
+ memclrBegin = reflect.ValueOf(memclr).Pointer()
+ memclrEnd = FindEndAddress(memclrBegin)
+ swapUint32Begin = reflect.ValueOf(swapUint32).Pointer()
+ swapUint32End = FindEndAddress(swapUint32Begin)
+ swapUint64Begin = reflect.ValueOf(swapUint64).Pointer()
+ swapUint64End = FindEndAddress(swapUint64Begin)
+ compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer()
+ compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin)
+ loadUint32Begin = reflect.ValueOf(loadUint32).Pointer()
+ loadUint32End = FindEndAddress(loadUint32Begin)
+}
+
+func init() {
+ initializeAddresses()
+ if err := ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err))
+ }
+ if err := ReplaceSignalHandler(syscall.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err))
+ }
+ syserror.AddErrorUnwrapper(func(e error) (syscall.Errno, bool) {
+ switch e.(type) {
+ case SegvError, BusError, AlignmentError:
+ return syscall.EFAULT, true
+ default:
+ return 0, false
+ }
+ })
+}
diff --git a/pkg/sentry/platform/safecopy/safecopy_state_autogen.go b/pkg/sentry/platform/safecopy/safecopy_state_autogen.go
new file mode 100755
index 000000000..58fd8fbd0
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/safecopy_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package safecopy
+
diff --git a/pkg/sentry/platform/safecopy/safecopy_unsafe.go b/pkg/sentry/platform/safecopy/safecopy_unsafe.go
new file mode 100644
index 000000000..eef028e68
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/safecopy_unsafe.go
@@ -0,0 +1,335 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safecopy
+
+import (
+ "fmt"
+ "syscall"
+ "unsafe"
+)
+
+// maxRegisterSize is the maximum register size used in memcpy and memclr. It
+// is used to decide by how much to rewind the copy (for memcpy) or zeroing
+// (for memclr) before proceeding.
+const maxRegisterSize = 16
+
+// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received
+// during the copy, it returns the address that caused the fault and the number
+// of the signal that was received. Otherwise, it returns an unspecified address
+// and a signal number of 0.
+//
+// Data is copied in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully copied.
+//
+//go:noescape
+func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+
+// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS
+// signal is received during the write, it returns the address that caused the
+// fault and the number of the signal that was received. Otherwise, it returns
+// an unspecified address and a signal number of 0.
+//
+// Data is written in order, such that if a fault happens at address p, it is
+// safe to assume that all data before p-maxRegisterSize has already been
+// successfully written.
+//
+//go:noescape
+func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+
+// swapUint32 atomically stores new into *ptr and returns (the previous *ptr
+// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
+// value of old is unspecified, and sig is the number of the signal that was
+// received.
+//
+// Preconditions: ptr must be aligned to a 4-byte boundary.
+//
+//go:noescape
+func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32)
+
+// swapUint64 atomically stores new into *ptr and returns (the previous *ptr
+// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
+// value of old is unspecified, and sig is the number of the signal that was
+// received.
+//
+// Preconditions: ptr must be aligned to a 8-byte boundary.
+//
+//go:noescape
+func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32)
+
+// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
+// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is
+// received during the operation, the value of prev is unspecified, and sig is
+// the number of the signal that was received.
+//
+// Preconditions: ptr must be aligned to a 4-byte boundary.
+//
+//go:noescape
+func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32)
+
+// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
+// may fail with SIGSEGV or SIGBUS if it is received while reading from ptr.
+//
+// Preconditions: ptr must be aligned to a 4-byte boundary.
+//
+//go:noescape
+func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
+
+// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes
+// copied and an error if SIGSEGV or SIGBUS is received while reading from src.
+func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
+ toCopy := uintptr(len(dst))
+ if len(dst) == 0 {
+ return 0, nil
+ }
+
+ fault, sig := memcpy(unsafe.Pointer(&dst[0]), src, toCopy)
+ if sig == 0 {
+ return len(dst), nil
+ }
+
+ faultN, srcN := uintptr(fault), uintptr(src)
+ if faultN < srcN || faultN >= srcN+toCopy {
+ panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, faultN, srcN, srcN+toCopy))
+ }
+
+ // memcpy might have ended the copy up to maxRegisterSize bytes before
+ // fault, if an instruction caused a memory access that straddled two
+ // pages, and the second one faulted. Try to copy up to the fault.
+ var done int
+ if faultN-srcN > maxRegisterSize {
+ done = int(faultN - srcN - maxRegisterSize)
+ }
+ n, err := CopyIn(dst[done:int(faultN-srcN)], unsafe.Pointer(srcN+uintptr(done)))
+ done += n
+ if err != nil {
+ return done, err
+ }
+ return done, errorFromFaultSignal(fault, sig)
+}
+
+// CopyOut copies len(src) bytes from src to dst. If returns the number of
+// bytes done and an error if SIGSEGV or SIGBUS is received while writing to
+// dst.
+func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
+ toCopy := uintptr(len(src))
+ if toCopy == 0 {
+ return 0, nil
+ }
+
+ fault, sig := memcpy(dst, unsafe.Pointer(&src[0]), toCopy)
+ if sig == 0 {
+ return len(src), nil
+ }
+
+ faultN, dstN := uintptr(fault), uintptr(dst)
+ if faultN < dstN || faultN >= dstN+toCopy {
+ panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toCopy))
+ }
+
+ // memcpy might have ended the copy up to maxRegisterSize bytes before
+ // fault, if an instruction caused a memory access that straddled two
+ // pages, and the second one faulted. Try to copy up to the fault.
+ var done int
+ if faultN-dstN > maxRegisterSize {
+ done = int(faultN - dstN - maxRegisterSize)
+ }
+ n, err := CopyOut(unsafe.Pointer(dstN+uintptr(done)), src[done:int(faultN-dstN)])
+ done += n
+ if err != nil {
+ return done, err
+ }
+ return done, errorFromFaultSignal(fault, sig)
+}
+
+// Copy copies toCopy bytes from src to dst. It returns the number of bytes
+// copied and an error if SIGSEGV or SIGBUS is received while reading from src
+// or writing to dst.
+//
+// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap,
+// the resulting contents of dst are unspecified.
+func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
+ if toCopy == 0 {
+ return 0, nil
+ }
+
+ fault, sig := memcpy(dst, src, toCopy)
+ if sig == 0 {
+ return toCopy, nil
+ }
+
+ // Did the fault occur while reading from src or writing to dst?
+ faultN, srcN, dstN := uintptr(fault), uintptr(src), uintptr(dst)
+ faultAfterSrc := ^uintptr(0)
+ if faultN >= srcN {
+ faultAfterSrc = faultN - srcN
+ }
+ faultAfterDst := ^uintptr(0)
+ if faultN >= dstN {
+ faultAfterDst = faultN - dstN
+ }
+ if faultAfterSrc >= toCopy && faultAfterDst >= toCopy {
+ panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, faultN, srcN, srcN+toCopy, dstN, dstN+toCopy))
+ }
+ faultedAfter := faultAfterSrc
+ if faultedAfter > faultAfterDst {
+ faultedAfter = faultAfterDst
+ }
+
+ // memcpy might have ended the copy up to maxRegisterSize bytes before
+ // fault, if an instruction caused a memory access that straddled two
+ // pages, and the second one faulted. Try to copy up to the fault.
+ var done uintptr
+ if faultedAfter > maxRegisterSize {
+ done = faultedAfter - maxRegisterSize
+ }
+ n, err := Copy(unsafe.Pointer(dstN+done), unsafe.Pointer(srcN+done), faultedAfter-done)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ return done, errorFromFaultSignal(fault, sig)
+}
+
+// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes
+// written and an error if SIGSEGV or SIGBUS is received while writing to dst.
+func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
+ if toZero == 0 {
+ return 0, nil
+ }
+
+ fault, sig := memclr(dst, toZero)
+ if sig == 0 {
+ return toZero, nil
+ }
+
+ faultN, dstN := uintptr(fault), uintptr(dst)
+ if faultN < dstN || faultN >= dstN+toZero {
+ panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toZero))
+ }
+
+ // memclr might have ended the write up to maxRegisterSize bytes before
+ // fault, if an instruction caused a memory access that straddled two
+ // pages, and the second one faulted. Try to write up to the fault.
+ var done uintptr
+ if faultN-dstN > maxRegisterSize {
+ done = faultN - dstN - maxRegisterSize
+ }
+ n, err := ZeroOut(unsafe.Pointer(dstN+done), faultN-dstN-done)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ return done, errorFromFaultSignal(fault, sig)
+}
+
+// SwapUint32 is equivalent to sync/atomic.SwapUint32, except that it returns
+// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is
+// not aligned to a 4-byte boundary.
+func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) {
+ if addr := uintptr(ptr); addr&3 != 0 {
+ return 0, AlignmentError{addr, 4}
+ }
+ old, sig := swapUint32(ptr, new)
+ return old, errorFromFaultSignal(ptr, sig)
+}
+
+// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns
+// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is
+// not aligned to an 8-byte boundary.
+func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) {
+ if addr := uintptr(ptr); addr&7 != 0 {
+ return 0, AlignmentError{addr, 8}
+ }
+ old, sig := swapUint64(ptr, new)
+ return old, errorFromFaultSignal(ptr, sig)
+}
+
+// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32,
+// except that it returns an error if SIGSEGV or SIGBUS is received while
+// accessing ptr, or if ptr is not aligned to a 4-byte boundary.
+func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) {
+ if addr := uintptr(ptr); addr&3 != 0 {
+ return 0, AlignmentError{addr, 4}
+ }
+ prev, sig := compareAndSwapUint32(ptr, old, new)
+ return prev, errorFromFaultSignal(ptr, sig)
+}
+
+// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
+// may fail with SIGSEGV or SIGBUS if it is received while reading from ptr.
+//
+// Preconditions: ptr must be aligned to a 4-byte boundary.
+func LoadUint32(ptr unsafe.Pointer) (uint32, error) {
+ if addr := uintptr(ptr); addr&3 != 0 {
+ return 0, AlignmentError{addr, 4}
+ }
+ val, sig := loadUint32(ptr)
+ return val, errorFromFaultSignal(ptr, sig)
+}
+
+func errorFromFaultSignal(addr unsafe.Pointer, sig int32) error {
+ switch sig {
+ case 0:
+ return nil
+ case int32(syscall.SIGSEGV):
+ return SegvError{uintptr(addr)}
+ case int32(syscall.SIGBUS):
+ return BusError{uintptr(addr)}
+ default:
+ panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
+ }
+}
+
+// ReplaceSignalHandler replaces the existing signal handler for the provided
+// signal with the one that handles faults in safecopy-protected functions.
+//
+// It stores the value of the previously set handler in previous.
+//
+// This function will be called on initialization in order to install safecopy
+// handlers for appropriate signals. These handlers will call the previous
+// handler however, and if this is function is being used externally then the
+// same courtesy is expected.
+func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr) error {
+ var sa struct {
+ handler uintptr
+ flags uint64
+ restorer uintptr
+ mask uint64
+ }
+ const maskLen = 8
+
+ // Get the existing signal handler information, and save the current
+ // handler. Once we replace it, we will use this pointer to fall back to
+ // it when we receive other signals.
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ // Fail if there isn't a previous handler.
+ if sa.handler == 0 {
+ return fmt.Errorf("previous handler for signal %x isn't set", sig)
+ }
+
+ *previous = sa.handler
+
+ // Install our own handler.
+ sa.handler = handler
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/platform/safecopy/sighandler_amd64.s b/pkg/sentry/platform/safecopy/sighandler_amd64.s
new file mode 100644
index 000000000..475ae48e9
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/sighandler_amd64.s
@@ -0,0 +1,133 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// The signals handled by sigHandler.
+#define SIGBUS 7
+#define SIGSEGV 11
+
+// Offsets to the registers in context->uc_mcontext.gregs[].
+#define REG_RDI 0x68
+#define REG_RAX 0x90
+#define REG_IP 0xa8
+
+// Offset to the si_addr field of siginfo.
+#define SI_CODE 0x08
+#define SI_ADDR 0x10
+
+// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must
+// not be set up as a handler to any other signals.
+//
+// If the instruction causing the signal is within a safecopy-protected
+// function, the signal is handled such that execution resumes in the
+// appropriate fault handling stub with AX containing the faulting address and
+// DI containing the signal number. Otherwise control is transferred to the
+// previously configured signal handler (savedSigSegvHandler or
+// savedSigBusHandler).
+//
+// This function cannot be written in go because it runs whenever a signal is
+// received by the thread (preempting whatever was running), which includes when
+// garbage collector has stopped or isn't expecting any interactions (like
+// barriers).
+//
+// The arguments are the following:
+// DI - The signal number.
+// SI - Pointer to siginfo_t structure.
+// DX - Pointer to ucontext structure.
+TEXT ·signalHandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel.
+ MOVQ $0x0, CX
+ CMPL CX, SI_CODE(SI)
+ JGE original_handler
+
+ // Check if RIP is within the area we care about.
+ MOVQ REG_IP(DX), CX
+ CMPQ CX, ·memcpyBegin(SB)
+ JB not_memcpy
+ CMPQ CX, ·memcpyEnd(SB)
+ JAE not_memcpy
+
+ // Modify the context such that execution will resume in the fault
+ // handler.
+ LEAQ handleMemcpyFault(SB), CX
+ JMP handle_fault
+
+not_memcpy:
+ CMPQ CX, ·memclrBegin(SB)
+ JB not_memclr
+ CMPQ CX, ·memclrEnd(SB)
+ JAE not_memclr
+
+ LEAQ handleMemclrFault(SB), CX
+ JMP handle_fault
+
+not_memclr:
+ CMPQ CX, ·swapUint32Begin(SB)
+ JB not_swapuint32
+ CMPQ CX, ·swapUint32End(SB)
+ JAE not_swapuint32
+
+ LEAQ handleSwapUint32Fault(SB), CX
+ JMP handle_fault
+
+not_swapuint32:
+ CMPQ CX, ·swapUint64Begin(SB)
+ JB not_swapuint64
+ CMPQ CX, ·swapUint64End(SB)
+ JAE not_swapuint64
+
+ LEAQ handleSwapUint64Fault(SB), CX
+ JMP handle_fault
+
+not_swapuint64:
+ CMPQ CX, ·compareAndSwapUint32Begin(SB)
+ JB not_casuint32
+ CMPQ CX, ·compareAndSwapUint32End(SB)
+ JAE not_casuint32
+
+ LEAQ handleCompareAndSwapUint32Fault(SB), CX
+ JMP handle_fault
+
+not_casuint32:
+ CMPQ CX, ·loadUint32Begin(SB)
+ JB not_loaduint32
+ CMPQ CX, ·loadUint32End(SB)
+ JAE not_loaduint32
+
+ LEAQ handleLoadUint32Fault(SB), CX
+ JMP handle_fault
+
+not_loaduint32:
+original_handler:
+ // Jump to the previous signal handler, which is likely the golang one.
+ XORQ CX, CX
+ MOVQ ·savedSigBusHandler(SB), AX
+ CMPL DI, $SIGSEGV
+ CMOVQEQ ·savedSigSegVHandler(SB), AX
+ JMP AX
+
+handle_fault:
+ // Entered with the address of the fault handler in RCX; store it in
+ // RIP.
+ MOVQ CX, REG_IP(DX)
+
+ // Store the faulting address in RAX.
+ MOVQ SI_ADDR(SI), CX
+ MOVQ CX, REG_RAX(DX)
+
+ // Store the signal number in EDI.
+ MOVL DI, REG_RDI(DX)
+
+ RET
diff --git a/pkg/sentry/platform/safecopy/sighandler_arm64.s b/pkg/sentry/platform/safecopy/sighandler_arm64.s
new file mode 100644
index 000000000..53e4ac2c1
--- /dev/null
+++ b/pkg/sentry/platform/safecopy/sighandler_arm64.s
@@ -0,0 +1,143 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// The signals handled by sigHandler.
+#define SIGBUS 7
+#define SIGSEGV 11
+
+// Offsets to the registers in context->uc_mcontext.gregs[].
+#define REG_R0 0xB8
+#define REG_R1 0xC0
+#define REG_PC 0x1B8
+
+// Offset to the si_addr field of siginfo.
+#define SI_CODE 0x08
+#define SI_ADDR 0x10
+
+// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must
+// not be set up as a handler to any other signals.
+//
+// If the instruction causing the signal is within a safecopy-protected
+// function, the signal is handled such that execution resumes in the
+// appropriate fault handling stub with R0 containing the faulting address and
+// R1 containing the signal number. Otherwise control is transferred to the
+// previously configured signal handler (savedSigSegvHandler or
+// savedSigBusHandler).
+//
+// This function cannot be written in go because it runs whenever a signal is
+// received by the thread (preempting whatever was running), which includes when
+// garbage collector has stopped or isn't expecting any interactions (like
+// barriers).
+//
+// The arguments are the following:
+// R0 - The signal number.
+// R1 - Pointer to siginfo_t structure.
+// R2 - Pointer to ucontext structure.
+TEXT ·signalHandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel, si_code > 0 means a kernel signal.
+ MOVD SI_CODE(R1), R7
+ CMPW $0x0, R7
+ BLE original_handler
+
+ // Check if PC is within the area we care about.
+ MOVD REG_PC(R2), R7
+ MOVD ·memcpyBegin(SB), R8
+ CMP R8, R7
+ BLO not_memcpy
+ MOVD ·memcpyEnd(SB), R8
+ CMP R8, R7
+ BHS not_memcpy
+
+ // Modify the context such that execution will resume in the fault handler.
+ MOVD $handleMemcpyFault(SB), R7
+ B handle_fault
+
+not_memcpy:
+ MOVD ·memclrBegin(SB), R8
+ CMP R8, R7
+ BLO not_memclr
+ MOVD ·memclrEnd(SB), R8
+ CMP R8, R7
+ BHS not_memclr
+
+ MOVD $handleMemclrFault(SB), R7
+ B handle_fault
+
+not_memclr:
+ MOVD ·swapUint32Begin(SB), R8
+ CMP R8, R7
+ BLO not_swapuint32
+ MOVD ·swapUint32End(SB), R8
+ CMP R8, R7
+ BHS not_swapuint32
+
+ MOVD $handleSwapUint32Fault(SB), R7
+ B handle_fault
+
+not_swapuint32:
+ MOVD ·swapUint64Begin(SB), R8
+ CMP R8, R7
+ BLO not_swapuint64
+ MOVD ·swapUint64End(SB), R8
+ CMP R8, R7
+ BHS not_swapuint64
+
+ MOVD $handleSwapUint64Fault(SB), R7
+ B handle_fault
+
+not_swapuint64:
+ MOVD ·compareAndSwapUint32Begin(SB), R8
+ CMP R8, R7
+ BLO not_casuint32
+ MOVD ·compareAndSwapUint32End(SB), R8
+ CMP R8, R7
+ BHS not_casuint32
+
+ MOVD $handleCompareAndSwapUint32Fault(SB), R7
+ B handle_fault
+
+not_casuint32:
+ MOVD ·loadUint32Begin(SB), R8
+ CMP R8, R7
+ BLO not_loaduint32
+ MOVD ·loadUint32End(SB), R8
+ CMP R8, R7
+ BHS not_loaduint32
+
+ MOVD $handleLoadUint32Fault(SB), R7
+ B handle_fault
+
+not_loaduint32:
+original_handler:
+ // Jump to the previous signal handler, which is likely the golang one.
+ MOVD ·savedSigBusHandler(SB), R7
+ MOVD ·savedSigSegVHandler(SB), R8
+ CMPW $SIGSEGV, R0
+ CSEL EQ, R8, R7, R7
+ B (R7)
+
+handle_fault:
+ // Entered with the address of the fault handler in R7; store it in PC.
+ MOVD R7, REG_PC(R2)
+
+ // Store the faulting address in R0.
+ MOVD SI_ADDR(R1), R7
+ MOVD R7, REG_R0(R2)
+
+ // Store the signal number in R1.
+ MOVW R0, REG_R1(R2)
+
+ RET
diff --git a/pkg/sentry/safemem/block_unsafe.go b/pkg/sentry/safemem/block_unsafe.go
new file mode 100644
index 000000000..1f72deb61
--- /dev/null
+++ b/pkg/sentry/safemem/block_unsafe.go
@@ -0,0 +1,279 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "fmt"
+ "reflect"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy"
+)
+
+// A Block is a range of contiguous bytes, similar to []byte but with the
+// following differences:
+//
+// - The memory represented by a Block may require the use of safecopy to
+// access.
+//
+// - Block does not carry a capacity and cannot be expanded.
+//
+// Blocks are immutable and may be copied by value. The zero value of Block
+// represents an empty range, analogous to a nil []byte.
+type Block struct {
+ // [start, start+length) is the represented memory.
+ //
+ // start is an unsafe.Pointer to ensure that Block prevents the represented
+ // memory from being garbage-collected.
+ start unsafe.Pointer
+ length int
+
+ // needSafecopy is true if accessing the represented memory requires the
+ // use of safecopy.
+ needSafecopy bool
+}
+
+// BlockFromSafeSlice returns a Block equivalent to slice, which is safe to
+// access without safecopy.
+func BlockFromSafeSlice(slice []byte) Block {
+ return blockFromSlice(slice, false)
+}
+
+// BlockFromUnsafeSlice returns a Block equivalent to bs, which is not safe to
+// access without safecopy.
+func BlockFromUnsafeSlice(slice []byte) Block {
+ return blockFromSlice(slice, true)
+}
+
+func blockFromSlice(slice []byte, needSafecopy bool) Block {
+ if len(slice) == 0 {
+ return Block{}
+ }
+ return Block{
+ start: unsafe.Pointer(&slice[0]),
+ length: len(slice),
+ needSafecopy: needSafecopy,
+ }
+}
+
+// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is
+// safe to access without safecopy.
+//
+// Preconditions: ptr+len does not overflow.
+func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block {
+ return blockFromPointer(ptr, len, false)
+}
+
+// BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which
+// is not safe to access without safecopy.
+//
+// Preconditions: ptr+len does not overflow.
+func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block {
+ return blockFromPointer(ptr, len, true)
+}
+
+func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block {
+ if uptr := uintptr(ptr); uptr+uintptr(len) < uptr {
+ panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len))
+ }
+ return Block{
+ start: ptr,
+ length: len,
+ needSafecopy: needSafecopy,
+ }
+}
+
+// DropFirst returns a Block equivalent to b, but with the first n bytes
+// omitted. It is analogous to the [n:] operation on a slice, except that if n
+// > b.Len(), DropFirst returns an empty Block instead of panicking.
+//
+// Preconditions: n >= 0.
+func (b Block) DropFirst(n int) Block {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return b.DropFirst64(uint64(n))
+}
+
+// DropFirst64 is equivalent to DropFirst but takes a uint64.
+func (b Block) DropFirst64(n uint64) Block {
+ if n >= uint64(b.length) {
+ return Block{}
+ }
+ return Block{
+ start: unsafe.Pointer(uintptr(b.start) + uintptr(n)),
+ length: b.length - int(n),
+ needSafecopy: b.needSafecopy,
+ }
+}
+
+// TakeFirst returns a Block equivalent to the first n bytes of b. It is
+// analogous to the [:n] operation on a slice, except that if n > b.Len(),
+// TakeFirst returns a copy of b instead of panicking.
+//
+// Preconditions: n >= 0.
+func (b Block) TakeFirst(n int) Block {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return b.TakeFirst64(uint64(n))
+}
+
+// TakeFirst64 is equivalent to TakeFirst but takes a uint64.
+func (b Block) TakeFirst64(n uint64) Block {
+ if n == 0 {
+ return Block{}
+ }
+ if n >= uint64(b.length) {
+ return b
+ }
+ return Block{
+ start: b.start,
+ length: int(n),
+ needSafecopy: b.needSafecopy,
+ }
+}
+
+// ToSlice returns a []byte equivalent to b.
+func (b Block) ToSlice() []byte {
+ var bs []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ hdr.Data = uintptr(b.start)
+ hdr.Len = b.length
+ hdr.Cap = b.length
+ return bs
+}
+
+// Addr returns b's start address as a uintptr. It returns uintptr instead of
+// unsafe.Pointer so that code using safemem cannot obtain unsafe.Pointers
+// without importing the unsafe package explicitly.
+//
+// Note that a uintptr is not recognized as a pointer by the garbage collector,
+// such that if there are no uses of b after a call to b.Addr() and the address
+// is to Go-managed memory, the returned uintptr does not prevent garbage
+// collection of the pointee.
+func (b Block) Addr() uintptr {
+ return uintptr(b.start)
+}
+
+// Len returns b's length in bytes.
+func (b Block) Len() int {
+ return b.length
+}
+
+// NeedSafecopy returns true if accessing b.ToSlice() requires the use of safecopy.
+func (b Block) NeedSafecopy() bool {
+ return b.needSafecopy
+}
+
+// String implements fmt.Stringer.String.
+func (b Block) String() string {
+ if uintptr(b.start) == 0 && b.length == 0 {
+ return "<nil>"
+ }
+ var suffix string
+ if b.needSafecopy {
+ suffix = "*"
+ }
+ return fmt.Sprintf("[%#x-%#x)%s", uintptr(b.start), uintptr(b.start)+uintptr(b.length), suffix)
+}
+
+// Copy copies src.Len() or dst.Len() bytes, whichever is less, from src
+// to dst and returns the number of bytes copied.
+//
+// If src and dst overlap, the data stored in dst is unspecified.
+func Copy(dst, src Block) (int, error) {
+ if !dst.needSafecopy && !src.needSafecopy {
+ return copy(dst.ToSlice(), src.ToSlice()), nil
+ }
+
+ n := dst.length
+ if n > src.length {
+ n = src.length
+ }
+ if n == 0 {
+ return 0, nil
+ }
+
+ switch {
+ case dst.needSafecopy && !src.needSafecopy:
+ return safecopy.CopyOut(dst.start, src.TakeFirst(n).ToSlice())
+ case !dst.needSafecopy && src.needSafecopy:
+ return safecopy.CopyIn(dst.TakeFirst(n).ToSlice(), src.start)
+ case dst.needSafecopy && src.needSafecopy:
+ n64, err := safecopy.Copy(dst.start, src.start, uintptr(n))
+ return int(n64), err
+ default:
+ panic("unreachable")
+ }
+}
+
+// Zero sets all bytes in dst to 0 and returns the number of bytes zeroed.
+func Zero(dst Block) (int, error) {
+ if !dst.needSafecopy {
+ bs := dst.ToSlice()
+ for i := range bs {
+ bs[i] = 0
+ }
+ return len(bs), nil
+ }
+
+ n64, err := safecopy.ZeroOut(dst.start, uintptr(dst.length))
+ return int(n64), err
+}
+
+// Safecopy atomics are no slower than non-safecopy atomics, so use the former
+// even when !b.needSafecopy to get consistent alignment checking.
+
+// SwapUint32 invokes safecopy.SwapUint32 on the first 4 bytes of b.
+//
+// Preconditions: b.Len() >= 4.
+func SwapUint32(b Block, new uint32) (uint32, error) {
+ if b.length < 4 {
+ panic(fmt.Sprintf("insufficient length: %d", b.length))
+ }
+ return safecopy.SwapUint32(b.start, new)
+}
+
+// SwapUint64 invokes safecopy.SwapUint64 on the first 8 bytes of b.
+//
+// Preconditions: b.Len() >= 8.
+func SwapUint64(b Block, new uint64) (uint64, error) {
+ if b.length < 8 {
+ panic(fmt.Sprintf("insufficient length: %d", b.length))
+ }
+ return safecopy.SwapUint64(b.start, new)
+}
+
+// CompareAndSwapUint32 invokes safecopy.CompareAndSwapUint32 on the first 4
+// bytes of b.
+//
+// Preconditions: b.Len() >= 4.
+func CompareAndSwapUint32(b Block, old, new uint32) (uint32, error) {
+ if b.length < 4 {
+ panic(fmt.Sprintf("insufficient length: %d", b.length))
+ }
+ return safecopy.CompareAndSwapUint32(b.start, old, new)
+}
+
+// LoadUint32 invokes safecopy.LoadUint32 on the first 4 bytes of b.
+//
+// Preconditions: b.Len() >= 4.
+func LoadUint32(b Block) (uint32, error) {
+ if b.length < 4 {
+ panic(fmt.Sprintf("insufficient length: %d", b.length))
+ }
+ return safecopy.LoadUint32(b.start)
+}
diff --git a/pkg/sentry/safemem/io.go b/pkg/sentry/safemem/io.go
new file mode 100644
index 000000000..5c3d73eb7
--- /dev/null
+++ b/pkg/sentry/safemem/io.go
@@ -0,0 +1,339 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "errors"
+ "io"
+ "math"
+)
+
+// ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write
+// beyond the end of the BlockSeq.
+var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq")
+
+// Reader represents a streaming byte source like io.Reader.
+type Reader interface {
+ // ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the
+ // number of bytes read. It may return a partial read without an error
+ // (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a
+ // full read with an error (i.e. (dsts.NumBytes(), err) where err != nil);
+ // note that this differs from io.Reader.Read (in particular, io.EOF should
+ // not be returned if ReadToBlocks successfully reads dsts.NumBytes()
+ // bytes.)
+ ReadToBlocks(dsts BlockSeq) (uint64, error)
+}
+
+// Writer represents a streaming byte sink like io.Writer.
+type Writer interface {
+ // WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns
+ // the number of bytes written. It may return a partial write without an
+ // error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not
+ // return a full write with an error (i.e. srcs.NumBytes(), err) where err
+ // != nil).
+ WriteFromBlocks(srcs BlockSeq) (uint64, error)
+}
+
+// ReadFullToBlocks repeatedly invokes r.ReadToBlocks until dsts.NumBytes()
+// bytes have been read or ReadToBlocks returns an error.
+func ReadFullToBlocks(r Reader, dsts BlockSeq) (uint64, error) {
+ var done uint64
+ for !dsts.IsEmpty() {
+ n, err := r.ReadToBlocks(dsts)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ dsts = dsts.DropFirst64(n)
+ }
+ return done, nil
+}
+
+// WriteFullFromBlocks repeatedly invokes w.WriteFromBlocks until
+// srcs.NumBytes() bytes have been written or WriteFromBlocks returns an error.
+func WriteFullFromBlocks(w Writer, srcs BlockSeq) (uint64, error) {
+ var done uint64
+ for !srcs.IsEmpty() {
+ n, err := w.WriteFromBlocks(srcs)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ srcs = srcs.DropFirst64(n)
+ }
+ return done, nil
+}
+
+// BlockSeqReader implements Reader by reading from a BlockSeq.
+type BlockSeqReader struct {
+ Blocks BlockSeq
+}
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ n, err := CopySeq(dsts, r.Blocks)
+ r.Blocks = r.Blocks.DropFirst64(n)
+ if err != nil {
+ return n, err
+ }
+ if n < dsts.NumBytes() {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// BlockSeqWriter implements Writer by writing to a BlockSeq.
+type BlockSeqWriter struct {
+ Blocks BlockSeq
+}
+
+// WriteFromBlocks implements Writer.WriteFromBlocks.
+func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
+ n, err := CopySeq(w.Blocks, srcs)
+ w.Blocks = w.Blocks.DropFirst64(n)
+ if err != nil {
+ return n, err
+ }
+ if n < srcs.NumBytes() {
+ return n, ErrEndOfBlockSeq
+ }
+ return n, nil
+}
+
+// ReaderFunc implements Reader for a function with the semantics of
+// Reader.ReadToBlocks.
+type ReaderFunc func(dsts BlockSeq) (uint64, error)
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ return f(dsts)
+}
+
+// WriterFunc implements Writer for a function with the semantics of
+// Writer.WriteFromBlocks.
+type WriterFunc func(srcs BlockSeq) (uint64, error)
+
+// WriteFromBlocks implements Writer.WriteFromBlocks.
+func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
+ return f(srcs)
+}
+
+// ToIOReader implements io.Reader for a (safemem.)Reader.
+//
+// ToIOReader will return a successful partial read iff Reader.ReadToBlocks does
+// so.
+type ToIOReader struct {
+ Reader Reader
+}
+
+// Read implements io.Reader.Read.
+func (r ToIOReader) Read(dst []byte) (int, error) {
+ n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst)))
+ return int(n), err
+}
+
+// ToIOWriter implements io.Writer for a (safemem.)Writer.
+type ToIOWriter struct {
+ Writer Writer
+}
+
+// Write implements io.Writer.Write.
+func (w ToIOWriter) Write(src []byte) (int, error) {
+ // io.Writer does not permit partial writes.
+ n, err := WriteFullFromBlocks(w.Writer, BlockSeqOf(BlockFromSafeSlice(src)))
+ return int(n), err
+}
+
+// FromIOReader implements Reader for an io.Reader by repeatedly invoking
+// io.Reader.Read until it returns an error or partial read.
+//
+// FromIOReader will return a successful partial read iff Reader.Read does so.
+type FromIOReader struct {
+ Reader io.Reader
+}
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ var buf []byte
+ var done uint64
+ for !dsts.IsEmpty() {
+ dst := dsts.Head()
+ var n int
+ var err error
+ n, buf, err = r.readToBlock(dst, buf)
+ done += uint64(n)
+ if n != dst.Len() {
+ return done, err
+ }
+ dsts = dsts.Tail()
+ if err != nil {
+ if dsts.IsEmpty() && err == io.EOF {
+ return done, nil
+ }
+ return done, err
+ }
+ }
+ return done, nil
+}
+
+func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) {
+ // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require
+ // safecopy.
+ if !dst.NeedSafecopy() {
+ n, err := r.Reader.Read(dst.ToSlice())
+ return n, buf, err
+ }
+ if len(buf) < dst.Len() {
+ buf = make([]byte, dst.Len())
+ }
+ rn, rerr := r.Reader.Read(buf[:dst.Len()])
+ wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn]))
+ if wberr != nil {
+ return wbn, buf, wberr
+ }
+ return wbn, buf, rerr
+}
+
+// FromIOWriter implements Writer for an io.Writer by repeatedly invoking
+// io.Writer.Write until it returns an error or partial write.
+//
+// FromIOWriter will tolerate implementations of io.Writer.Write that return
+// partial writes with a nil error in contravention of io.Writer's
+// requirements, since Writer is permitted to do so. FromIOWriter will return a
+// successful partial write iff Writer.Write does so.
+type FromIOWriter struct {
+ Writer io.Writer
+}
+
+// WriteFromBlocks implements Writer.WriteFromBlocks.
+func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
+ var buf []byte
+ var done uint64
+ for !srcs.IsEmpty() {
+ src := srcs.Head()
+ var n int
+ var err error
+ n, buf, err = w.writeFromBlock(src, buf)
+ done += uint64(n)
+ if n != src.Len() || err != nil {
+ return done, err
+ }
+ srcs = srcs.Tail()
+ }
+ return done, nil
+}
+
+func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) {
+ // io.Writer isn't safecopy-aware, so we have to buffer Blocks that require
+ // safecopy.
+ if !src.NeedSafecopy() {
+ n, err := w.Writer.Write(src.ToSlice())
+ return n, buf, err
+ }
+ if len(buf) < src.Len() {
+ buf = make([]byte, src.Len())
+ }
+ bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src)
+ wn, werr := w.Writer.Write(buf[:bufn])
+ if werr != nil {
+ return wn, buf, werr
+ }
+ return wn, buf, buferr
+}
+
+// FromVecReaderFunc implements Reader for a function that reads data into a
+// [][]byte and returns the number of bytes read as an int64.
+type FromVecReaderFunc struct {
+ ReadVec func(dsts [][]byte) (int64, error)
+}
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+//
+// ReadToBlocks calls r.ReadVec at most once.
+func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ // Ensure that we don't pass a [][]byte with a total length > MaxInt64.
+ dsts = dsts.TakeFirst64(uint64(math.MaxInt64))
+ dstSlices := make([][]byte, 0, dsts.NumBlocks())
+ // Buffer Blocks that require safecopy.
+ for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() {
+ dst := tmp.Head()
+ if dst.NeedSafecopy() {
+ dstSlices = append(dstSlices, make([]byte, dst.Len()))
+ } else {
+ dstSlices = append(dstSlices, dst.ToSlice())
+ }
+ }
+ rn, rerr := r.ReadVec(dstSlices)
+ dsts = dsts.TakeFirst64(uint64(rn))
+ var done uint64
+ var i int
+ for !dsts.IsEmpty() {
+ dst := dsts.Head()
+ if dst.NeedSafecopy() {
+ n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i]))
+ done += uint64(n)
+ if err != nil {
+ return done, err
+ }
+ } else {
+ done += uint64(dst.Len())
+ }
+ dsts = dsts.Tail()
+ i++
+ }
+ return done, rerr
+}
+
+// FromVecWriterFunc implements Writer for a function that writes data from a
+// [][]byte and returns the number of bytes written.
+type FromVecWriterFunc struct {
+ WriteVec func(srcs [][]byte) (int64, error)
+}
+
+// WriteFromBlocks implements Writer.WriteFromBlocks.
+//
+// WriteFromBlocks calls w.WriteVec at most once.
+func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ // Ensure that we don't pass a [][]byte with a total length > MaxInt64.
+ srcs = srcs.TakeFirst64(uint64(math.MaxInt64))
+ srcSlices := make([][]byte, 0, srcs.NumBlocks())
+ // Buffer Blocks that require safecopy.
+ var buferr error
+ for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() {
+ src := tmp.Head()
+ if src.NeedSafecopy() {
+ slice := make([]byte, src.Len())
+ n, err := Copy(BlockFromSafeSlice(slice), src)
+ srcSlices = append(srcSlices, slice[:n])
+ if err != nil {
+ buferr = err
+ break
+ }
+ } else {
+ srcSlices = append(srcSlices, src.ToSlice())
+ }
+ }
+ n, err := w.WriteVec(srcSlices)
+ if err != nil {
+ return uint64(n), err
+ }
+ return uint64(n), buferr
+}
diff --git a/pkg/sentry/safemem/safemem.go b/pkg/sentry/safemem/safemem.go
new file mode 100644
index 000000000..3e70d33a2
--- /dev/null
+++ b/pkg/sentry/safemem/safemem.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package safemem provides the Block and BlockSeq types.
+package safemem
diff --git a/pkg/sentry/safemem/safemem_state_autogen.go b/pkg/sentry/safemem/safemem_state_autogen.go
new file mode 100755
index 000000000..7264df0b1
--- /dev/null
+++ b/pkg/sentry/safemem/safemem_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package safemem
+
diff --git a/pkg/sentry/safemem/seq_unsafe.go b/pkg/sentry/safemem/seq_unsafe.go
new file mode 100644
index 000000000..354a95dde
--- /dev/null
+++ b/pkg/sentry/safemem/seq_unsafe.go
@@ -0,0 +1,299 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "unsafe"
+)
+
+// A BlockSeq represents a sequence of Blocks, each of which has non-zero
+// length.
+//
+// BlockSeqs are immutable and may be copied by value. The zero value of
+// BlockSeq represents an empty sequence.
+type BlockSeq struct {
+ // If length is 0, then the BlockSeq is empty. Invariants: data == 0;
+ // offset == 0; limit == 0.
+ //
+ // If length is -1, then the BlockSeq represents the single Block{data,
+ // limit, false}. Invariants: offset == 0; limit > 0; limit does not
+ // overflow the range of an int.
+ //
+ // If length is -2, then the BlockSeq represents the single Block{data,
+ // limit, true}. Invariants: offset == 0; limit > 0; limit does not
+ // overflow the range of an int.
+ //
+ // Otherwise, length >= 2, and the BlockSeq represents the `length` Blocks
+ // in the array of Blocks starting at address `data`, starting at `offset`
+ // bytes into the first Block and limited to the following `limit` bytes.
+ // Invariants: data != 0; offset < len(data[0]); limit > 0; offset+limit <=
+ // the combined length of all Blocks in the array; the first Block in the
+ // array has non-zero length.
+ //
+ // length is never 1; sequences consisting of a single Block are always
+ // stored inline (with length < 0).
+ data unsafe.Pointer
+ length int
+ offset int
+ limit uint64
+}
+
+// BlockSeqOf returns a BlockSeq representing the single Block b.
+func BlockSeqOf(b Block) BlockSeq {
+ bs := BlockSeq{
+ data: b.start,
+ length: -1,
+ limit: uint64(b.length),
+ }
+ if b.needSafecopy {
+ bs.length = -2
+ }
+ return bs
+}
+
+// BlockSeqFromSlice returns a BlockSeq representing all Blocks in slice.
+// If slice contains Blocks with zero length, BlockSeq will skip them during
+// iteration.
+//
+// Whether the returned BlockSeq shares memory with slice is unspecified;
+// clients should avoid mutating slices passed to BlockSeqFromSlice.
+//
+// Preconditions: The combined length of all Blocks in slice <= math.MaxUint64.
+func BlockSeqFromSlice(slice []Block) BlockSeq {
+ slice = skipEmpty(slice)
+ var limit uint64
+ for _, b := range slice {
+ sum := limit + uint64(b.Len())
+ if sum < limit {
+ panic("BlockSeq length overflows uint64")
+ }
+ limit = sum
+ }
+ return blockSeqFromSliceLimited(slice, limit)
+}
+
+// Preconditions: The combined length of all Blocks in slice <= limit. If
+// len(slice) != 0, the first Block in slice has non-zero length, and limit >
+// 0.
+func blockSeqFromSliceLimited(slice []Block, limit uint64) BlockSeq {
+ switch len(slice) {
+ case 0:
+ return BlockSeq{}
+ case 1:
+ return BlockSeqOf(slice[0].TakeFirst64(limit))
+ default:
+ return BlockSeq{
+ data: unsafe.Pointer(&slice[0]),
+ length: len(slice),
+ limit: limit,
+ }
+ }
+}
+
+func skipEmpty(slice []Block) []Block {
+ for i, b := range slice {
+ if b.Len() != 0 {
+ return slice[i:]
+ }
+ }
+ return nil
+}
+
+// IsEmpty returns true if bs contains no Blocks.
+//
+// Invariants: bs.IsEmpty() == (bs.NumBlocks() == 0) == (bs.NumBytes() == 0).
+// (Of these, prefer to use bs.IsEmpty().)
+func (bs BlockSeq) IsEmpty() bool {
+ return bs.length == 0
+}
+
+// NumBlocks returns the number of Blocks in bs.
+func (bs BlockSeq) NumBlocks() int {
+ // In general, we have to count: if bs represents a windowed slice then the
+ // slice may contain Blocks with zero length, and bs.length may be larger
+ // than the actual number of Blocks due to bs.limit.
+ var n int
+ for !bs.IsEmpty() {
+ n++
+ bs = bs.Tail()
+ }
+ return n
+}
+
+// NumBytes returns the sum of Block.Len() for all Blocks in bs.
+func (bs BlockSeq) NumBytes() uint64 {
+ return bs.limit
+}
+
+// Head returns the first Block in bs.
+//
+// Preconditions: !bs.IsEmpty().
+func (bs BlockSeq) Head() Block {
+ if bs.length == 0 {
+ panic("empty BlockSeq")
+ }
+ if bs.length < 0 {
+ return bs.internalBlock()
+ }
+ return (*Block)(bs.data).DropFirst(bs.offset).TakeFirst64(bs.limit)
+}
+
+// Preconditions: bs.length < 0.
+func (bs BlockSeq) internalBlock() Block {
+ return Block{
+ start: bs.data,
+ length: int(bs.limit),
+ needSafecopy: bs.length == -2,
+ }
+}
+
+// Tail returns a BlockSeq consisting of all Blocks in bs after the first.
+//
+// Preconditions: !bs.IsEmpty().
+func (bs BlockSeq) Tail() BlockSeq {
+ if bs.length == 0 {
+ panic("empty BlockSeq")
+ }
+ if bs.length < 0 {
+ return BlockSeq{}
+ }
+ head := (*Block)(bs.data).DropFirst(bs.offset)
+ headLen := uint64(head.Len())
+ if headLen >= bs.limit {
+ // The head Block exhausts the limit, so the tail is empty.
+ return BlockSeq{}
+ }
+ var extSlice []Block
+ extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice))
+ extSliceHdr.Data = uintptr(bs.data)
+ extSliceHdr.Len = bs.length
+ extSliceHdr.Cap = bs.length
+ tailSlice := skipEmpty(extSlice[1:])
+ tailLimit := bs.limit - headLen
+ return blockSeqFromSliceLimited(tailSlice, tailLimit)
+}
+
+// DropFirst returns a BlockSeq equivalent to bs, but with the first n bytes
+// omitted. If n > bs.NumBytes(), DropFirst returns an empty BlockSeq.
+//
+// Preconditions: n >= 0.
+func (bs BlockSeq) DropFirst(n int) BlockSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return bs.DropFirst64(uint64(n))
+}
+
+// DropFirst64 is equivalent to DropFirst but takes an uint64.
+func (bs BlockSeq) DropFirst64(n uint64) BlockSeq {
+ if n >= bs.limit {
+ return BlockSeq{}
+ }
+ for {
+ // Calling bs.Head() here is surprisingly expensive, so inline getting
+ // the head's length.
+ var headLen uint64
+ if bs.length < 0 {
+ headLen = bs.limit
+ } else {
+ headLen = uint64((*Block)(bs.data).Len() - bs.offset)
+ }
+ if n < headLen {
+ // Dropping ends partway through the head Block.
+ if bs.length < 0 {
+ return BlockSeqOf(bs.internalBlock().DropFirst64(n))
+ }
+ bs.offset += int(n)
+ bs.limit -= n
+ return bs
+ }
+ n -= headLen
+ bs = bs.Tail()
+ }
+}
+
+// TakeFirst returns a BlockSeq equivalent to the first n bytes of bs. If n >
+// bs.NumBytes(), TakeFirst returns a BlockSeq equivalent to bs.
+//
+// Preconditions: n >= 0.
+func (bs BlockSeq) TakeFirst(n int) BlockSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return bs.TakeFirst64(uint64(n))
+}
+
+// TakeFirst64 is equivalent to TakeFirst but takes a uint64.
+func (bs BlockSeq) TakeFirst64(n uint64) BlockSeq {
+ if n == 0 {
+ return BlockSeq{}
+ }
+ if bs.limit > n {
+ bs.limit = n
+ }
+ return bs
+}
+
+// String implements fmt.Stringer.String.
+func (bs BlockSeq) String() string {
+ var buf bytes.Buffer
+ buf.WriteByte('[')
+ var sep string
+ for !bs.IsEmpty() {
+ buf.WriteString(sep)
+ sep = " "
+ buf.WriteString(bs.Head().String())
+ bs = bs.Tail()
+ }
+ buf.WriteByte(']')
+ return buf.String()
+}
+
+// CopySeq copies srcs.NumBytes() or dsts.NumBytes() bytes, whichever is less,
+// from srcs to dsts and returns the number of bytes copied.
+//
+// If srcs and dsts overlap, the data stored in dsts is unspecified.
+func CopySeq(dsts, srcs BlockSeq) (uint64, error) {
+ var done uint64
+ for !dsts.IsEmpty() && !srcs.IsEmpty() {
+ dst := dsts.Head()
+ src := srcs.Head()
+ n, err := Copy(dst, src)
+ done += uint64(n)
+ if err != nil {
+ return done, err
+ }
+ dsts = dsts.DropFirst(n)
+ srcs = srcs.DropFirst(n)
+ }
+ return done, nil
+}
+
+// ZeroSeq sets all bytes in dsts to 0 and returns the number of bytes zeroed.
+func ZeroSeq(dsts BlockSeq) (uint64, error) {
+ var done uint64
+ for !dsts.IsEmpty() {
+ n, err := Zero(dsts.Head())
+ done += uint64(n)
+ if err != nil {
+ return done, err
+ }
+ dsts = dsts.DropFirst(n)
+ }
+ return done, nil
+}
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go
new file mode 100644
index 000000000..659b43363
--- /dev/null
+++ b/pkg/sentry/sighandling/sighandling.go
@@ -0,0 +1,140 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sighandling contains helpers for handling signals to applications.
+package sighandling
+
+import (
+ "fmt"
+ "os"
+ "os/signal"
+ "reflect"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// numSignals is the number of normal (non-realtime) signals on Linux.
+const numSignals = 32
+
+// handleSignals listens for incoming signals and calls the given handler
+// function.
+//
+// It starts when the start channel is closed, stops when the stop channel
+// is closed, and closes done once it will no longer deliver signals to k.
+func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, stop, done chan struct{}) {
+ // Build a select case.
+ sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(start)}}
+ for _, sigchan := range sigchans {
+ sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)})
+ }
+
+ started := false
+ for {
+ // Wait for a notification.
+ index, _, ok := reflect.Select(sc)
+
+ // Was it the start / stop channel?
+ if index == 0 {
+ if !ok {
+ if !started {
+ // start channel; start forwarding and
+ // swap this case for the stop channel
+ // to select stop requests.
+ started = true
+ sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}
+ } else {
+ // stop channel; stop forwarding and
+ // clear this case so it is never
+ // selected again.
+ started = false
+ close(done)
+ sc[0].Chan = reflect.Value{}
+ }
+ }
+ continue
+ }
+
+ // How about a different close?
+ if !ok {
+ panic("signal channel closed unexpectedly")
+ }
+
+ // Otherwise, it was a signal on channel N. Index 0 represents the stop
+ // channel, so index N represents the channel for signal N.
+ signal := linux.Signal(index)
+
+ if !started {
+ // Kernel cannot receive signals, either because it is
+ // not ready yet or is shutting down.
+ //
+ // Kill ourselves if this signal would have killed the
+ // process before PrepareForwarding was called. i.e., all
+ // _SigKill signals; see Go
+ // src/runtime/sigtab_linux_generic.go.
+ //
+ // Otherwise ignore the signal.
+ //
+ // TODO(b/114489875): Drop in Go 1.12, which uses tgkill
+ // in runtime.raise.
+ switch signal {
+ case linux.SIGHUP, linux.SIGINT, linux.SIGTERM:
+ dieFromSignal(signal)
+ panic(fmt.Sprintf("Failed to die from signal %d", signal))
+ default:
+ continue
+ }
+ }
+
+ // Pass the signal to the handler.
+ handler(signal)
+ }
+}
+
+// PrepareHandler ensures that synchronous signals are passed to the given
+// handler function and returns a callback that starts signal delivery, which
+// itself returns a callback that stops signal handling.
+//
+// Note that this function permanently takes over signal handling. After the
+// stop callback, signals revert to the default Go runtime behavior, which
+// cannot be overridden with external calls to signal.Notify.
+func PrepareHandler(handler func(linux.Signal)) func() func() {
+ start := make(chan struct{})
+ stop := make(chan struct{})
+ done := make(chan struct{})
+
+ // Register individual channels. One channel per standard signal is
+ // required as os.Notify() is non-blocking and may drop signals. To avoid
+ // this, standard signals have to be queued separately. Channel size 1 is
+ // enough for standard signals as their semantics allow de-duplication.
+ //
+ // External real-time signals are not supported. We rely on the go-runtime
+ // for their handling.
+ var sigchans []chan os.Signal
+ for sig := 1; sig <= numSignals+1; sig++ {
+ sigchan := make(chan os.Signal, 1)
+ sigchans = append(sigchans, sigchan)
+ signal.Notify(sigchan, syscall.Signal(sig))
+ }
+ // Start up our listener.
+ go handleSignals(sigchans, handler, start, stop, done) // S/R-SAFE: synchronized by Kernel.extMu.
+
+ return func() func() {
+ close(start)
+ return func() {
+ close(stop)
+ <-done
+ }
+ }
+}
diff --git a/pkg/sentry/sighandling/sighandling_state_autogen.go b/pkg/sentry/sighandling/sighandling_state_autogen.go
new file mode 100755
index 000000000..dad4bdda2
--- /dev/null
+++ b/pkg/sentry/sighandling/sighandling_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package sighandling
+
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
new file mode 100644
index 000000000..aca77888a
--- /dev/null
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -0,0 +1,74 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sighandling
+
+import (
+ "fmt"
+ "runtime"
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// TODO(b/34161764): Move to pkg/abi/linux along with definitions in
+// pkg/sentry/arch.
+type sigaction struct {
+ handler uintptr
+ flags uint64
+ restorer uintptr
+ mask uint64
+}
+
+// IgnoreChildStop sets the SA_NOCLDSTOP flag, causing child processes to not
+// generate SIGCHLD when they stop.
+func IgnoreChildStop() error {
+ var sa sigaction
+
+ // Get the existing signal handler information, and set the flag.
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(syscall.SIGCHLD), 0, uintptr(unsafe.Pointer(&sa)), linux.SignalSetSize, 0, 0); e != 0 {
+ return e
+ }
+ sa.flags |= linux.SA_NOCLDSTOP
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(syscall.SIGCHLD), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 {
+ return e
+ }
+
+ return nil
+}
+
+// dieFromSignal kills the current process with sig.
+//
+// Preconditions: The default action of sig is termination.
+func dieFromSignal(sig linux.Signal) {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ sa := sigaction{handler: linux.SIG_DFL}
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 {
+ panic(fmt.Sprintf("rt_sigaction failed: %v", e))
+ }
+
+ set := linux.MakeSignalSet(sig)
+ if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_UNBLOCK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0); e != 0 {
+ panic(fmt.Sprintf("rt_sigprocmask failed: %v", e))
+ }
+
+ if err := syscall.Tgkill(syscall.Getpid(), syscall.Gettid(), syscall.Signal(sig)); err != nil {
+ panic(fmt.Sprintf("tgkill failed: %v", err))
+ }
+
+ panic("failed to die")
+}
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
new file mode 100644
index 000000000..c0238691d
--- /dev/null
+++ b/pkg/sentry/socket/control/control.go
@@ -0,0 +1,421 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package control provides internal representations of socket control
+// messages.
+package control
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const maxInt = int(^uint(0) >> 1)
+
+// SCMCredentials represents a SCM_CREDENTIALS socket control message.
+type SCMCredentials interface {
+ transport.CredentialsControlMessage
+
+ // Credentials returns properly namespaced values for the sender's pid, uid
+ // and gid.
+ Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID)
+}
+
+// SCMRights represents a SCM_RIGHTS socket control message.
+type SCMRights interface {
+ transport.RightsControlMessage
+
+ // Files returns up to max RightsFiles.
+ //
+ // Returned files are consumed and ownership is transferred to the caller.
+ // Subsequent calls to Files will return the next files.
+ Files(ctx context.Context, max int) (rf RightsFiles, truncated bool)
+}
+
+// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
+// maintained for each fs.File and is release either when an FD is created or
+// when the Release method is called.
+//
+// +stateify savable
+type RightsFiles []*fs.File
+
+// NewSCMRights creates a new SCM_RIGHTS socket control message representation
+// using local sentry FDs.
+func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
+ files := make(RightsFiles, 0, len(fds))
+ for _, fd := range fds {
+ file, _ := t.FDMap().GetDescriptor(kdefs.FD(fd))
+ if file == nil {
+ files.Release()
+ return nil, syserror.EBADF
+ }
+ files = append(files, file)
+ }
+ return &files, nil
+}
+
+// Files implements SCMRights.Files.
+func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) {
+ n := max
+ var trunc bool
+ if l := len(*fs); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+ rf := (*fs)[:n]
+ *fs = (*fs)[n:]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (fs *RightsFiles) Clone() transport.RightsControlMessage {
+ nfs := append(RightsFiles(nil), *fs...)
+ for _, nf := range nfs {
+ nf.IncRef()
+ }
+ return &nfs
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (fs *RightsFiles) Release() {
+ for _, f := range *fs {
+ f.DecRef()
+ }
+ *fs = nil
+}
+
+// rightsFDs gets up to the specified maximum number of FDs.
+func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32, bool) {
+ files, trunc := rights.Files(t, max)
+ fds := make([]int32, 0, len(files))
+ for i := 0; i < max && len(files) > 0; i++ {
+ fd, err := t.FDMap().NewFDFrom(0, files[0], kernel.FDFlags{cloexec}, t.ThreadGroup().Limits())
+ files[0].DecRef()
+ files = files[1:]
+ if err != nil {
+ t.Warningf("Error inserting FD: %v", err)
+ // This is what Linux does.
+ break
+ }
+
+ fds = append(fds, int32(fd))
+ }
+ return fds, trunc
+}
+
+// PackRights packs as many FDs as will fit into the unused capacity of buf.
+func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flags int) ([]byte, int) {
+ maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
+ // Linux does not return any FDs if none fit.
+ if maxFDs <= 0 {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+ fds, trunc := rightsFDs(t, rights, cloexec, maxFDs)
+ if trunc {
+ flags |= linux.MSG_CTRUNC
+ }
+ align := t.Arch().Width()
+ return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
+}
+
+// scmCredentials represents an SCM_CREDENTIALS socket control message.
+//
+// +stateify savable
+type scmCredentials struct {
+ t *kernel.Task
+ kuid auth.KUID
+ kgid auth.KGID
+}
+
+// NewSCMCredentials creates a new SCM_CREDENTIALS socket control message
+// representation.
+func NewSCMCredentials(t *kernel.Task, cred linux.ControlMessageCredentials) (SCMCredentials, error) {
+ tcred := t.Credentials()
+ kuid, err := tcred.UseUID(auth.UID(cred.UID))
+ if err != nil {
+ return nil, err
+ }
+ kgid, err := tcred.UseGID(auth.GID(cred.GID))
+ if err != nil {
+ return nil, err
+ }
+ if kernel.ThreadID(cred.PID) != t.ThreadGroup().ID() && !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.PIDNamespace().UserNamespace()) {
+ return nil, syserror.EPERM
+ }
+ return &scmCredentials{t, kuid, kgid}, nil
+}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (c *scmCredentials) Equals(oc transport.CredentialsControlMessage) bool {
+ if oc, _ := oc.(*scmCredentials); oc != nil && *c == *oc {
+ return true
+ }
+ return false
+}
+
+func putUint64(buf []byte, n uint64) []byte {
+ usermem.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n)
+ return buf[:len(buf)+8]
+}
+
+func putUint32(buf []byte, n uint32) []byte {
+ usermem.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n)
+ return buf[:len(buf)+4]
+}
+
+// putCmsg writes a control message header and as much data as will fit into
+// the unused capacity of a buffer.
+func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
+ space := AlignDown(cap(buf)-len(buf), 4)
+
+ // We can't write to space that doesn't exist, so if we are going to align
+ // the available space, we must align down.
+ //
+ // align must be >= 4 and each data int32 is 4 bytes. The length of the
+ // header is already aligned, so if we align to the with of the data there
+ // are two cases:
+ // 1. The aligned length is less than the length of the header. The
+ // unaligned length was also less than the length of the header, so we
+ // can't write anything.
+ // 2. The aligned length is greater than or equal to the length of the
+ // header. We can write the header plus zero or more datas. We can't write
+ // a partial int32, so the length of the message will be
+ // min(aligned length, header + datas).
+ if space < linux.SizeOfControlMessageHeader {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+
+ length := 4*len(data) + linux.SizeOfControlMessageHeader
+ if length > space {
+ length = space
+ }
+ buf = putUint64(buf, uint64(length))
+ buf = putUint32(buf, linux.SOL_SOCKET)
+ buf = putUint32(buf, msgType)
+ for _, d := range data {
+ if len(buf)+4 > cap(buf) {
+ flags |= linux.MSG_CTRUNC
+ break
+ }
+ buf = putUint32(buf, uint32(d))
+ }
+ return alignSlice(buf, align), flags
+}
+
+func putCmsgStruct(buf []byte, msgType uint32, align uint, data interface{}) []byte {
+ if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader {
+ return buf
+ }
+ ob := buf
+
+ buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader))
+ buf = putUint32(buf, linux.SOL_SOCKET)
+ buf = putUint32(buf, msgType)
+
+ hdrBuf := buf
+
+ buf = binary.Marshal(buf, usermem.ByteOrder, data)
+
+ // Check if we went over.
+ if cap(buf) != cap(ob) {
+ return hdrBuf
+ }
+
+ // Fix up length.
+ putUint64(ob, uint64(len(buf)-len(ob)))
+
+ return alignSlice(buf, align)
+}
+
+// Credentials implements SCMCredentials.Credentials.
+func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ // "When a process's user and group IDs are passed over a UNIX domain
+ // socket to a process in a different user namespace (see the description
+ // of SCM_CREDENTIALS in unix(7)), they are translated into the
+ // corresponding values as per the receiving process's user and group ID
+ // mappings." - user_namespaces(7)
+ pid := t.PIDNamespace().IDOfTask(c.t)
+ uid := c.kuid.In(t.UserNamespace()).OrOverflow()
+ gid := c.kgid.In(t.UserNamespace()).OrOverflow()
+
+ return pid, uid, gid
+}
+
+// PackCredentials packs the credentials in the control message (or default
+// credentials if none) into a buffer.
+func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int) ([]byte, int) {
+ align := t.Arch().Width()
+
+ // Default credentials if none are available.
+ pid := kernel.ThreadID(0)
+ uid := auth.UID(auth.NobodyKUID)
+ gid := auth.GID(auth.NobodyKGID)
+
+ if creds != nil {
+ pid, uid, gid = creds.Credentials(t)
+ }
+ c := []int32{int32(pid), int32(uid), int32(gid)}
+ return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
+}
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
+
+// alignSlice extends a slice's length (up to the capacity) to align it.
+func alignSlice(buf []byte, align uint) []byte {
+ aligned := AlignUp(len(buf), align)
+ if aligned > cap(buf) {
+ // Linux allows unaligned data if there isn't room for alignment.
+ // Since there isn't room for alignment, there isn't room for any
+ // additional messages either.
+ return buf
+ }
+ return buf[:aligned]
+}
+
+// PackTimestamp packs a SO_TIMESTAMP socket control message.
+func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SO_TIMESTAMP,
+ t.Arch().Width(),
+ linux.NsecToTimeval(timestamp),
+ )
+}
+
+// Parse parses a raw socket control message into portable objects.
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) {
+ var (
+ fds linux.ControlMessageRights
+
+ haveCreds bool
+ creds linux.ControlMessageCredentials
+ )
+
+ for i := 0; i < len(buf); {
+ if i+linux.SizeOfControlMessageHeader > len(buf) {
+ return transport.ControlMessages{}, syserror.EINVAL
+ }
+
+ var h linux.ControlMessageHeader
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
+
+ if h.Length < uint64(linux.SizeOfControlMessageHeader) {
+ return transport.ControlMessages{}, syserror.EINVAL
+ }
+ if h.Length > uint64(len(buf)-i) {
+ return transport.ControlMessages{}, syserror.EINVAL
+ }
+ if h.Level != linux.SOL_SOCKET {
+ return transport.ControlMessages{}, syserror.EINVAL
+ }
+
+ i += linux.SizeOfControlMessageHeader
+ length := int(h.Length) - linux.SizeOfControlMessageHeader
+
+ // The use of t.Arch().Width() is analogous to Linux's use of
+ // sizeof(long) in CMSG_ALIGN.
+ width := t.Arch().Width()
+
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+
+ if len(fds)+numRights > linux.SCM_MAX_FD {
+ return transport.ControlMessages{}, syserror.EINVAL
+ }
+
+ for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
+ fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ }
+
+ i += AlignUp(length, width)
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ return transport.ControlMessages{}, syserror.EINVAL
+ }
+
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ haveCreds = true
+ i += AlignUp(length, width)
+
+ default:
+ // Unknown message type.
+ return transport.ControlMessages{}, syserror.EINVAL
+ }
+ }
+
+ var credentials SCMCredentials
+ if haveCreds {
+ var err error
+ if credentials, err = NewSCMCredentials(t, creds); err != nil {
+ return transport.ControlMessages{}, err
+ }
+ } else {
+ credentials = makeCreds(t, socketOrEndpoint)
+ }
+
+ var rights SCMRights
+ if len(fds) > 0 {
+ var err error
+ if rights, err = NewSCMRights(t, fds); err != nil {
+ return transport.ControlMessages{}, err
+ }
+ }
+
+ if credentials == nil && rights == nil {
+ return transport.ControlMessages{}, nil
+ }
+
+ return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil
+}
+
+func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
+ if t == nil || socketOrEndpoint == nil {
+ return nil
+ }
+ if cr, ok := socketOrEndpoint.(transport.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) {
+ tcred := t.Credentials()
+ return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
+ }
+ return nil
+}
+
+// New creates default control messages if needed.
+func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages {
+ return transport.ControlMessages{
+ Credentials: makeCreds(t, socketOrEndpoint),
+ Rights: rights,
+ }
+}
diff --git a/pkg/sentry/socket/control/control_state_autogen.go b/pkg/sentry/socket/control/control_state_autogen.go
new file mode 100755
index 000000000..4554692a7
--- /dev/null
+++ b/pkg/sentry/socket/control/control_state_autogen.go
@@ -0,0 +1,36 @@
+// automatically generated by stateify.
+
+package control
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+func (x *RightsFiles) save(m state.Map) {
+ m.SaveValue("", ([]*fs.File)(*x))
+}
+
+func (x *RightsFiles) load(m state.Map) {
+ m.LoadValue("", new([]*fs.File), func(y interface{}) { *x = (RightsFiles)(y.([]*fs.File)) })
+}
+
+func (x *scmCredentials) beforeSave() {}
+func (x *scmCredentials) save(m state.Map) {
+ x.beforeSave()
+ m.Save("t", &x.t)
+ m.Save("kuid", &x.kuid)
+ m.Save("kgid", &x.kgid)
+}
+
+func (x *scmCredentials) afterLoad() {}
+func (x *scmCredentials) load(m state.Map) {
+ m.Load("t", &x.t)
+ m.Load("kuid", &x.kuid)
+ m.Load("kgid", &x.kgid)
+}
+
+func init() {
+ state.Register("control.RightsFiles", (*RightsFiles)(nil), state.Fns{Save: (*RightsFiles).save, Load: (*RightsFiles).load})
+ state.Register("control.scmCredentials", (*scmCredentials)(nil), state.Fns{Save: (*scmCredentials).save, Load: (*scmCredentials).load})
+}
diff --git a/pkg/sentry/socket/epsocket/device.go b/pkg/sentry/socket/epsocket/device.go
new file mode 100644
index 000000000..ab4083efe
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epsocket
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// epsocketDevice is the endpoint socket virtual device.
+var epsocketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
new file mode 100644
index 000000000..de4b963da
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -0,0 +1,2283 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package epsocket provides an implementation of the socket.Socket interface
+// that is backed by a tcpip.Endpoint.
+//
+// It does not depend on any particular endpoint implementation, and thus can
+// be used to expose certain endpoints to the sentry while leaving others out,
+// for example, TCP endpoints and Unix-domain endpoints.
+//
+// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
+// tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
+// this operation.
+package epsocket
+
+import (
+ "bytes"
+ "math"
+ "sync"
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+func mustCreateMetric(name, description string) *tcpip.StatCounter {
+ var cm tcpip.StatCounter
+ metric.MustRegisterCustomUint64Metric(name, false /* sync */, description, cm.Value)
+ return &cm
+}
+
+// Metrics contains metrics exported by netstack.
+var Metrics = tcpip.Stats{
+ UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."),
+ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
+ DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
+ ICMP: tcpip.ICMPStats{
+ V4PacketsSent: tcpip.ICMPv4SentPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ },
+ V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
+ },
+ V6PacketsSent: tcpip.ICMPv6SentPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ },
+ V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ },
+ },
+ IP: tcpip.IPStats{
+ PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
+ InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
+ PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
+ PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
+ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
+ },
+ TCP: tcpip.TCPStats{
+ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
+ PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
+ ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."),
+ ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."),
+ ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."),
+ ListenOverflowSynCookieRcvd: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_rcvd", "Number of times a SYN cookie was received."),
+ ListenOverflowInvalidSynCookieRcvd: mustCreateMetric("/netstack/tcp/listen_overflow_invalid_syn_cookie_rcvd", "Number of times an invalid SYN cookie was received."),
+ FailedConnectionAttempts: mustCreateMetric("/netstack/tcp/failed_connection_attempts", "Number of calls to Connect or Listen (active and passive openings, respectively) that end in an error."),
+ ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."),
+ InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."),
+ SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."),
+ ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."),
+ ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."),
+ Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."),
+ FastRecovery: mustCreateMetric("/netstack/tcp/fast_recovery", "Number of times fast recovery was used to recover from packet loss."),
+ SACKRecovery: mustCreateMetric("/netstack/tcp/sack_recovery", "Number of times SACK recovery was used to recover from packet loss."),
+ SlowStartRetransmits: mustCreateMetric("/netstack/tcp/slow_start_retransmits", "Number of segments retransmitted in slow start mode."),
+ FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."),
+ Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."),
+ ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
+ },
+ UDP: tcpip.UDPStats{
+ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
+ UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."),
+ ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."),
+ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent via sendUDP."),
+ },
+}
+
+const sizeOfInt32 int = 4
+
+var errStackType = syserr.New("expected but did not receive an epsocket.Stack", linux.EINVAL)
+
+// ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func htons(v uint16) uint16 {
+ return ntohs(v)
+}
+
+// commonEndpoint represents the intersection of a tcpip.Endpoint and a
+// transport.Endpoint.
+type commonEndpoint interface {
+ // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress and
+ // transport.Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress and
+ // transport.Endpoint.GetRemoteAddress.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Readiness implements tcpip.Endpoint.Readiness and
+ // transport.Endpoint.Readiness.
+ Readiness(mask waiter.EventMask) waiter.EventMask
+
+ // SetSockOpt implements tcpip.Endpoint.SetSockOpt and
+ // transport.Endpoint.SetSockOpt.
+ SetSockOpt(interface{}) *tcpip.Error
+
+ // GetSockOpt implements tcpip.Endpoint.GetSockOpt and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOpt(interface{}) *tcpip.Error
+}
+
+// SocketOperations encapsulates all the state needed to represent a network stack
+// endpoint in the kernel context.
+//
+// +stateify savable
+type SocketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ socket.SendReceiveTimeout
+ *waiter.Queue
+
+ family int
+ Endpoint tcpip.Endpoint
+ skType transport.SockType
+
+ // readMu protects access to the below fields.
+ readMu sync.Mutex `state:"nosave"`
+ // readView contains the remaining payload from the last packet.
+ readView buffer.View
+ // readCM holds control message information for the last packet read
+ // from Endpoint.
+ readCM tcpip.ControlMessages
+ sender tcpip.FullAddress
+
+ // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
+ // of returned messages can be returned via control messages. When
+ // false, the same timestamp is instead stored and can be read via the
+ // SIOCGSTAMP ioctl. It is protected by readMu. See socket(7).
+ sockOptTimestamp bool
+ // timestampValid indicates whether timestamp for SIOCGSTAMP has been
+ // set. It is protected by readMu.
+ timestampValid bool
+ // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
+ // valid when timestampValid is true. It is protected by readMu.
+ timestampNS int64
+}
+
+// New creates a new endpoint socket.
+func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
+ if skType == transport.SockStream {
+ if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ }
+
+ dirent := socket.NewDirent(t, epsocketDevice)
+ defer dirent.DecRef()
+ return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true}, &SocketOperations{
+ Queue: queue,
+ family: family,
+ Endpoint: endpoint,
+ skType: skType,
+ }), nil
+}
+
+var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
+var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
+
+// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
+// netstack representation taking any addresses into account.
+func bytesToIPAddress(addr []byte) tcpip.Address {
+ if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
+ return ""
+ }
+ return tcpip.Address(addr)
+}
+
+// GetAddress reads an sockaddr struct from the given address and converts it
+// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
+// addresses.
+func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ }
+
+ family := usermem.ByteOrder.Uint16(addr)
+ if family != uint16(sfamily) {
+ return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ if len(path) > linux.UnixPathMax {
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ }
+ // Drop the terminating NUL (if one exists) and everything after
+ // it for filesystem (non-abstract) addresses.
+ if len(path) > 0 && path[0] != 0 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, syserr.ErrBadAddress
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: bytesToIPAddress(a.Addr[:]),
+ Port: ntohs(a.Port),
+ }
+ return out, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, syserr.ErrBadAddress
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: bytesToIPAddress(a.Addr[:]),
+ Port: ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ return out, nil
+
+ default:
+ return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ }
+}
+
+func (s *SocketOperations) isPacketBased() bool {
+ return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
+}
+
+// fetchReadView updates the readView field of the socket if it's currently
+// empty. It assumes that the socket is locked.
+func (s *SocketOperations) fetchReadView() *syserr.Error {
+ if len(s.readView) > 0 {
+ return nil
+ }
+
+ s.readView = nil
+ s.sender = tcpip.FullAddress{}
+
+ v, cms, err := s.Endpoint.Read(&s.sender)
+ if err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+
+ s.readView = v
+ s.readCM = cms
+
+ return nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *SocketOperations) Release() {
+ s.Endpoint.Close()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ if err == syserr.ErrWouldBlock {
+ return int64(n), syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return 0, err.ToError()
+ }
+ return int64(n), nil
+}
+
+// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand
+// based on the requested size.
+type ioSequencePayload struct {
+ ctx context.Context
+ src usermem.IOSequence
+}
+
+// Get implements tcpip.Payload.
+func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
+ if size > i.Size() {
+ size = i.Size()
+ }
+ v := buffer.NewView(size)
+ if _, err := i.src.CopyIn(i.ctx, v); err != nil {
+ return nil, tcpip.ErrBadAddress
+ }
+ return v, nil
+}
+
+// Size implements tcpip.Payload.
+func (i *ioSequencePayload) Size() int {
+ return int(i.src.NumBytes())
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ f := &ioSequencePayload{ctx: ctx, src: src}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ t := ctx.(*kernel.Task)
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err).ToError()
+ }
+
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
+ }
+
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
+}
+
+// Readiness returns a mask of ready events for socket s.
+func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ r := s.Endpoint.Readiness(mask)
+
+ // Check our cached value iff the caller asked for readability and the
+ // endpoint itself is currently not readable.
+ if (mask & ^r & waiter.EventIn) != 0 {
+ s.readMu.Lock()
+ if len(s.readView) > 0 {
+ r |= waiter.EventIn
+ }
+ s.readMu.Unlock()
+ }
+
+ return r
+}
+
+// Connect implements the linux syscall connect(2) for sockets backed by
+// tpcip.Endpoint.
+func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ addr, err := GetAddress(s.family, sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // Always return right away in the non-blocking case.
+ if !blocking {
+ return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
+ }
+
+ // Register for notification when the endpoint becomes writable, then
+ // initiate the connection.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ return syserr.TranslateNetstackError(err)
+ }
+
+ // It's pending, so we have to wait for a notification, and fetch the
+ // result once the wait completes.
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+
+ // Call Connect() again after blocking to find connect's result.
+ return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
+}
+
+// Bind implements the linux syscall bind(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ addr, err := GetAddress(s.family, sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // Issue the bind request to the endpoint.
+ return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection again; if it fails, then wait until we
+ // get a notification.
+ for {
+ if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock {
+ return ep, wq, syserr.TranslateNetstackError(err)
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, wq, terr := s.Endpoint.Accept()
+ if terr != nil {
+ if terr != tcpip.ErrWouldBlock || !blocking {
+ return 0, nil, 0, syserr.TranslateNetstackError(terr)
+ }
+
+ var err *syserr.Error
+ ep, wq, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := New(t, s.family, s.skType, wq, ep)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr interface{}
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer and write it to peer slice.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ }
+ fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+
+ t.Kernel().RecordSocket(ns, s.family)
+
+ return fd, addr, addrLen, syserr.FromError(e)
+}
+
+// ConvertShutdown converts Linux shutdown flags into tcpip shutdown flags.
+func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
+ var f tcpip.ShutdownFlags
+ switch how {
+ case linux.SHUT_RD:
+ f = tcpip.ShutdownRead
+ case linux.SHUT_WR:
+ f = tcpip.ShutdownWrite
+ case linux.SHUT_RDWR:
+ f = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ default:
+ return 0, syserr.ErrInvalidArgument
+ }
+ return f, nil
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return syserr.TranslateNetstackError(s.Endpoint.Shutdown(f))
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for epsocket.SocketOperations rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptTimestamp {
+ val = 1
+ }
+ return val, nil
+ }
+
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+}
+
+// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
+// sockets backed by a commonEndpoint.
+func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+ switch level {
+ case linux.SOL_SOCKET:
+ return getSockOptSocket(t, s, ep, family, skType, name, outLen)
+
+ case linux.SOL_TCP:
+ return getSockOptTCP(t, ep, name, outLen)
+
+ case linux.SOL_IPV6:
+ return getSockOptIPv6(t, ep, name, outLen)
+
+ case linux.SOL_IP:
+ return getSockOptIP(t, ep, name, outLen)
+
+ case linux.SOL_UDP,
+ linux.SOL_ICMPV6,
+ linux.SOL_RAW,
+ linux.SOL_PACKET:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
+func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) {
+ // TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
+ switch name {
+ case linux.SO_TYPE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return int32(skType), nil
+
+ case linux.SO_ERROR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Get the last error and convert it.
+ err := ep.GetSockOpt(tcpip.ErrorOption{})
+ if err == nil {
+ return int32(0), nil
+ }
+ return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ case linux.SO_PEERCRED:
+ if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ tcred := t.Credentials()
+ return syscall.Ucred{
+ Pid: int32(t.ThreadGroup().ID()),
+ Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }, nil
+
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.PasscredOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_SNDBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var size tcpip.SendBufferSizeOption
+ if err := ep.GetSockOpt(&size); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if size > math.MaxInt32 {
+ size = math.MaxInt32
+ }
+
+ return int32(size), nil
+
+ case linux.SO_RCVBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var size tcpip.ReceiveBufferSizeOption
+ if err := ep.GetSockOpt(&size); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if size > math.MaxInt32 {
+ size = math.MaxInt32
+ }
+
+ return int32(size), nil
+
+ case linux.SO_REUSEADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.ReuseAddressOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_REUSEPORT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.ReusePortOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_BROADCAST:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.BroadcastOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_KEEPALIVE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.KeepaliveEnabledOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_LINGER:
+ if outLen < linux.SizeOfLinger {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return linux.Linger{}, nil
+
+ case linux.SO_SNDTIMEO:
+ // TODO(igudger): Linux allows shorter lengths for partial results.
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+
+ case linux.SO_RCVTIMEO:
+ // TODO(igudger): Linux allows shorter lengths for partial results.
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.RecvTimeout()), nil
+
+ case linux.SO_OOBINLINE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OutOfBandInlineOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ default:
+ socket.GetSockOptEmitUnimplementedEvent(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.TCP_NODELAY:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.DelayOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if v == 0 {
+ return int32(1), nil
+ }
+ return int32(0), nil
+
+ case linux.TCP_CORK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.CorkOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.TCP_QUICKACK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.QuickAckOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.TCP_KEEPIDLE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.KeepaliveIdleOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_KEEPINTVL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.KeepaliveIntervalOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_INFO:
+ var v tcpip.TCPInfoOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // TODO(b/64800844): Translate fields once they are added to
+ // tcpip.TCPInfoOption.
+ info := linux.TCPInfo{}
+
+ // Linux truncates the output binary to outLen.
+ ib := binary.Marshal(nil, usermem.ByteOrder, &info)
+ if len(ib) > outLen {
+ ib = ib[:outLen]
+ }
+
+ return ib, nil
+
+ case linux.TCP_CC_INFO,
+ linux.TCP_NOTSENT_LOWAT,
+ linux.TCP_ZEROCOPY_RECEIVE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEventTCP(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.IPV6_V6ONLY:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.V6OnlyOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.IPV6_PATHMTU:
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEventIPv6(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptIP implements GetSockOpt when level is SOL_IP.
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.IP_MULTICAST_TTL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.MulticastTTLOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.IP_MULTICAST_IF:
+ if outLen < len(linux.InetAddr{}) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.MulticastInterfaceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
+
+ return a.(linux.SockAddrInet).Addr, nil
+
+ case linux.IP_MULTICAST_LOOP:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.MulticastLoopOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if v {
+ return int32(1), nil
+ }
+ return int32(0), nil
+
+ default:
+ emitUnimplementedEventIP(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for epsocket.SocketOperations rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+
+ return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
+}
+
+// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
+// sockets backed by a commonEndpoint.
+func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+ switch level {
+ case linux.SOL_SOCKET:
+ return setSockOptSocket(t, s, ep, name, optVal)
+
+ case linux.SOL_TCP:
+ return setSockOptTCP(t, ep, name, optVal)
+
+ case linux.SOL_IPV6:
+ return setSockOptIPv6(t, ep, name, optVal)
+
+ case linux.SOL_IP:
+ return setSockOptIP(t, ep, name, optVal)
+
+ case linux.SOL_UDP,
+ linux.SOL_ICMPV6,
+ linux.SOL_RAW,
+ linux.SOL_PACKET:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
+func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.SO_SNDBUF:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+
+ case linux.SO_RCVBUF:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+
+ case linux.SO_REUSEADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v)))
+
+ case linux.SO_REUSEPORT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+
+ case linux.SO_BROADCAST:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v)))
+
+ case linux.SO_PASSCRED:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+
+ case linux.SO_KEEPALIVE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v)))
+
+ case linux.SO_SNDTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+
+ case linux.SO_RCVTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetRecvTimeout(v.ToNsecCapped())
+ return nil
+
+ case linux.SO_OOBINLINE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ if v == 0 {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v)))
+
+ case linux.SO_LINGER:
+ if len(optVal) < linux.SizeOfLinger {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Linger
+ binary.Unmarshal(optVal[:linux.SizeOfLinger], usermem.ByteOrder, &v)
+
+ if v != (linux.Linger{}) {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ return nil
+
+ default:
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
+func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.TCP_NODELAY:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ var o tcpip.DelayOption
+ if v == 0 {
+ o = 1
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(o))
+
+ case linux.TCP_CORK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v)))
+
+ case linux.TCP_QUICKACK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v)))
+
+ case linux.TCP_KEEPIDLE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPIDLE {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIdleOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_KEEPINTVL:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPINTVL {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_REPAIR_OPTIONS:
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEventTCP(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
+func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.IPV6_V6ONLY:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v)))
+
+ case linux.IPV6_ADD_MEMBERSHIP,
+ linux.IPV6_DROP_MEMBERSHIP,
+ linux.IPV6_IPSEC_POLICY,
+ linux.IPV6_JOIN_ANYCAST,
+ linux.IPV6_LEAVE_ANYCAST,
+ linux.IPV6_PKTINFO,
+ linux.IPV6_ROUTER_ALERT,
+ linux.IPV6_XFRM_POLICY,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_JOIN_GROUP,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_UNBLOCK_SOURCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEventIPv6(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+var (
+ inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{}))
+ inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{}))
+)
+
+// copyInMulticastRequest copies in a variable-size multicast request. The
+// kernel determines which structure was passed by its length. IP_MULTICAST_IF
+// supports ip_mreqn, ip_mreq and in_addr, while IP_ADD_MEMBERSHIP and
+// IP_DROP_MEMBERSHIP only support ip_mreqn and ip_mreq. To handle this,
+// allowAddr controls whether in_addr is accepted or rejected.
+func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastRequestWithNIC, *syserr.Error) {
+ if len(optVal) < len(linux.InetAddr{}) {
+ return linux.InetMulticastRequestWithNIC{}, syserr.ErrInvalidArgument
+ }
+
+ if len(optVal) < inetMulticastRequestSize {
+ if !allowAddr {
+ return linux.InetMulticastRequestWithNIC{}, syserr.ErrInvalidArgument
+ }
+
+ var req linux.InetMulticastRequestWithNIC
+ copy(req.InterfaceAddr[:], optVal)
+ return req, nil
+ }
+
+ if len(optVal) >= inetMulticastRequestWithNICSize {
+ var req linux.InetMulticastRequestWithNIC
+ binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], usermem.ByteOrder, &req)
+ return req, nil
+ }
+
+ var req linux.InetMulticastRequestWithNIC
+ binary.Unmarshal(optVal[:inetMulticastRequestSize], usermem.ByteOrder, &req.InetMulticastRequest)
+ return req, nil
+}
+
+// parseIntOrChar copies either a 32-bit int or an 8-bit uint out of buf.
+//
+// net/ipv4/ip_sockglue.c:do_ip_setsockopt does this for its socket options.
+func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
+ if len(buf) == 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ if len(buf) >= sizeOfInt32 {
+ return int32(usermem.ByteOrder.Uint32(buf)), nil
+ }
+
+ return int32(buf[0]), nil
+}
+
+// setSockOptIP implements SetSockOpt when level is SOL_IP.
+func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.IP_MULTICAST_TTL:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ if v == -1 {
+ // Linux translates -1 to 1.
+ v = 1
+ }
+ if v < 0 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v)))
+
+ case linux.IP_ADD_MEMBERSHIP:
+ req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.AddMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ // TODO(igudger): Change AddMembership to use the standard
+ // any address representation.
+ InterfaceAddr: tcpip.Address(req.InterfaceAddr[:]),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IP_DROP_MEMBERSHIP:
+ req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.RemoveMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ // TODO(igudger): Change DropMembership to use the standard
+ // any address representation.
+ InterfaceAddr: tcpip.Address(req.InterfaceAddr[:]),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IP_MULTICAST_IF:
+ req, err := copyInMulticastRequest(optVal, true /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastInterfaceOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
+ }))
+
+ case linux.IP_MULTICAST_LOOP:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(
+ tcpip.MulticastLoopOption(v != 0),
+ ))
+
+ case linux.MCAST_JOIN_GROUP:
+ // FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return syserr.ErrInvalidArgument
+
+ case linux.IP_ADD_SOURCE_MEMBERSHIP,
+ linux.IP_BIND_ADDRESS_NO_PORT,
+ linux.IP_BLOCK_SOURCE,
+ linux.IP_CHECKSUM,
+ linux.IP_DROP_SOURCE_MEMBERSHIP,
+ linux.IP_FREEBIND,
+ linux.IP_HDRINCL,
+ linux.IP_IPSEC_POLICY,
+ linux.IP_MINTTL,
+ linux.IP_MSFILTER,
+ linux.IP_MTU_DISCOVER,
+ linux.IP_MULTICAST_ALL,
+ linux.IP_NODEFRAG,
+ linux.IP_OPTIONS,
+ linux.IP_PASSSEC,
+ linux.IP_PKTINFO,
+ linux.IP_RECVERR,
+ linux.IP_RECVFRAGSIZE,
+ linux.IP_RECVOPTS,
+ linux.IP_RECVORIGDSTADDR,
+ linux.IP_RECVTOS,
+ linux.IP_RECVTTL,
+ linux.IP_RETOPTS,
+ linux.IP_TOS,
+ linux.IP_TRANSPARENT,
+ linux.IP_TTL,
+ linux.IP_UNBLOCK_SOURCE,
+ linux.IP_UNICAST_IF,
+ linux.IP_XFRM_POLICY,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_MSFILTER,
+ linux.MCAST_UNBLOCK_SOURCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// emitUnimplementedEventTCP emits unimplemented event if name is valid. This
+// function contains names that are common between Get and SetSockOpt when
+// level is SOL_TCP.
+func emitUnimplementedEventTCP(t *kernel.Task, name int) {
+ switch name {
+ case linux.TCP_CONGESTION,
+ linux.TCP_CORK,
+ linux.TCP_DEFER_ACCEPT,
+ linux.TCP_FASTOPEN,
+ linux.TCP_FASTOPEN_CONNECT,
+ linux.TCP_FASTOPEN_KEY,
+ linux.TCP_FASTOPEN_NO_COOKIE,
+ linux.TCP_INQ,
+ linux.TCP_KEEPCNT,
+ linux.TCP_KEEPIDLE,
+ linux.TCP_KEEPINTVL,
+ linux.TCP_LINGER2,
+ linux.TCP_MAXSEG,
+ linux.TCP_QUEUE_SEQ,
+ linux.TCP_QUICKACK,
+ linux.TCP_REPAIR,
+ linux.TCP_REPAIR_QUEUE,
+ linux.TCP_REPAIR_WINDOW,
+ linux.TCP_SAVED_SYN,
+ linux.TCP_SAVE_SYN,
+ linux.TCP_SYNCNT,
+ linux.TCP_THIN_DUPACK,
+ linux.TCP_THIN_LINEAR_TIMEOUTS,
+ linux.TCP_TIMESTAMP,
+ linux.TCP_ULP,
+ linux.TCP_USER_TIMEOUT,
+ linux.TCP_WINDOW_CLAMP:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// emitUnimplementedEventIPv6 emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSockOpt when level is
+// SOL_IPV6.
+func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
+ switch name {
+ case linux.IPV6_2292DSTOPTS,
+ linux.IPV6_2292HOPLIMIT,
+ linux.IPV6_2292HOPOPTS,
+ linux.IPV6_2292PKTINFO,
+ linux.IPV6_2292PKTOPTIONS,
+ linux.IPV6_2292RTHDR,
+ linux.IPV6_ADDR_PREFERENCES,
+ linux.IPV6_AUTOFLOWLABEL,
+ linux.IPV6_DONTFRAG,
+ linux.IPV6_DSTOPTS,
+ linux.IPV6_FLOWINFO,
+ linux.IPV6_FLOWINFO_SEND,
+ linux.IPV6_FLOWLABEL_MGR,
+ linux.IPV6_FREEBIND,
+ linux.IPV6_HOPOPTS,
+ linux.IPV6_MINHOPCOUNT,
+ linux.IPV6_MTU,
+ linux.IPV6_MTU_DISCOVER,
+ linux.IPV6_MULTICAST_ALL,
+ linux.IPV6_MULTICAST_HOPS,
+ linux.IPV6_MULTICAST_IF,
+ linux.IPV6_MULTICAST_LOOP,
+ linux.IPV6_RECVDSTOPTS,
+ linux.IPV6_RECVERR,
+ linux.IPV6_RECVFRAGSIZE,
+ linux.IPV6_RECVHOPLIMIT,
+ linux.IPV6_RECVHOPOPTS,
+ linux.IPV6_RECVORIGDSTADDR,
+ linux.IPV6_RECVPATHMTU,
+ linux.IPV6_RECVPKTINFO,
+ linux.IPV6_RECVRTHDR,
+ linux.IPV6_RECVTCLASS,
+ linux.IPV6_RTHDR,
+ linux.IPV6_RTHDRDSTOPTS,
+ linux.IPV6_TCLASS,
+ linux.IPV6_TRANSPARENT,
+ linux.IPV6_UNICAST_HOPS,
+ linux.IPV6_UNICAST_IF,
+ linux.MCAST_MSFILTER,
+ linux.IPV6_ADDRFORM:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// emitUnimplementedEventIP emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSockOpt when level is
+// SOL_IP.
+func emitUnimplementedEventIP(t *kernel.Task, name int) {
+ switch name {
+ case linux.IP_TOS,
+ linux.IP_TTL,
+ linux.IP_HDRINCL,
+ linux.IP_OPTIONS,
+ linux.IP_ROUTER_ALERT,
+ linux.IP_RECVOPTS,
+ linux.IP_RETOPTS,
+ linux.IP_PKTINFO,
+ linux.IP_PKTOPTIONS,
+ linux.IP_MTU_DISCOVER,
+ linux.IP_RECVERR,
+ linux.IP_RECVTTL,
+ linux.IP_RECVTOS,
+ linux.IP_MTU,
+ linux.IP_FREEBIND,
+ linux.IP_IPSEC_POLICY,
+ linux.IP_XFRM_POLICY,
+ linux.IP_PASSSEC,
+ linux.IP_TRANSPARENT,
+ linux.IP_ORIGDSTADDR,
+ linux.IP_MINTTL,
+ linux.IP_NODEFRAG,
+ linux.IP_CHECKSUM,
+ linux.IP_BIND_ADDRESS_NO_PORT,
+ linux.IP_RECVFRAGSIZE,
+ linux.IP_MULTICAST_IF,
+ linux.IP_MULTICAST_TTL,
+ linux.IP_MULTICAST_LOOP,
+ linux.IP_ADD_MEMBERSHIP,
+ linux.IP_DROP_MEMBERSHIP,
+ linux.IP_UNBLOCK_SOURCE,
+ linux.IP_BLOCK_SOURCE,
+ linux.IP_ADD_SOURCE_MEMBERSHIP,
+ linux.IP_DROP_SOURCE_MEMBERSHIP,
+ linux.IP_MSFILTER,
+ linux.MCAST_JOIN_GROUP,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_UNBLOCK_SOURCE,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_MSFILTER,
+ linux.IP_MULTICAST_ALL,
+ linux.IP_UNICAST_IF:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ l := len([]byte(addr.Addr))
+ for i := 0; i < l; i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract and empty paths always return
+ // the full exact length.
+ if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
+ return out, uint32(2 + l)
+ }
+ return out, uint32(3 + l)
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = htons(addr.Port)
+ return out, uint32(binary.Size(out))
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == 4 {
+ // Copy address is v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return out, uint32(binary.Size(out))
+ default:
+ return nil, 0
+ }
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.Endpoint.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := ConvertAddress(s.family, addr)
+ return a, l, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.Endpoint.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := ConvertAddress(s.family, addr)
+ return a, l, nil
+}
+
+// coalescingRead is the fast path for non-blocking, non-peek, stream-based
+// case. It coalesces as many packets as possible before returning to the
+// caller.
+//
+// Precondition: s.readMu must be locked.
+func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
+ var err *syserr.Error
+ var copied int
+
+ // Copy as many views as possible into the user-provided buffer.
+ for dst.NumBytes() != 0 {
+ err = s.fetchReadView()
+ if err != nil {
+ break
+ }
+
+ var n int
+ var e error
+ if discard {
+ n = len(s.readView)
+ if int64(n) > dst.NumBytes() {
+ n = int(dst.NumBytes())
+ }
+ } else {
+ n, e = dst.CopyOut(ctx, s.readView)
+ // Set the control message, even if 0 bytes were read.
+ if e == nil {
+ s.updateTimestamp()
+ }
+ }
+ copied += n
+ s.readView.TrimFront(n)
+ dst = dst.DropFirst(n)
+ if e != nil {
+ err = syserr.FromError(e)
+ break
+ }
+ }
+
+ // If we managed to copy something, we must deliver it.
+ if copied > 0 {
+ return copied, nil
+ }
+
+ return 0, err
+}
+
+// nonBlockingRead issues a non-blocking read.
+//
+// TODO(b/78348848): Support timestamps for stream sockets.
+func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+ isPacket := s.isPacketBased()
+
+ // Fast path for regular reads from stream (e.g., TCP) endpoints. Note
+ // that senderRequested is ignored for stream sockets.
+ if !peek && !isPacket {
+ // TCP sockets discard the data if MSG_TRUNC is set.
+ //
+ // This behavior is documented in man 7 tcp:
+ // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
+ // argument of recv(2) (and recvmsg(2)). This flag causes the received
+ // bytes of data to be discarded, rather than passed back in a
+ // caller-supplied buffer.
+ s.readMu.Lock()
+ n, err := s.coalescingRead(ctx, dst, trunc)
+ s.readMu.Unlock()
+ return n, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+
+ if err := s.fetchReadView(); err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if !isPacket && peek && trunc {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read.
+ var rql tcpip.ReceiveQueueSizeOption
+ if err := s.Endpoint.GetSockOpt(&rql); err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
+ }
+ available := len(s.readView) + int(rql)
+ bufLen := int(dst.NumBytes())
+ if available < bufLen {
+ return available, 0, nil, 0, socket.ControlMessages{}, nil
+ }
+ return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
+ }
+
+ n, err := dst.CopyOut(ctx, s.readView)
+ // Set the control message, even if 0 bytes were read.
+ if err == nil {
+ s.updateTimestamp()
+ }
+ var addr interface{}
+ var addrLen uint32
+ if isPacket && senderRequested {
+ addr, addrLen = ConvertAddress(s.family, s.sender)
+ }
+
+ if peek {
+ if l := len(s.readView); trunc && l > n {
+ // isPacket must be true.
+ return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ }
+
+ if isPacket || err != nil {
+ return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ }
+
+ // We need to peek beyond the first message.
+ dst = dst.DropFirst(n)
+ num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
+ n, _, err := s.Endpoint.Peek(dsts)
+ // TODO(b/78348848): Handle peek timestamp.
+ if err != nil {
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+ }
+ return int64(n), nil
+ }})
+ n += int(num)
+ if err == syserror.ErrWouldBlock && n > 0 {
+ // We got some data, so no need to return an error.
+ err = nil
+ }
+ return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
+ }
+
+ var msgLen int
+ if isPacket {
+ msgLen = len(s.readView)
+ s.readView = nil
+ } else {
+ msgLen = int(n)
+ s.readView.TrimFront(int(n))
+ }
+
+ var flags int
+ if msgLen > int(n) {
+ flags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = msgLen
+ }
+
+ return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+}
+
+func (s *SocketOperations) controlMessages() socket.ControlMessages {
+ return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}}
+}
+
+// updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after
+// successfully writing packet data out to userspace.
+//
+// Precondition: s.readMu must be locked.
+func (s *SocketOperations) updateTimestamp() {
+ // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
+ if !s.sockOptTimestamp {
+ s.timestampValid = true
+ s.timestampNS = s.readCM.Timestamp
+ }
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+ peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
+ if senderRequested && !s.isPacketBased() {
+ // Stream sockets ignore the sender address.
+ senderRequested = false
+ }
+ n, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+
+ if s.isPacketBased() && err == syserr.ErrClosedForReceive && flags&linux.MSG_DONTWAIT != 0 {
+ // In this situation we should return EAGAIN.
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+
+ if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
+ // Read failed and we should not retry.
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
+ // We got all the data we need.
+ return
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst(n)
+
+ // We'll have to block. Register for notifications and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ var rn int
+ rn, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ n += rn
+ if err != nil && err != syserr.ErrWouldBlock {
+ // Always stop on errors other than would block as we generally
+ // won't be able to get any more data. Eat the error if we got
+ // any data.
+ if n > 0 {
+ err = nil
+ }
+ return
+ }
+ if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) {
+ // We got all the data we need.
+ return
+ }
+ dst = dst.DropFirst(rn)
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if n > 0 {
+ return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ // Reject Unix control messages.
+ if !controlMessages.Unix.Empty() {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ var addr *tcpip.FullAddress
+ if len(to) > 0 {
+ addrBuf, err := GetAddress(s.family, to)
+ if err != nil {
+ return 0, err
+ }
+
+ addr = &addrBuf
+ }
+
+ v := buffer.NewView(int(src.NumBytes()))
+
+ // Copy all the data into the buffer.
+ if _, err := src.CopyIn(t, v); err != nil {
+ return 0, syserr.FromError(err)
+ }
+
+ opts := tcpip.WriteOptions{
+ To: addr,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }
+
+ n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ if resCh != nil {
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err)
+ }
+ n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ }
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ if err == nil && (n >= uintptr(len(v)) || dontWait) {
+ // Complete write.
+ return int(n), nil
+ }
+ if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
+ return int(n), syserr.TranslateNetstackError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ v.TrimFront(int(n))
+ total := n
+ for {
+ n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ v.TrimFront(int(n))
+ total += n
+
+ if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
+ return 0, syserr.TranslateNetstackError(err)
+ }
+
+ if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ return int(total), nil
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(total), syserr.ErrTryAgain
+ }
+ // handleIOError will consume errors from t.Block if needed.
+ return int(total), syserr.FromError(err)
+ }
+ }
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint
+ // sockets.
+ // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
+ if int(args[1].Int()) == syscall.SIOCGSTAMP {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if !s.timestampValid {
+ return 0, syserror.ENOENT
+ }
+
+ tv := linux.NsecToTimeval(s.timestampNS)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ }
+
+ return Ioctl(ctx, s.Endpoint, io, args)
+}
+
+// Ioctl performs a socket ioctl.
+func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch arg := int(args[1].Int()); arg {
+ case syscall.SIOCGIFFLAGS,
+ syscall.SIOCGIFADDR,
+ syscall.SIOCGIFBRDADDR,
+ syscall.SIOCGIFDSTADDR,
+ syscall.SIOCGIFHWADDR,
+ syscall.SIOCGIFINDEX,
+ syscall.SIOCGIFMAP,
+ syscall.SIOCGIFMETRIC,
+ syscall.SIOCGIFMTU,
+ syscall.SIOCGIFNAME,
+ syscall.SIOCGIFNETMASK,
+ syscall.SIOCGIFTXQLEN:
+
+ var ifr linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil {
+ return 0, err.ToError()
+ }
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case syscall.SIOCGIFCONF:
+ // Return a list of interface addresses or the buffer size
+ // necessary to hold the list.
+ var ifc linux.IFConf
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ if err := ifconfIoctl(ctx, io, &ifc); err != nil {
+ return 0, err
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+
+ return 0, err
+
+ case linux.TIOCINQ:
+ var v tcpip.ReceiveQueueSizeOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCOUTQ:
+ var v tcpip.SendQueueSizeOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
+ unimpl.EmitUnimplementedEvent(ctx)
+ }
+
+ return 0, syserror.ENOTTY
+}
+
+// interfaceIoctl implements interface requests.
+func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
+ var (
+ iface inet.Interface
+ index int32
+ found bool
+ )
+
+ // Find the relevant device.
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+
+ // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to
+ // identify a device.
+ if arg == syscall.SIOCGIFNAME {
+ // Gets the name of the interface given the interface index
+ // stored in ifr_ifindex.
+ index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4]))
+ if iface, ok := stack.Interfaces()[index]; ok {
+ ifr.SetName(iface.Name)
+ return nil
+ }
+ return syserr.ErrNoDevice
+ }
+
+ // Find the relevant device.
+ for index, iface = range stack.Interfaces() {
+ if iface.Name == ifr.Name() {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return syserr.ErrNoDevice
+ }
+
+ switch arg {
+ case syscall.SIOCGIFINDEX:
+ // Copy out the index to the data.
+ usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
+
+ case syscall.SIOCGIFHWADDR:
+ // Copy the hardware address out.
+ ifr.Data[0] = 6 // IEEE802.2 arp type.
+ ifr.Data[1] = 0
+ n := copy(ifr.Data[2:], iface.Addr)
+ for i := 2 + n; i < len(ifr.Data); i++ {
+ ifr.Data[i] = 0 // Clear padding.
+ }
+ usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n))
+
+ case syscall.SIOCGIFFLAGS:
+ f, err := interfaceStatusFlags(stack, iface.Name)
+ if err != nil {
+ return err
+ }
+ // Drop the flags that don't fit in the size that we need to return. This
+ // matches Linux behavior.
+ usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
+
+ case syscall.SIOCGIFADDR:
+ // Copy the IPv4 address out.
+ for _, addr := range stack.InterfaceAddrs()[index] {
+ // This ioctl is only compatible with AF_INET addresses.
+ if addr.Family != linux.AF_INET {
+ continue
+ }
+ copy(ifr.Data[4:8], addr.Addr)
+ break
+ }
+
+ case syscall.SIOCGIFMETRIC:
+ // Gets the metric of the device. As per netdevice(7), this
+ // always just sets ifr_metric to 0.
+ usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
+
+ case syscall.SIOCGIFMTU:
+ // Gets the MTU of the device.
+ usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
+
+ case syscall.SIOCGIFMAP:
+ // Gets the hardware parameters of the device.
+ // TODO(b/71872867): Implement.
+
+ case syscall.SIOCGIFTXQLEN:
+ // Gets the transmit queue length of the device.
+ // TODO(b/71872867): Implement.
+
+ case syscall.SIOCGIFDSTADDR:
+ // Gets the destination address of a point-to-point device.
+ // TODO(b/71872867): Implement.
+
+ case syscall.SIOCGIFBRDADDR:
+ // Gets the broadcast address of a device.
+ // TODO(b/71872867): Implement.
+
+ case syscall.SIOCGIFNETMASK:
+ // Gets the network mask of a device.
+ for _, addr := range stack.InterfaceAddrs()[index] {
+ // This ioctl is only compatible with AF_INET addresses.
+ if addr.Family != linux.AF_INET {
+ continue
+ }
+ // Populate ifr.ifr_netmask (type sockaddr).
+ usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET))
+ usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ var mask uint32 = 0xffffffff << (32 - addr.PrefixLen)
+ // Netmask is expected to be returned as a big endian
+ // value.
+ binary.BigEndian.PutUint32(ifr.Data[4:8], mask)
+ break
+ }
+
+ default:
+ // Not a valid call.
+ return syserr.ErrInvalidArgument
+ }
+
+ return nil
+}
+
+// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
+func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+ // If Ptr is NULL, return the necessary buffer size via Len.
+ // Otherwise, write up to Len bytes starting at Ptr containing ifreq
+ // structs.
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ return syserr.ErrNoDevice.ToError()
+ }
+
+ if ifc.Ptr == 0 {
+ ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ return nil
+ }
+
+ max := ifc.Len
+ ifc.Len = 0
+ for key, ifaceAddrs := range stack.InterfaceAddrs() {
+ iface := stack.Interfaces()[key]
+ for _, ifaceAddr := range ifaceAddrs {
+ // Don't write past the end of the buffer.
+ if ifc.Len+int32(linux.SizeOfIFReq) > max {
+ break
+ }
+ if ifaceAddr.Family != linux.AF_INET {
+ continue
+ }
+
+ // Populate ifr.ifr_addr.
+ ifr := linux.IFReq{}
+ ifr.SetName(iface.Name)
+ usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
+ usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
+
+ // Copy the ifr to userspace.
+ dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
+ ifc.Len += int32(linux.SizeOfIFReq)
+ if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// interfaceStatusFlags returns status flags for an interface in the stack.
+// Flag values and meanings are described in greater detail in netdevice(7) in
+// the SIOCGIFFLAGS section.
+func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) {
+ // epsocket should only ever be passed an epsocket.Stack.
+ epstack, ok := stack.(*Stack)
+ if !ok {
+ return 0, errStackType
+ }
+
+ // Find the NIC corresponding to this interface.
+ for _, info := range epstack.Stack.NICInfo() {
+ if info.Name == name {
+ return nicStateFlagsToLinux(info.Flags), nil
+ }
+ }
+ return 0, syserr.ErrNoDevice
+}
+
+func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
+ var rv uint32
+ if f.Up {
+ rv |= linux.IFF_UP | linux.IFF_LOWER_UP
+ }
+ if f.Running {
+ rv |= linux.IFF_RUNNING
+ }
+ if f.Promiscuous {
+ rv |= linux.IFF_PROMISC
+ }
+ if f.Loopback {
+ rv |= linux.IFF_LOOPBACK
+ }
+ return rv
+}
diff --git a/pkg/sentry/socket/epsocket/epsocket_state_autogen.go b/pkg/sentry/socket/epsocket/epsocket_state_autogen.go
new file mode 100755
index 000000000..4b407b796
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/epsocket_state_autogen.go
@@ -0,0 +1,52 @@
+// automatically generated by stateify.
+
+package epsocket
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *SocketOperations) beforeSave() {}
+func (x *SocketOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SendReceiveTimeout", &x.SendReceiveTimeout)
+ m.Save("Queue", &x.Queue)
+ m.Save("family", &x.family)
+ m.Save("Endpoint", &x.Endpoint)
+ m.Save("skType", &x.skType)
+ m.Save("readView", &x.readView)
+ m.Save("readCM", &x.readCM)
+ m.Save("sender", &x.sender)
+ m.Save("sockOptTimestamp", &x.sockOptTimestamp)
+ m.Save("timestampValid", &x.timestampValid)
+ m.Save("timestampNS", &x.timestampNS)
+}
+
+func (x *SocketOperations) afterLoad() {}
+func (x *SocketOperations) load(m state.Map) {
+ m.Load("SendReceiveTimeout", &x.SendReceiveTimeout)
+ m.Load("Queue", &x.Queue)
+ m.Load("family", &x.family)
+ m.Load("Endpoint", &x.Endpoint)
+ m.Load("skType", &x.skType)
+ m.Load("readView", &x.readView)
+ m.Load("readCM", &x.readCM)
+ m.Load("sender", &x.sender)
+ m.Load("sockOptTimestamp", &x.sockOptTimestamp)
+ m.Load("timestampValid", &x.timestampValid)
+ m.Load("timestampNS", &x.timestampNS)
+}
+
+func (x *Stack) beforeSave() {}
+func (x *Stack) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *Stack) load(m state.Map) {
+ m.AfterLoad(x.afterLoad)
+}
+
+func init() {
+ state.Register("epsocket.SocketOperations", (*SocketOperations)(nil), state.Fns{Save: (*SocketOperations).save, Load: (*SocketOperations).load})
+ state.Register("epsocket.Stack", (*Stack)(nil), state.Fns{Save: (*Stack).save, Load: (*Stack).load})
+}
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
new file mode 100644
index 000000000..ec930d8d5
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -0,0 +1,140 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epsocket
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// provider is an inet socket provider.
+type provider struct {
+ family int
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// getTransportProtocol figures out transport protocol. Currently only TCP,
+// UDP, and ICMP are supported.
+func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
+ switch stype {
+ case linux.SOCK_STREAM:
+ if protocol != 0 && protocol != syscall.IPPROTO_TCP {
+ return 0, syserr.ErrInvalidArgument
+ }
+ return tcp.ProtocolNumber, nil
+
+ case linux.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ return udp.ProtocolNumber, nil
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, nil
+ }
+
+ case linux.SOCK_RAW:
+ // Raw sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ switch protocol {
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, nil
+ case syscall.IPPROTO_UDP:
+ return header.UDPProtocolNumber, nil
+ case syscall.IPPROTO_TCP:
+ return header.TCPProtocolNumber, nil
+ }
+ }
+ return 0, syserr.ErrProtocolNotSupported
+}
+
+// Socket creates a new socket object for the AF_INET or AF_INET6 family.
+func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Fail right away if we don't have a stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ // Don't propagate an error here. Instead, allow the socket
+ // code to continue searching for another provider.
+ return nil, nil
+ }
+ eps, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Figure out the transport protocol.
+ transProto, err := getTransportProtocol(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
+ wq := &waiter.Queue{}
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+ }
+ if e != nil {
+ return nil, syserr.TranslateNetstackError(e)
+ }
+
+ return New(t, p.family, stype, wq, ep)
+}
+
+// Pair just returns nil sockets (not supported).
+func (*provider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+ return nil, nil, nil
+}
+
+// init registers socket providers for AF_INET and AF_INET6.
+func init() {
+ // Providers backed by netstack.
+ p := []provider{
+ {
+ family: linux.AF_INET,
+ netProto: ipv4.ProtocolNumber,
+ },
+ {
+ family: linux.AF_INET6,
+ netProto: ipv6.ProtocolNumber,
+ },
+ }
+
+ for i := range p {
+ socket.RegisterProvider(p[i].family, &p[i])
+ }
+}
diff --git a/pkg/sentry/socket/epsocket/save_restore.go b/pkg/sentry/socket/epsocket/save_restore.go
new file mode 100644
index 000000000..feaafb7cc
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epsocket
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// afterLoad is invoked by stateify.
+func (s *Stack) afterLoad() {
+ s.Stack = stack.StackFromEnv // FIXME(b/36201077)
+ if s.Stack == nil {
+ panic("can't restore without netstack/tcpip/stack.Stack")
+ }
+}
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go
new file mode 100644
index 000000000..edefa225b
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/stack.go
@@ -0,0 +1,140 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epsocket
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+)
+
+// Stack implements inet.Stack for netstack/tcpip/stack.Stack.
+//
+// +stateify savable
+type Stack struct {
+ Stack *stack.Stack `state:"manual"`
+}
+
+// SupportsIPv6 implements Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber)
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ is := make(map[int32]inet.Interface)
+ for id, ni := range s.Stack.NICInfo() {
+ var devType uint16
+ if ni.Flags.Loopback {
+ devType = linux.ARPHRD_LOOPBACK
+ }
+ is[int32(id)] = inet.Interface{
+ Name: ni.Name,
+ Addr: []byte(ni.LinkAddress),
+ Flags: uint32(nicStateFlagsToLinux(ni.Flags)),
+ DeviceType: devType,
+ MTU: ni.MTU,
+ }
+ }
+ return is
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ nicAddrs := make(map[int32][]inet.InterfaceAddr)
+ for id, ni := range s.Stack.NICInfo() {
+ var addrs []inet.InterfaceAddr
+ for _, a := range ni.ProtocolAddresses {
+ var family uint8
+ switch a.Protocol {
+ case ipv4.ProtocolNumber:
+ family = linux.AF_INET
+ case ipv6.ProtocolNumber:
+ family = linux.AF_INET6
+ default:
+ log.Warningf("Unknown network protocol in %+v", a)
+ continue
+ }
+
+ addrs = append(addrs, inet.InterfaceAddr{
+ Family: family,
+ PrefixLen: uint8(len(a.Address) * 8),
+ Addr: []byte(a.Address),
+ // TODO(b/68878065): Other fields.
+ })
+ }
+ nicAddrs[int32(id)] = addrs
+ }
+ return nicAddrs
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ var rs tcp.ReceiveBufferSizeOption
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs)
+ return inet.TCPBufferSize{
+ Min: rs.Min,
+ Default: rs.Default,
+ Max: rs.Max,
+ }, syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ rs := tcp.ReceiveBufferSizeOption{
+ Min: size.Min,
+ Default: size.Default,
+ Max: size.Max,
+ }
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError()
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ var ss tcp.SendBufferSizeOption
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss)
+ return inet.TCPBufferSize{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }, syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ ss := tcp.SendBufferSizeOption{
+ Min: size.Min,
+ Default: size.Default,
+ Max: size.Max,
+ }
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError()
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ var sack tcp.SACKEnabled
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack)
+ return bool(sack), syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
+}
diff --git a/pkg/sentry/socket/hostinet/device.go b/pkg/sentry/socket/hostinet/device.go
new file mode 100644
index 000000000..4267e3691
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/device.go
@@ -0,0 +1,19 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+var socketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/hostinet/hostinet.go b/pkg/sentry/socket/hostinet/hostinet.go
new file mode 100644
index 000000000..0d6f51d2b
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/hostinet.go
@@ -0,0 +1,17 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostinet implements AF_INET and AF_INET6 sockets using the host's
+// network stack.
+package hostinet
diff --git a/pkg/sentry/socket/hostinet/hostinet_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go
new file mode 100755
index 000000000..0a5c7cdf3
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package hostinet
+
diff --git a/pkg/sentry/socket/hostinet/save_restore.go b/pkg/sentry/socket/hostinet/save_restore.go
new file mode 100644
index 000000000..1dec33897
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/save_restore.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+// beforeSave is invoked by stateify.
+func (*socketOperations) beforeSave() {
+ panic("host.socketOperations is not savable")
+}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
new file mode 100644
index 000000000..41f9693bb
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -0,0 +1,578 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/fdnotifier"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ sizeofInt32 = 4
+
+ // sizeofSockaddr is the size in bytes of the largest sockaddr type
+ // supported by this package.
+ sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
+)
+
+// socketOperations implements fs.FileOperations and socket.Socket for a socket
+// implemented using a host socket.
+type socketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ socket.SendReceiveTimeout
+
+ family int // Read-only.
+ fd int // must be O_NONBLOCK
+ queue waiter.Queue
+}
+
+var _ = socket.Socket(&socketOperations{})
+
+func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
+ s := &socketOperations{family: family, fd: fd}
+ if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ dirent := socket.NewDirent(ctx, socketDevice)
+ defer dirent.DecRef()
+ return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true}, s), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOperations) Release() {
+ fdnotifier.RemoveFD(int32(s.fd))
+ syscall.Close(s.fd)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOperations) EventUnregister(e *waiter.Entry) {
+ s.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ if dsts.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ n, err := syscall.Read(s.fd, dsts.Head().ToSlice())
+ if err != nil {
+ return 0, translateIOSyscallError(err)
+ }
+ return uint64(n), nil
+ }
+ return readv(s.fd, iovecsFromBlockSeq(dsts))
+ }))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of src.Addrs was unusable.
+ if uint64(src.NumBytes()) != srcs.NumBytes() {
+ return 0, nil
+ }
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ if srcs.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ n, err := syscall.Write(s.fd, srcs.Head().ToSlice())
+ if err != nil {
+ return 0, translateIOSyscallError(err)
+ }
+ return uint64(n), nil
+ }
+ return writev(s.fd, iovecsFromBlockSeq(srcs))
+ }))
+ return int64(n), err
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ if len(sockaddr) > sizeofSockaddr {
+ sockaddr = sockaddr[:sizeofSockaddr]
+ }
+
+ _, _, errno := syscall.Syscall(syscall.SYS_CONNECT, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
+
+ if errno == 0 {
+ return nil
+ }
+ if errno != syscall.EINPROGRESS || !blocking {
+ return syserr.FromError(translateIOSyscallError(errno))
+ }
+
+ // "EINPROGRESS: The socket is nonblocking and the connection cannot be
+ // completed immediately. It is possible to select(2) or poll(2) for
+ // completion by selecting the socket for writing. After select(2)
+ // indicates writability, use getsockopt(2) to read the SO_ERROR option at
+ // level SOL-SOCKET to determine whether connect() completed successfully
+ // (SO_ERROR is zero) or unsuccessfully (SO_ERROR is one of the usual error
+ // codes listed here, explaining the reason for the failure)." - connect(2)
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+ if s.Readiness(waiter.EventOut)&waiter.EventOut == 0 {
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+ }
+ val, err := syscall.GetsockoptInt(s.fd, syscall.SOL_SOCKET, syscall.SO_ERROR)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ if val != 0 {
+ return syserr.FromError(syscall.Errno(uintptr(val)))
+ }
+ return nil
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ var peerAddr []byte
+ var peerAddrlen uint32
+ var peerAddrPtr *byte
+ var peerAddrlenPtr *uint32
+ if peerRequested {
+ peerAddr = make([]byte, sizeofSockaddr)
+ peerAddrlen = uint32(len(peerAddr))
+ peerAddrPtr = &peerAddr[0]
+ peerAddrlenPtr = &peerAddrlen
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it.
+ fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
+ if blocking {
+ var ch chan struct{}
+ for syscallErr == syserror.ErrWouldBlock {
+ if ch != nil {
+ if syscallErr = t.Block(ch); syscallErr != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+ }
+ fd, syscallErr = accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
+ }
+ }
+
+ if peerRequested {
+ peerAddr = peerAddr[:peerAddrlen]
+ }
+ if syscallErr != nil {
+ return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
+ }
+
+ f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef()
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ }
+ kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits())
+ t.Kernel().RecordSocket(f, s.family)
+ return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) > sizeofSockaddr {
+ sockaddr = sockaddr[:sizeofSockaddr]
+ }
+
+ _, _, errno := syscall.Syscall(syscall.SYS_BIND, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
+ if errno != 0 {
+ return syserr.FromError(errno)
+ }
+ return nil
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.FromError(syscall.Listen(s.fd, backlog))
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ switch how {
+ case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR:
+ return syserr.FromError(syscall.Shutdown(s.fd, how))
+ default:
+ return syserr.ErrInvalidArgument
+ }
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+ if outLen < 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Whitelist options and constrain option length.
+ var optlen int
+ switch level {
+ case syscall.SOL_IPV6:
+ switch name {
+ case syscall.IPV6_V6ONLY:
+ optlen = sizeofInt32
+ }
+ case syscall.SOL_SOCKET:
+ switch name {
+ case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR, syscall.SO_TYPE:
+ optlen = sizeofInt32
+ case syscall.SO_LINGER:
+ optlen = syscall.SizeofLinger
+ }
+ case syscall.SOL_TCP:
+ switch name {
+ case syscall.TCP_NODELAY:
+ optlen = sizeofInt32
+ case syscall.TCP_INFO:
+ optlen = int(linux.SizeOfTCPInfo)
+ }
+ }
+ if optlen == 0 {
+ return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT
+ }
+ if outLen < optlen {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ opt, err := getsockopt(s.fd, level, name, optlen)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return opt, nil
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ // Whitelist options and constrain option length.
+ var optlen int
+ switch level {
+ case syscall.SOL_IPV6:
+ switch name {
+ case syscall.IPV6_V6ONLY:
+ optlen = sizeofInt32
+ }
+ case syscall.SOL_SOCKET:
+ switch name {
+ case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR:
+ optlen = sizeofInt32
+ }
+ case syscall.SOL_TCP:
+ switch name {
+ case syscall.TCP_NODELAY:
+ optlen = sizeofInt32
+ }
+ }
+ if optlen == 0 {
+ // Pretend to accept socket options we don't understand. This seems
+ // dangerous, but it's what netstack does...
+ return nil
+ }
+ if len(opt) < optlen {
+ return syserr.ErrInvalidArgument
+ }
+ opt = opt[:optlen]
+
+ _, _, errno := syscall.Syscall6(syscall.SYS_SETSOCKOPT, uintptr(s.fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(len(opt)), 0)
+ if errno != 0 {
+ return syserr.FromError(errno)
+ }
+ return nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+ // Whitelist flags.
+ //
+ // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
+ // messages that netstack/tcpip/transport/unix doesn't understand. Kill the
+ // Socket interface's dependence on netstack.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
+ }
+
+ var senderAddr []byte
+ if senderRequested {
+ senderAddr = make([]byte, sizeofSockaddr)
+ }
+
+ var msgFlags int
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ if dsts.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr)
+ }
+
+ iovs := iovecsFromBlockSeq(dsts)
+ msg := syscall.Msghdr{
+ Iov: &iovs[0],
+ Iovlen: uint64(len(iovs)),
+ }
+ if len(senderAddr) != 0 {
+ msg.Name = &senderAddr[0]
+ msg.Namelen = uint32(len(senderAddr))
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0, err
+ }
+ senderAddr = senderAddr[:msg.Namelen]
+ msgFlags = int(msg.Flags)
+ return n, nil
+ })
+
+ var ch chan struct{}
+ n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
+ if flags&syscall.MSG_DONTWAIT == 0 {
+ for err == syserror.ErrWouldBlock {
+ // We only expect blocking to come from the actual syscall, in which
+ // case it can't have returned any data.
+ if n != 0 {
+ panic(fmt.Sprintf("CopyOutFrom: got (%d, %v), wanted (0, %v)", n, err, err))
+ }
+ if ch != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+ }
+ n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
+ }
+
+ // We don't allow control messages.
+ msgFlags &^= linux.MSG_CTRUNC
+
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err)
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ // Whitelist flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of src.Addrs was unusable.
+ if uint64(src.NumBytes()) != srcs.NumBytes() {
+ return 0, nil
+ }
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+
+ // We always do a non-blocking send*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ if srcs.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ src := srcs.Head()
+ n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+ }
+
+ iovs := iovecsFromBlockSeq(srcs)
+ msg := syscall.Msghdr{
+ Iov: &iovs[0],
+ Iovlen: uint64(len(iovs)),
+ }
+ if len(to) != 0 {
+ msg.Name = &to[0]
+ msg.Namelen = uint32(len(to))
+ }
+ return sendmsg(s.fd, &msg, sysflags)
+ })
+
+ var ch chan struct{}
+ n, err := src.CopyInTo(t, sendmsgFromBlocks)
+ if flags&syscall.MSG_DONTWAIT == 0 {
+ for err == syserror.ErrWouldBlock {
+ // We only expect blocking to come from the actual syscall, in which
+ // case it can't have returned any data.
+ if n != 0 {
+ panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err))
+ }
+ if ch != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+ }
+ n, err = src.CopyInTo(t, sendmsgFromBlocks)
+ }
+ }
+
+ return int(n), syserr.FromError(err)
+}
+
+func iovecsFromBlockSeq(bs safemem.BlockSeq) []syscall.Iovec {
+ iovs := make([]syscall.Iovec, 0, bs.NumBlocks())
+ for ; !bs.IsEmpty(); bs = bs.Tail() {
+ b := bs.Head()
+ iovs = append(iovs, syscall.Iovec{
+ Base: &b.ToSlice()[0],
+ Len: uint64(b.Len()),
+ })
+ // We don't need to care about b.NeedSafecopy(), because the host
+ // kernel will handle such address ranges just fine (by returning
+ // EFAULT).
+ }
+ return iovs
+}
+
+func translateIOSyscallError(err error) error {
+ if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+type socketProvider struct {
+ family int
+}
+
+// Socket implements socket.Provider.Socket.
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check that we are using the host network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+ if _, ok := stack.(*Stack); !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
+ // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
+ fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
+
+func init() {
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socket.RegisterProvider(family, &socketProvider{family})
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
new file mode 100644
index 000000000..eed0c7837
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -0,0 +1,138 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+func firstBytePtr(bs []byte) unsafe.Pointer {
+ if bs == nil {
+ return nil
+ }
+ return unsafe.Pointer(&bs[0])
+}
+
+// Preconditions: len(dsts) != 0.
+func readv(fd int, dsts []syscall.Iovec) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&dsts[0])), uintptr(len(dsts)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+// Preconditions: len(srcs) != 0.
+func writev(fd int, srcs []syscall.Iovec) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&srcs[0])), uintptr(len(srcs)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *socketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := uintptr(args[1].Int()); cmd {
+ case syscall.TIOCINQ, syscall.TIOCOUTQ:
+ var val int32
+ if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(s.fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], uint32(val))
+ _, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func accept4(fd int, addr *byte, addrlen *uint32, flags int) (int, error) {
+ afd, _, errno := syscall.Syscall6(syscall.SYS_ACCEPT4, uintptr(fd), uintptr(unsafe.Pointer(addr)), uintptr(unsafe.Pointer(addrlen)), uintptr(flags), 0, 0)
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return int(afd), nil
+}
+
+func getsockopt(fd int, level, name int, optlen int) ([]byte, error) {
+ opt := make([]byte, optlen)
+ optlen32 := int32(len(opt))
+ _, _, errno := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(unsafe.Pointer(&optlen32)), 0)
+ if errno != 0 {
+ return nil, errno
+ }
+ return opt[:optlen32], nil
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr := make([]byte, sizeofSockaddr)
+ addrlen := uint32(len(addr))
+ _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
+ if errno != 0 {
+ return nil, 0, syserr.FromError(errno)
+ }
+ return addr[:addrlen], addrlen, nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr := make([]byte, sizeofSockaddr)
+ addrlen := uint32(len(addr))
+ _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
+ if errno != 0 {
+ return nil, 0, syserr.FromError(errno)
+ }
+ return addr[:addrlen], addrlen, nil
+}
+
+func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) {
+ fromLen := uint32(len(*from))
+ n, _, errno := syscall.Syscall6(syscall.SYS_RECVFROM, uintptr(fd), uintptr(firstBytePtr(dst)), uintptr(len(dst)), uintptr(flags), uintptr(firstBytePtr(*from)), uintptr(unsafe.Pointer(&fromLen)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ *from = (*from)[:fromLen]
+ return uint64(n), nil
+}
+
+func recvmsg(fd int, msg *syscall.Msghdr, flags int) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+func sendmsg(fd int, msg *syscall.Msghdr, flags int) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
new file mode 100644
index 000000000..9c45991ba
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -0,0 +1,246 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+var defaultRecvBufSize = inet.TCPBufferSize{
+ Min: 4096,
+ Default: 87380,
+ Max: 6291456,
+}
+
+var defaultSendBufSize = inet.TCPBufferSize{
+ Min: 4096,
+ Default: 16384,
+ Max: 4194304,
+}
+
+// Stack implements inet.Stack for host sockets.
+type Stack struct {
+ // Stack is immutable.
+ interfaces map[int32]inet.Interface
+ interfaceAddrs map[int32][]inet.InterfaceAddr
+ supportsIPv6 bool
+ tcpRecvBufSize inet.TCPBufferSize
+ tcpSendBufSize inet.TCPBufferSize
+ tcpSACKEnabled bool
+}
+
+// NewStack returns an empty Stack containing no configuration.
+func NewStack() *Stack {
+ return &Stack{
+ interfaces: make(map[int32]inet.Interface),
+ interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
+ }
+}
+
+// Configure sets up the stack using the current state of the host network.
+func (s *Stack) Configure() error {
+ if err := addHostInterfaces(s); err != nil {
+ return err
+ }
+
+ if _, err := os.Stat("/proc/net/if_inet6"); err == nil {
+ s.supportsIPv6 = true
+ }
+
+ s.tcpRecvBufSize = defaultRecvBufSize
+ if tcpRMem, err := readTCPBufferSizeFile("/proc/sys/net/ipv4/tcp_rmem"); err == nil {
+ s.tcpRecvBufSize = tcpRMem
+ } else {
+ log.Warningf("Failed to read TCP receive buffer size, using default values")
+ }
+
+ s.tcpSendBufSize = defaultSendBufSize
+ if tcpWMem, err := readTCPBufferSizeFile("/proc/sys/net/ipv4/tcp_wmem"); err == nil {
+ s.tcpSendBufSize = tcpWMem
+ } else {
+ log.Warningf("Failed to read TCP send buffer size, using default values")
+ }
+
+ // SACK is important for performance and even compatibility, assume it's
+ // enabled if we can't find the actual value.
+ s.tcpSACKEnabled = true
+ if sack, err := ioutil.ReadFile("/proc/sys/net/ipv4/tcp_sack"); err == nil {
+ s.tcpSACKEnabled = strings.TrimSpace(string(sack)) != "0"
+ } else {
+ log.Warningf("Failed to read if TCP SACK if enabled, setting to true")
+ }
+
+ return nil
+}
+
+// ExtractHostInterfaces will populate an interface map and
+// interfaceAddrs map with the results of the equivalent
+// netlink messages.
+func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.NetlinkMessage, interfaces map[int32]inet.Interface, interfaceAddrs map[int32][]inet.InterfaceAddr) error {
+ for _, link := range links {
+ if link.Header.Type != syscall.RTM_NEWLINK {
+ continue
+ }
+ if len(link.Data) < syscall.SizeofIfInfomsg {
+ return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), syscall.SizeofIfInfomsg)
+ }
+ var ifinfo syscall.IfInfomsg
+ binary.Unmarshal(link.Data[:syscall.SizeofIfInfomsg], usermem.ByteOrder, &ifinfo)
+ inetIF := inet.Interface{
+ DeviceType: ifinfo.Type,
+ Flags: ifinfo.Flags,
+ }
+ // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the
+ // syscall.NetlinkMessage.Header.Type and skip the struct ifinfomsg
+ // accordingly.
+ attrs, err := syscall.ParseNetlinkRouteAttr(&link)
+ if err != nil {
+ return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid rtattrs: %v", err)
+ }
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.IFLA_ADDRESS:
+ inetIF.Addr = attr.Value
+ case syscall.IFLA_IFNAME:
+ inetIF.Name = string(attr.Value[:len(attr.Value)-1])
+ }
+ }
+ interfaces[ifinfo.Index] = inetIF
+ }
+
+ for _, addr := range addrs {
+ if addr.Header.Type != syscall.RTM_NEWADDR {
+ continue
+ }
+ if len(addr.Data) < syscall.SizeofIfAddrmsg {
+ return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), syscall.SizeofIfAddrmsg)
+ }
+ var ifaddr syscall.IfAddrmsg
+ binary.Unmarshal(addr.Data[:syscall.SizeofIfAddrmsg], usermem.ByteOrder, &ifaddr)
+ inetAddr := inet.InterfaceAddr{
+ Family: ifaddr.Family,
+ PrefixLen: ifaddr.Prefixlen,
+ Flags: ifaddr.Flags,
+ }
+ attrs, err := syscall.ParseNetlinkRouteAttr(&addr)
+ if err != nil {
+ return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid rtattrs: %v", err)
+ }
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.IFA_ADDRESS:
+ inetAddr.Addr = attr.Value
+ }
+ }
+ interfaceAddrs[int32(ifaddr.Index)] = append(interfaceAddrs[int32(ifaddr.Index)], inetAddr)
+ }
+
+ return nil
+}
+
+func addHostInterfaces(s *Stack) error {
+ links, err := doNetlinkRouteRequest(syscall.RTM_GETLINK)
+ if err != nil {
+ return fmt.Errorf("RTM_GETLINK failed: %v", err)
+ }
+
+ addrs, err := doNetlinkRouteRequest(syscall.RTM_GETADDR)
+ if err != nil {
+ return fmt.Errorf("RTM_GETADDR failed: %v", err)
+ }
+
+ return ExtractHostInterfaces(links, addrs, s.interfaces, s.interfaceAddrs)
+}
+
+func doNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
+ data, err := syscall.NetlinkRIB(req, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, err
+ }
+ return syscall.ParseNetlinkMessage(data)
+}
+
+func readTCPBufferSizeFile(filename string) (inet.TCPBufferSize, error) {
+ contents, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to read %s: %v", filename, err)
+ }
+ ioseq := usermem.BytesIOSequence(contents)
+ fields := make([]int32, 3)
+ if n, err := usermem.CopyInt32StringsInVec(context.Background(), ioseq.IO, ioseq.Addrs, fields, ioseq.Opts); n != ioseq.NumBytes() || err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to parse %s (%q): got %v after %d/%d bytes", filename, contents, err, n, ioseq.NumBytes())
+ }
+ return inet.TCPBufferSize{
+ Min: int(fields[0]),
+ Default: int(fields[1]),
+ Max: int(fields[2]),
+ }, nil
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ return s.interfaces
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ return s.interfaceAddrs
+}
+
+// SupportsIPv6 implements inet.Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.supportsIPv6
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpRecvBufSize, nil
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ return syserror.EACCES
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpSendBufSize, nil
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ return syserror.EACCES
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ return s.tcpSACKEnabled, nil
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ return syserror.EACCES
+}
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
new file mode 100644
index 000000000..5bd3b49ce
--- /dev/null
+++ b/pkg/sentry/socket/netlink/message.go
@@ -0,0 +1,159 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// alignUp rounds a length up to an alignment.
+//
+// Preconditions: align is a power of two.
+func alignUp(length int, align uint) int {
+ return (length + int(align) - 1) &^ (int(align) - 1)
+}
+
+// Message contains a complete serialized netlink message.
+type Message struct {
+ buf []byte
+}
+
+// NewMessage creates a new Message containing the passed header.
+//
+// The header length will be updated by Finalize.
+func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
+ return &Message{
+ buf: binary.Marshal(nil, usermem.ByteOrder, hdr),
+ }
+}
+
+// Finalize returns the []byte containing the entire message, with the total
+// length set in the message header. The Message must not be modified after
+// calling Finalize.
+func (m *Message) Finalize() []byte {
+ // Update length, which is the first 4 bytes of the header.
+ usermem.ByteOrder.PutUint32(m.buf, uint32(len(m.buf)))
+
+ // Align the message. Note that the message length in the header (set
+ // above) is the useful length of the message, not the total aligned
+ // length. See net/netlink/af_netlink.c:__nlmsg_put.
+ aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ m.putZeros(aligned - len(m.buf))
+ return m.buf
+}
+
+// putZeros adds n zeros to the message.
+func (m *Message) putZeros(n int) {
+ for n > 0 {
+ m.buf = append(m.buf, 0)
+ n--
+ }
+}
+
+// Put serializes v into the message.
+func (m *Message) Put(v interface{}) {
+ m.buf = binary.Marshal(m.buf, usermem.ByteOrder, v)
+}
+
+// PutAttr adds v to the message as a netlink attribute.
+//
+// Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
+// binary.Size(v) fits in math.MaxUint16 bytes.
+func (m *Message) PutAttr(atype uint16, v interface{}) {
+ l := linux.NetlinkAttrHeaderSize + int(binary.Size(v))
+ if l > math.MaxUint16 {
+ panic(fmt.Sprintf("attribute too large: %d", l))
+ }
+
+ m.Put(linux.NetlinkAttrHeader{
+ Type: atype,
+ Length: uint16(l),
+ })
+ m.Put(v)
+
+ // Align the attribute.
+ aligned := alignUp(l, linux.NLA_ALIGNTO)
+ m.putZeros(aligned - l)
+}
+
+// PutAttrString adds s to the message as a netlink attribute.
+func (m *Message) PutAttrString(atype uint16, s string) {
+ l := linux.NetlinkAttrHeaderSize + len(s) + 1
+ m.Put(linux.NetlinkAttrHeader{
+ Type: atype,
+ Length: uint16(l),
+ })
+
+ // String + NUL-termination.
+ m.Put([]byte(s))
+ m.putZeros(1)
+
+ // Align the attribute.
+ aligned := alignUp(l, linux.NLA_ALIGNTO)
+ m.putZeros(aligned - l)
+}
+
+// MessageSet contains a series of netlink messages.
+type MessageSet struct {
+ // Multi indicates that this a multi-part message, to be terminated by
+ // NLMSG_DONE. NLMSG_DONE is sent even if the set contains only one
+ // Message.
+ //
+ // If Multi is set, all added messages will have NLM_F_MULTI set.
+ Multi bool
+
+ // PortID is the destination port for all messages.
+ PortID int32
+
+ // Seq is the sequence counter for all messages in the set.
+ Seq uint32
+
+ // Messages contains the messages in the set.
+ Messages []*Message
+}
+
+// NewMessageSet creates a new MessageSet.
+//
+// portID is the destination port to set as PortID in all messages.
+//
+// seq is the sequence counter to set as seq in all messages in the set.
+func NewMessageSet(portID int32, seq uint32) *MessageSet {
+ return &MessageSet{
+ PortID: portID,
+ Seq: seq,
+ }
+}
+
+// AddMessage adds a new message to the set and returns it for further
+// additions.
+//
+// The passed header will have Seq, PortID and the multi flag set
+// automatically.
+func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message {
+ hdr.Seq = ms.Seq
+ hdr.PortID = uint32(ms.PortID)
+ if ms.Multi {
+ hdr.Flags |= linux.NLM_F_MULTI
+ }
+
+ m := NewMessage(hdr)
+ ms.Messages = append(ms.Messages, m)
+ return m
+}
diff --git a/pkg/sentry/socket/netlink/netlink_state_autogen.go b/pkg/sentry/socket/netlink/netlink_state_autogen.go
new file mode 100755
index 000000000..59d902798
--- /dev/null
+++ b/pkg/sentry/socket/netlink/netlink_state_autogen.go
@@ -0,0 +1,36 @@
+// automatically generated by stateify.
+
+package netlink
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Socket) beforeSave() {}
+func (x *Socket) save(m state.Map) {
+ x.beforeSave()
+ m.Save("SendReceiveTimeout", &x.SendReceiveTimeout)
+ m.Save("ports", &x.ports)
+ m.Save("protocol", &x.protocol)
+ m.Save("ep", &x.ep)
+ m.Save("connection", &x.connection)
+ m.Save("bound", &x.bound)
+ m.Save("portID", &x.portID)
+ m.Save("sendBufferSize", &x.sendBufferSize)
+}
+
+func (x *Socket) afterLoad() {}
+func (x *Socket) load(m state.Map) {
+ m.Load("SendReceiveTimeout", &x.SendReceiveTimeout)
+ m.Load("ports", &x.ports)
+ m.Load("protocol", &x.protocol)
+ m.Load("ep", &x.ep)
+ m.Load("connection", &x.connection)
+ m.Load("bound", &x.bound)
+ m.Load("portID", &x.portID)
+ m.Load("sendBufferSize", &x.sendBufferSize)
+}
+
+func init() {
+ state.Register("netlink.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load})
+}
diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go
new file mode 100644
index 000000000..e9d3275b1
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/port.go
@@ -0,0 +1,116 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package port provides port ID allocation for netlink sockets.
+//
+// A netlink port is any int32 value. Positive ports are typically equivalent
+// to the PID of the binding process. If that port is unavailable, negative
+// ports are searched to find a free port that will not conflict with other
+// PIDS.
+package port
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+ "sync"
+)
+
+// maxPorts is a sanity limit on the maximum number of ports to allocate per
+// protocol.
+const maxPorts = 10000
+
+// Manager allocates netlink port IDs.
+//
+// +stateify savable
+type Manager struct {
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // ports contains a map of allocated ports for each protocol.
+ ports map[int]map[int32]struct{}
+}
+
+// New creates a new Manager.
+func New() *Manager {
+ return &Manager{
+ ports: make(map[int]map[int32]struct{}),
+ }
+}
+
+// Allocate reserves a new port ID for protocol. hint will be taken if
+// available.
+func (m *Manager) Allocate(protocol int, hint int32) (int32, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ proto, ok := m.ports[protocol]
+ if !ok {
+ proto = make(map[int32]struct{})
+ // Port 0 is reserved for the kernel.
+ proto[0] = struct{}{}
+ m.ports[protocol] = proto
+ }
+
+ if len(proto) >= maxPorts {
+ return 0, false
+ }
+
+ if _, ok := proto[hint]; !ok {
+ // Hint is available, reserve it.
+ proto[hint] = struct{}{}
+ return hint, true
+ }
+
+ // Search for any free port in [math.MinInt32, -4096). The positive
+ // port space is left open for pid-based allocations. This behavior is
+ // consistent with Linux.
+ start := int32(math.MinInt32 + rand.Int63n(math.MaxInt32-4096+1))
+ curr := start
+ for {
+ if _, ok := proto[curr]; !ok {
+ proto[curr] = struct{}{}
+ return curr, true
+ }
+
+ curr--
+ if curr >= -4096 {
+ curr = -4097
+ }
+ if curr == start {
+ // Nothing found. We should always find a free port
+ // because maxPorts < -4096 - MinInt32.
+ panic(fmt.Sprintf("No free port found in %+v", proto))
+ }
+ }
+}
+
+// Release frees the specified port for protocol.
+//
+// Preconditions: port is already allocated.
+func (m *Manager) Release(protocol int, port int32) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ proto, ok := m.ports[protocol]
+ if !ok {
+ panic(fmt.Sprintf("Released port %d for protocol %d which has no allocations", port, protocol))
+ }
+
+ if _, ok := proto[port]; !ok {
+ panic(fmt.Sprintf("Released port %d for protocol %d is not allocated", port, protocol))
+ }
+
+ delete(proto, port)
+}
diff --git a/pkg/sentry/socket/netlink/port/port_state_autogen.go b/pkg/sentry/socket/netlink/port/port_state_autogen.go
new file mode 100755
index 000000000..f01d9704f
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/port_state_autogen.go
@@ -0,0 +1,22 @@
+// automatically generated by stateify.
+
+package port
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Manager) beforeSave() {}
+func (x *Manager) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ports", &x.ports)
+}
+
+func (x *Manager) afterLoad() {}
+func (x *Manager) load(m state.Map) {
+ m.Load("ports", &x.ports)
+}
+
+func init() {
+ state.Register("port.Manager", (*Manager)(nil), state.Fns{Save: (*Manager).save, Load: (*Manager).load})
+}
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
new file mode 100644
index 000000000..76cf12fd4
--- /dev/null
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+)
+
+// Protocol is the implementation of a netlink socket protocol.
+type Protocol interface {
+ // Protocol returns the Linux netlink protocol value.
+ Protocol() int
+
+ // ProcessMessage processes a single message from userspace.
+ //
+ // If err == nil, any messages added to ms will be sent back to the
+ // other end of the socket. Setting ms.Multi will cause an NLMSG_DONE
+ // message to be sent even if ms contains no messages.
+ ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *MessageSet) *syserr.Error
+}
+
+// Provider is a function that creates a new Protocol for a specific netlink
+// protocol.
+//
+// Note that this is distinct from socket.Provider, which is used for all
+// socket families.
+type Provider func(t *kernel.Task) (Protocol, *syserr.Error)
+
+// protocols holds a map of all known address protocols and their provider.
+var protocols = make(map[int]Provider)
+
+// RegisterProvider registers the provider of a given address protocol so that
+// netlink sockets of that type can be created via socket(2).
+//
+// Preconditions: May only be called before any netlink sockets are created.
+func RegisterProvider(protocol int, provider Provider) {
+ if p, ok := protocols[protocol]; ok {
+ panic(fmt.Sprintf("Netlink protocol %d already provided by %+v", protocol, p))
+ }
+
+ protocols[protocol] = provider
+}
+
+// socketProvider implements socket.Provider.
+type socketProvider struct {
+}
+
+// Socket implements socket.Provider.Socket.
+func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Netlink sockets must be specified as datagram or raw, but they
+ // behave the same regardless of type.
+ if stype != transport.SockDgram && stype != transport.SockRaw {
+ return nil, syserr.ErrSocketNotSupported
+ }
+
+ provider, ok := protocols[protocol]
+ if !ok {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ p, err := provider(t)
+ if err != nil {
+ return nil, err
+ }
+
+ s, err := NewSocket(t, p)
+ if err != nil {
+ return nil, err
+ }
+
+ d := socket.NewDirent(t, netlinkSocketDevice)
+ defer d.DecRef()
+ return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true}, s), nil
+}
+
+// Pair implements socket.Provider.Pair by returning an error.
+func (*socketProvider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+ // Netlink sockets never supports creating socket pairs.
+ return nil, nil, syserr.ErrNotSupported
+}
+
+// init registers the socket provider.
+func init() {
+ socket.RegisterProvider(linux.AF_NETLINK, &socketProvider{})
+}
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
new file mode 100644
index 000000000..9f0a81403
--- /dev/null
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -0,0 +1,197 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package route provides a NETLINK_ROUTE socket protocol.
+package route
+
+import (
+ "bytes"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+)
+
+// commandKind describes the operational class of a message type.
+//
+// The route message types use the lower 2 bits of the type to describe class
+// of command.
+type commandKind int
+
+const (
+ kindNew commandKind = 0x0
+ kindDel = 0x1
+ kindGet = 0x2
+ kindSet = 0x3
+)
+
+func typeKind(typ uint16) commandKind {
+ return commandKind(typ & 0x3)
+}
+
+// Protocol implements netlink.Protocol.
+//
+// +stateify savable
+type Protocol struct{}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_ROUTE netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{}, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_ROUTE
+}
+
+// dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests.
+func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an
+ // ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some
+ // userspace applications (including glibc) still include rtgenmsg.
+ // Linux has a workaround based on the total message length.
+ //
+ // We don't bother to check for either, since we don't support any
+ // extra attributes that may be included anyways.
+ //
+ // The message may also contain netlink attribute IFLA_EXT_MASK, which
+ // we don't support.
+
+ // The RTM_GETLINK dump response is a set of messages each containing
+ // an InterfaceInfoMessage followed by a set of netlink attributes.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ for id, i := range stack.Interfaces() {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWLINK,
+ })
+
+ m.Put(linux.InterfaceInfoMessage{
+ Family: linux.AF_UNSPEC,
+ Type: i.DeviceType,
+ Index: id,
+ Flags: i.Flags,
+ })
+
+ m.PutAttrString(linux.IFLA_IFNAME, i.Name)
+ m.PutAttr(linux.IFLA_MTU, i.MTU)
+
+ mac := make([]byte, 6)
+ brd := mac
+ if len(i.Addr) > 0 {
+ mac = i.Addr
+ brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
+ }
+ m.PutAttr(linux.IFLA_ADDRESS, mac)
+ m.PutAttr(linux.IFLA_BROADCAST, brd)
+
+ // TODO(b/68878065): There are many more attributes.
+ }
+
+ return nil
+}
+
+// dumpAddrs handles RTM_GETADDR + NLM_F_DUMP requests.
+func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // RTM_GETADDR dump requests need not contain anything more than the
+ // netlink header and 1 byte protocol family common to all
+ // NETLINK_ROUTE requests.
+ //
+ // TODO(b/68878065): Filter output by passed protocol family.
+
+ // The RTM_GETADDR dump response is a set of RTM_NEWADDR messages each
+ // containing an InterfaceAddrMessage followed by a set of netlink
+ // attributes.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ for id, as := range stack.InterfaceAddrs() {
+ for _, a := range as {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWADDR,
+ })
+
+ m.Put(linux.InterfaceAddrMessage{
+ Family: a.Family,
+ PrefixLen: a.PrefixLen,
+ Index: uint32(id),
+ })
+
+ m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
+
+ // TODO(b/68878065): There are many more attributes.
+ }
+ }
+
+ return nil
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // All messages start with a 1 byte protocol family.
+ if len(data) < 1 {
+ // Linux ignores messages missing the protocol family. See
+ // net/core/rtnetlink.c:rtnetlink_rcv_msg.
+ return nil
+ }
+
+ // Non-GET message types require CAP_NET_ADMIN.
+ if typeKind(hdr.Type) != kindGet {
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_ADMIN) {
+ return syserr.ErrPermissionDenied
+ }
+ }
+
+ // TODO(b/68878065): Only the dump variant of the types below are
+ // supported.
+ if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP {
+ return syserr.ErrNotSupported
+ }
+
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.dumpLinks(ctx, hdr, data, ms)
+ case linux.RTM_GETADDR:
+ return p.dumpAddrs(ctx, hdr, data, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
+}
+
+// init registers the NETLINK_ROUTE provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_ROUTE, NewProtocol)
+}
diff --git a/pkg/sentry/socket/netlink/route/route_state_autogen.go b/pkg/sentry/socket/netlink/route/route_state_autogen.go
new file mode 100755
index 000000000..8431bb3d5
--- /dev/null
+++ b/pkg/sentry/socket/netlink/route/route_state_autogen.go
@@ -0,0 +1,20 @@
+// automatically generated by stateify.
+
+package route
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Protocol) beforeSave() {}
+func (x *Protocol) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *Protocol) afterLoad() {}
+func (x *Protocol) load(m state.Map) {
+}
+
+func init() {
+ state.Register("route.Protocol", (*Protocol)(nil), state.Fns{Save: (*Protocol).save, Load: (*Protocol).load})
+}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
new file mode 100644
index 000000000..afd06ca33
--- /dev/null
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -0,0 +1,618 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netlink provides core functionality for netlink sockets.
+package netlink
+
+import (
+ "math"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink/port"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const sizeOfInt32 int = 4
+
+const (
+ // minBufferSize is the smallest size of a send buffer.
+ minSendBufferSize = 4 << 10 // 4096 bytes.
+
+ // defaultSendBufferSize is the default size for the send buffer.
+ defaultSendBufferSize = 16 * 1024
+
+ // maxBufferSize is the largest size a send buffer can grow to.
+ maxSendBufferSize = 4 << 20 // 4MB
+)
+
+// netlinkSocketDevice is the netlink socket virtual device.
+var netlinkSocketDevice = device.NewAnonDevice()
+
+// Socket is the base socket type for netlink sockets.
+//
+// This implementation only supports userspace sending and receiving messages
+// to/from the kernel.
+//
+// Socket implements socket.Socket.
+//
+// +stateify savable
+type Socket struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ socket.SendReceiveTimeout
+
+ // ports provides netlink port allocation.
+ ports *port.Manager
+
+ // protocol is the netlink protocol implementation.
+ protocol Protocol
+
+ // ep is a datagram unix endpoint used to buffer messages sent from the
+ // kernel to userspace. RecvMsg reads messages from this endpoint.
+ ep transport.Endpoint
+
+ // connection is the kernel's connection to ep, used to write messages
+ // sent to userspace.
+ connection transport.ConnectedEndpoint
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // bound indicates that portid is valid.
+ bound bool
+
+ // portID is the port ID allocated for this socket.
+ portID int32
+
+ // sendBufferSize is the send buffer "size". We don't actually have a
+ // fixed buffer but only consume this many bytes.
+ sendBufferSize uint32
+}
+
+var _ socket.Socket = (*Socket)(nil)
+
+// NewSocket creates a new Socket.
+func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
+ // Datagram endpoint used to buffer kernel -> user messages.
+ ep := transport.NewConnectionless()
+
+ // Bind the endpoint for good measure so we can connect to it. The
+ // bound address will never be exposed.
+ if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ // Create a connection from which the kernel can write messages.
+ connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect()
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ return &Socket{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ }, nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *Socket) Release() {
+ s.connection.Release()
+ s.ep.Close()
+
+ if s.bound {
+ s.ports.Release(s.protocol.Protocol(), s.portID)
+ }
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // ep holds messages to be read and thus handles EventIn readiness.
+ ready := s.ep.Readiness(mask)
+
+ if mask&waiter.EventOut == waiter.EventOut {
+ // sendMsg handles messages synchronously and is thus always
+ // ready for writing.
+ ready |= waiter.EventOut
+ }
+
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *Socket) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+ // Writable readiness never changes, so no registration is needed.
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *Socket) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *Socket) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // TODO(b/68878065): no ioctls supported.
+ return 0, syserror.ENOTTY
+}
+
+// ExtractSockAddr extracts the SockAddrNetlink from b.
+func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
+ if len(b) < linux.SockAddrNetlinkSize {
+ return nil, syserr.ErrBadAddress
+ }
+
+ var sa linux.SockAddrNetlink
+ binary.Unmarshal(b[:linux.SockAddrNetlinkSize], usermem.ByteOrder, &sa)
+
+ if sa.Family != linux.AF_NETLINK {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return &sa, nil
+}
+
+// bindPort binds this socket to a port, preferring 'port' if it is available.
+//
+// port of 0 defaults to the ThreadGroup ID.
+//
+// Preconditions: mu is held.
+func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
+ if s.bound {
+ // Re-binding is only allowed if the port doesn't change.
+ if port != s.portID {
+ return syserr.ErrInvalidArgument
+ }
+
+ return nil
+ }
+
+ if port == 0 {
+ port = int32(t.ThreadGroup().ID())
+ }
+ port, ok := s.ports.Allocate(s.protocol.Protocol(), port)
+ if !ok {
+ return syserr.ErrBusy
+ }
+
+ s.portID = port
+ s.bound = true
+ return nil
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ a, err := ExtractSockAddr(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return s.bindPort(t, int32(a.PortID))
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ a, err := ExtractSockAddr(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if a.PortID == 0 {
+ // Netlink sockets default to connected to the kernel, but
+ // connecting anyways automatically binds if not already bound.
+ if !s.bound {
+ // Pass port 0 to get an auto-selected port ID.
+ return s.bindPort(t, 0)
+ }
+ return nil
+ }
+
+ // We don't support non-kernel destination ports. Linux returns EPERM
+ // if applications attempt to do this without NL_CFG_F_NONROOT_SEND, so
+ // we emulate that.
+ return syserr.ErrPermissionDenied
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ // Netlink sockets never support accept.
+ return 0, nil, 0, syserr.ErrNotSupported
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ // Netlink sockets never support listen.
+ return syserr.ErrNotSupported
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ // Netlink sockets never support shutdown.
+ return syserr.ErrNotSupported
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+ switch level {
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return int32(s.sendBufferSize), nil
+
+ case linux.SO_RCVBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ // We don't have limit on receiving size.
+ return int32(math.MaxInt32), nil
+
+ default:
+ socket.GetSockOptEmitUnimplementedEvent(t, name)
+ }
+ case linux.SOL_NETLINK:
+ switch name {
+ case linux.NETLINK_BROADCAST_ERROR,
+ linux.NETLINK_CAP_ACK,
+ linux.NETLINK_DUMP_STRICT_CHK,
+ linux.NETLINK_EXT_ACK,
+ linux.NETLINK_LIST_MEMBERSHIPS,
+ linux.NETLINK_NO_ENOBUFS,
+ linux.NETLINK_PKTINFO:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+ }
+ // TODO(b/68878065): other sockopts are not supported.
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ switch level {
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ size := usermem.ByteOrder.Uint32(opt)
+ if size < minSendBufferSize {
+ size = minSendBufferSize
+ } else if size > maxSendBufferSize {
+ size = maxSendBufferSize
+ }
+ s.mu.Lock()
+ s.sendBufferSize = size
+ s.mu.Unlock()
+ return nil
+ case linux.SO_RCVBUF:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ // We don't have limit on receiving size. So just accept anything as
+ // valid for compatibility.
+ return nil
+ default:
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ case linux.SOL_NETLINK:
+ switch name {
+ case linux.NETLINK_ADD_MEMBERSHIP,
+ linux.NETLINK_BROADCAST_ERROR,
+ linux.NETLINK_CAP_ACK,
+ linux.NETLINK_DROP_MEMBERSHIP,
+ linux.NETLINK_DUMP_STRICT_CHK,
+ linux.NETLINK_EXT_ACK,
+ linux.NETLINK_LISTEN_ALL_NSID,
+ linux.NETLINK_NO_ENOBUFS,
+ linux.NETLINK_PKTINFO:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ }
+ // TODO(b/68878065): other sockopts are not supported.
+ return syserr.ErrProtocolNotAvailable
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ sa := linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ PortID: uint32(s.portID),
+ }
+ return sa, uint32(binary.Size(sa)), nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ sa := linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ // TODO(b/68878065): Support non-kernel peers. For now the peer
+ // must be the kernel.
+ PortID: 0,
+ }
+ return sa, uint32(binary.Size(sa)), nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+ from := linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ PortID: 0,
+ }
+ fromLen := uint32(binary.Size(from))
+
+ trunc := flags&linux.MSG_TRUNC != 0
+
+ r := unix.EndpointReader{
+ Endpoint: s.ep,
+ Peek: flags&linux.MSG_PEEK != 0,
+ }
+
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ var mflags int
+ if n < int64(r.MsgSize) {
+ mflags |= linux.MSG_TRUNC
+ }
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // receive all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ var mflags int
+ if n < int64(r.MsgSize) {
+ mflags |= linux.MSG_TRUNC
+ }
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &unix.EndpointReader{
+ Endpoint: s.ep,
+ })
+}
+
+// sendResponse sends the response messages in ms back to userspace.
+func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
+ // Linux combines multiple netlink messages into a single datagram.
+ bufs := make([][]byte, 0, len(ms.Messages))
+ for _, m := range ms.Messages {
+ bufs = append(bufs, m.Finalize())
+ }
+
+ if len(bufs) > 0 {
+ // RecvMsg never receives the address, so we don't need to send
+ // one.
+ _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{})
+ // If the buffer is full, we simply drop messages, just like
+ // Linux.
+ if err != nil && err != syserr.ErrWouldBlock {
+ return err
+ }
+ if notify {
+ s.connection.SendNotify()
+ }
+ }
+
+ // N.B. multi-part messages should still send NLMSG_DONE even if
+ // MessageSet contains no messages.
+ //
+ // N.B. NLMSG_DONE is always sent in a different datagram. See
+ // net/netlink/af_netlink.c:netlink_dump.
+ if ms.Multi {
+ m := NewMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_DONE,
+ Flags: linux.NLM_F_MULTI,
+ Seq: ms.Seq,
+ PortID: uint32(ms.PortID),
+ })
+
+ _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{})
+ if err != nil && err != syserr.ErrWouldBlock {
+ return err
+ }
+ if notify {
+ s.connection.SendNotify()
+ }
+ }
+
+ return nil
+}
+
+// processMessages handles each message in buf, passing it to the protocol
+// handler for final handling.
+func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
+ for len(buf) > 0 {
+ if len(buf) < linux.NetlinkMessageHeaderSize {
+ // Linux ignores messages that are too short. See
+ // net/netlink/af_netlink.c:netlink_rcv_skb.
+ break
+ }
+
+ var hdr linux.NetlinkMessageHeader
+ binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr)
+
+ if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) {
+ // Linux ignores malformed messages. See
+ // net/netlink/af_netlink.c:netlink_rcv_skb.
+ break
+ }
+
+ // Data from this message.
+ data := buf[linux.NetlinkMessageHeaderSize:hdr.Length]
+
+ // Advance to the next message.
+ next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO)
+ if next >= len(buf)-1 {
+ next = len(buf) - 1
+ }
+ buf = buf[next:]
+
+ // Ignore control messages.
+ if hdr.Type < linux.NLMSG_MIN_TYPE {
+ continue
+ }
+
+ // TODO(b/68877377): ACKs not supported yet.
+ if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
+ return syserr.ErrNotSupported
+ }
+
+ ms := NewMessageSet(s.portID, hdr.Seq)
+ if err := s.protocol.ProcessMessage(ctx, hdr, data, ms); err != nil {
+ return err
+ }
+
+ if err := s.sendResponse(ctx, ms); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// sendMsg is the core of message send, used for SendMsg and Write.
+func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ dstPort := int32(0)
+
+ if len(to) != 0 {
+ a, err := ExtractSockAddr(to)
+ if err != nil {
+ return 0, err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ dstPort = int32(a.PortID)
+ }
+
+ if dstPort != 0 {
+ // Non-kernel destinations not supported yet. Treat as if
+ // NL_CFG_F_NONROOT_SEND is not set.
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // For simplicity, and consistency with Linux, we copy in the entire
+ // message up front.
+ if src.NumBytes() > int64(s.sendBufferSize) {
+ return 0, syserr.ErrMessageTooLong
+ }
+
+ buf := make([]byte, src.NumBytes())
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ // Don't partially consume messages.
+ return 0, syserr.FromError(err)
+ }
+
+ if err := s.processMessages(ctx, buf); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ return s.sendMsg(t, src, to, flags, controlMessages)
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
+ return int64(n), err.ToError()
+}
diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go
new file mode 100644
index 000000000..f537c7f63
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/conn/conn.go
@@ -0,0 +1,187 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package conn is an RPC connection to a syscall RPC server.
+package conn
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "github.com/golang/protobuf/proto"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+)
+
+type request struct {
+ response []byte
+ ready chan struct{}
+ ignoreResult bool
+}
+
+// RPCConnection represents a single RPC connection to a syscall gofer.
+type RPCConnection struct {
+ // reqID is the ID of the last request and must be accessed atomically.
+ reqID uint64
+
+ sendMu sync.Mutex
+ socket *unet.Socket
+
+ reqMu sync.Mutex
+ requests map[uint64]request
+}
+
+// NewRPCConnection initializes a RPC connection to a socket gofer.
+func NewRPCConnection(s *unet.Socket) *RPCConnection {
+ conn := &RPCConnection{socket: s, requests: map[uint64]request{}}
+ go func() { // S/R-FIXME(b/77962828)
+ var nums [16]byte
+ for {
+ for n := 0; n < len(nums); {
+ nn, err := conn.socket.Read(nums[n:])
+ if err != nil {
+ panic(fmt.Sprint("error reading length from socket rpc gofer: ", err))
+ }
+ n += nn
+ }
+
+ b := make([]byte, binary.LittleEndian.Uint64(nums[:8]))
+ id := binary.LittleEndian.Uint64(nums[8:])
+
+ for n := 0; n < len(b); {
+ nn, err := conn.socket.Read(b[n:])
+ if err != nil {
+ panic(fmt.Sprint("error reading request from socket rpc gofer: ", err))
+ }
+ n += nn
+ }
+
+ conn.reqMu.Lock()
+ r := conn.requests[id]
+ if r.ignoreResult {
+ delete(conn.requests, id)
+ } else {
+ r.response = b
+ conn.requests[id] = r
+ }
+ conn.reqMu.Unlock()
+ close(r.ready)
+ }
+ }()
+ return conn
+}
+
+// NewRequest makes a request to the RPC gofer and returns the request ID and a
+// channel which will be closed once the request completes.
+func (c *RPCConnection) NewRequest(req pb.SyscallRequest, ignoreResult bool) (uint64, chan struct{}) {
+ b, err := proto.Marshal(&req)
+ if err != nil {
+ panic(fmt.Sprint("invalid proto: ", err))
+ }
+
+ id := atomic.AddUint64(&c.reqID, 1)
+ ch := make(chan struct{})
+
+ c.reqMu.Lock()
+ c.requests[id] = request{ready: ch, ignoreResult: ignoreResult}
+ c.reqMu.Unlock()
+
+ c.sendMu.Lock()
+ defer c.sendMu.Unlock()
+
+ var nums [16]byte
+ binary.LittleEndian.PutUint64(nums[:8], uint64(len(b)))
+ binary.LittleEndian.PutUint64(nums[8:], id)
+ for n := 0; n < len(nums); {
+ nn, err := c.socket.Write(nums[n:])
+ if err != nil {
+ panic(fmt.Sprint("error writing length and ID to socket gofer: ", err))
+ }
+ n += nn
+ }
+
+ for n := 0; n < len(b); {
+ nn, err := c.socket.Write(b[n:])
+ if err != nil {
+ panic(fmt.Sprint("error writing request to socket gofer: ", err))
+ }
+ n += nn
+ }
+
+ return id, ch
+}
+
+// RPCReadFile will execute the ReadFile helper RPC method which avoids the
+// common pattern of open(2), read(2), close(2) by doing all three operations
+// as a single RPC. It will read the entire file or return EFBIG if the file
+// was too large.
+func (c *RPCConnection) RPCReadFile(path string) ([]byte, *syserr.Error) {
+ req := &pb.SyscallRequest_ReadFile{&pb.ReadFileRequest{
+ Path: path,
+ }}
+
+ id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-ch
+
+ res := c.Request(id).Result.(*pb.SyscallResponse_ReadFile).ReadFile.Result
+ if e, ok := res.(*pb.ReadFileResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.ReadFileResponse_Data).Data, nil
+}
+
+// RPCWriteFile will execute the WriteFile helper RPC method which avoids the
+// common pattern of open(2), write(2), write(2), close(2) by doing all
+// operations as a single RPC.
+func (c *RPCConnection) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
+ req := &pb.SyscallRequest_WriteFile{&pb.WriteFileRequest{
+ Path: path,
+ Content: data,
+ }}
+
+ id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-ch
+
+ res := c.Request(id).Result.(*pb.SyscallResponse_WriteFile).WriteFile
+ if e := res.ErrorNumber; e != 0 {
+ return int64(res.Written), syserr.FromHost(syscall.Errno(e))
+ }
+
+ return int64(res.Written), nil
+}
+
+// Request retrieves the request corresponding to the given request ID.
+//
+// The channel returned by NewRequest must have been closed before Request can
+// be called. This will happen automatically, do not manually close the
+// channel.
+func (c *RPCConnection) Request(id uint64) pb.SyscallResponse {
+ c.reqMu.Lock()
+ r := c.requests[id]
+ delete(c.requests, id)
+ c.reqMu.Unlock()
+
+ var resp pb.SyscallResponse
+ if err := proto.Unmarshal(r.response, &resp); err != nil {
+ panic(fmt.Sprint("invalid proto: ", err))
+ }
+
+ return resp
+}
diff --git a/pkg/sentry/socket/rpcinet/conn/conn_state_autogen.go b/pkg/sentry/socket/rpcinet/conn/conn_state_autogen.go
new file mode 100755
index 000000000..f6c927a60
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/conn/conn_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package conn
+
diff --git a/pkg/sentry/socket/rpcinet/device.go b/pkg/sentry/socket/rpcinet/device.go
new file mode 100644
index 000000000..44c0a39b7
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/device.go
@@ -0,0 +1,19 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpcinet
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+var socketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go
new file mode 100644
index 000000000..601e05994
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go
@@ -0,0 +1,230 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package notifier implements an FD notifier implementation over RPC.
+package notifier
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn"
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type fdInfo struct {
+ queue *waiter.Queue
+ waiting bool
+}
+
+// Notifier holds all the state necessary to issue notifications when IO events
+// occur in the observed FDs.
+type Notifier struct {
+ // rpcConn is the connection that is used for sending RPCs.
+ rpcConn *conn.RPCConnection
+
+ // epFD is the epoll file descriptor used to register for io
+ // notifications.
+ epFD uint32
+
+ // mu protects fdMap.
+ mu sync.Mutex
+
+ // fdMap maps file descriptors to their notification queues and waiting
+ // status.
+ fdMap map[uint32]*fdInfo
+}
+
+// NewRPCNotifier creates a new notifier object.
+func NewRPCNotifier(cn *conn.RPCConnection) (*Notifier, error) {
+ id, c := cn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCreate1{&pb.EpollCreate1Request{}}}, false /* ignoreResult */)
+ <-c
+
+ res := cn.Request(id).Result.(*pb.SyscallResponse_EpollCreate1).EpollCreate1.Result
+ if e, ok := res.(*pb.EpollCreate1Response_ErrorNumber); ok {
+ return nil, syscall.Errno(e.ErrorNumber)
+ }
+
+ w := &Notifier{
+ rpcConn: cn,
+ epFD: res.(*pb.EpollCreate1Response_Fd).Fd,
+ fdMap: make(map[uint32]*fdInfo),
+ }
+
+ go w.waitAndNotify() // S/R-FIXME(b/77962828)
+
+ return w, nil
+}
+
+// waitFD waits on mask for fd. The fdMap mutex must be hold.
+func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error {
+ if !fi.waiting && mask == 0 {
+ return nil
+ }
+
+ e := pb.EpollEvent{
+ Events: mask.ToLinux() | -syscall.EPOLLET,
+ Fd: fd,
+ }
+
+ switch {
+ case !fi.waiting && mask != 0:
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_ADD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
+ <-c
+
+ e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
+ if e != 0 {
+ return syscall.Errno(e)
+ }
+
+ fi.waiting = true
+ case fi.waiting && mask == 0:
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_DEL, Fd: fd}}}, false /* ignoreResult */)
+ <-c
+ n.rpcConn.Request(id)
+
+ fi.waiting = false
+ case fi.waiting && mask != 0:
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_MOD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
+ <-c
+
+ e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
+ if e != 0 {
+ return syscall.Errno(e)
+ }
+ }
+
+ return nil
+}
+
+// addFD adds an FD to the list of FDs observed by n.
+func (n *Notifier) addFD(fd uint32, queue *waiter.Queue) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Panic if we're already notifying on this FD.
+ if _, ok := n.fdMap[fd]; ok {
+ panic(fmt.Sprintf("File descriptor %d added twice", fd))
+ }
+
+ // We have nothing to wait for at the moment. Just add it to the map.
+ n.fdMap[fd] = &fdInfo{queue: queue}
+}
+
+// updateFD updates the set of events the FD needs to be notified on.
+func (n *Notifier) updateFD(fd uint32) error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if fi, ok := n.fdMap[fd]; ok {
+ return n.waitFD(fd, fi, fi.queue.Events())
+ }
+
+ return nil
+}
+
+// RemoveFD removes an FD from the list of FDs observed by n.
+func (n *Notifier) removeFD(fd uint32) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Remove from map, then from epoll object.
+ n.waitFD(fd, n.fdMap[fd], 0)
+ delete(n.fdMap, fd)
+}
+
+// hasFD returns true if the FD is in the list of observed FDs.
+func (n *Notifier) hasFD(fd uint32) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ _, ok := n.fdMap[fd]
+ return ok
+}
+
+// waitAndNotify loops waiting for io event notifications from the epoll
+// object. Once notifications arrive, they are dispatched to the
+// registered queue.
+func (n *Notifier) waitAndNotify() error {
+ for {
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollWait{&pb.EpollWaitRequest{Fd: n.epFD, NumEvents: 100, Msec: -1}}}, false /* ignoreResult */)
+ <-c
+
+ res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollWait).EpollWait.Result
+ if e, ok := res.(*pb.EpollWaitResponse_ErrorNumber); ok {
+ err := syscall.Errno(e.ErrorNumber)
+ // NOTE(magi): I don't think epoll_wait can return EAGAIN but I'm being
+ // conseratively careful here since exiting the notification thread
+ // would be really bad.
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ return err
+ }
+
+ n.mu.Lock()
+ for _, e := range res.(*pb.EpollWaitResponse_Events).Events.Events {
+ if fi, ok := n.fdMap[e.Fd]; ok {
+ fi.queue.Notify(waiter.EventMaskFromLinux(e.Events))
+ }
+ }
+ n.mu.Unlock()
+ }
+}
+
+// AddFD adds an FD to the list of observed FDs.
+func (n *Notifier) AddFD(fd uint32, queue *waiter.Queue) error {
+ n.addFD(fd, queue)
+ return nil
+}
+
+// UpdateFD updates the set of events the FD needs to be notified on.
+func (n *Notifier) UpdateFD(fd uint32) error {
+ return n.updateFD(fd)
+}
+
+// RemoveFD removes an FD from the list of observed FDs.
+func (n *Notifier) RemoveFD(fd uint32) {
+ n.removeFD(fd)
+}
+
+// HasFD returns true if the FD is in the list of observed FDs.
+//
+// This should only be used by tests to assert that FDs are correctly
+// registered.
+func (n *Notifier) HasFD(fd uint32) bool {
+ return n.hasFD(fd)
+}
+
+// NonBlockingPoll polls the given fd in non-blocking fashion. It is used just
+// to query the FD's current state; this method will block on the RPC response
+// although the syscall is non-blocking.
+func (n *Notifier) NonBlockingPoll(fd uint32, mask waiter.EventMask) waiter.EventMask {
+ for {
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Poll{&pb.PollRequest{Fd: fd, Events: mask.ToLinux()}}}, false /* ignoreResult */)
+ <-c
+
+ res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_Poll).Poll.Result
+ if e, ok := res.(*pb.PollResponse_ErrorNumber); ok {
+ if syscall.Errno(e.ErrorNumber) == syscall.EINTR {
+ continue
+ }
+ return mask
+ }
+
+ return waiter.EventMaskFromLinux(res.(*pb.PollResponse_Events).Events)
+ }
+}
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier_state_autogen.go b/pkg/sentry/socket/rpcinet/notifier/notifier_state_autogen.go
new file mode 100755
index 000000000..f108d91c1
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/notifier/notifier_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package notifier
+
diff --git a/pkg/sentry/socket/rpcinet/rpcinet.go b/pkg/sentry/socket/rpcinet/rpcinet.go
new file mode 100644
index 000000000..5d4fd4dac
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/rpcinet.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package rpcinet implements sockets using an RPC for each syscall.
+package rpcinet
diff --git a/pkg/sentry/socket/rpcinet/rpcinet_state_autogen.go b/pkg/sentry/socket/rpcinet/rpcinet_state_autogen.go
new file mode 100755
index 000000000..d3076c7e3
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/rpcinet_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package rpcinet
+
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
new file mode 100644
index 000000000..55e0b6665
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -0,0 +1,887 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpcinet
+
+import (
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier"
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// socketOperations implements fs.FileOperations and socket.Socket for a socket
+// implemented using a host socket.
+type socketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ socket.SendReceiveTimeout
+
+ family int // Read-only.
+ fd uint32 // must be O_NONBLOCK
+ wq *waiter.Queue
+ rpcConn *conn.RPCConnection
+ notifier *notifier.Notifier
+
+ // shState is the state of the connection with respect to shutdown. Because
+ // we're mixing non-blocking semantics on the other side we have to adapt for
+ // some strange differences between blocking and non-blocking sockets.
+ shState int32
+}
+
+// Verify that we actually implement socket.Socket.
+var _ = socket.Socket(&socketOperations{})
+
+// New creates a new RPC socket.
+func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, protocol int) (*fs.File, *syserr.Error) {
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
+ if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+ fd := res.(*pb.SocketResponse_Fd).Fd
+
+ var wq waiter.Queue
+ stack.notifier.AddFD(fd, &wq)
+
+ dirent := socket.NewDirent(ctx, socketDevice)
+ defer dirent.DecRef()
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{
+ family: family,
+ wq: &wq,
+ fd: fd,
+ rpcConn: stack.rpcConn,
+ notifier: stack.notifier,
+ }), nil
+}
+
+func isBlockingErrno(err error) bool {
+ return err == syscall.EAGAIN || err == syscall.EWOULDBLOCK
+}
+
+func translateIOSyscallError(err error) error {
+ if isBlockingErrno(err) {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+// setShutdownFlags will set the shutdown flag so we can handle blocking reads
+// after a read shutdown.
+func (s *socketOperations) setShutdownFlags(how int) {
+ var f tcpip.ShutdownFlags
+ switch how {
+ case linux.SHUT_RD:
+ f = tcpip.ShutdownRead
+ case linux.SHUT_WR:
+ f = tcpip.ShutdownWrite
+ case linux.SHUT_RDWR:
+ f = tcpip.ShutdownWrite | tcpip.ShutdownRead
+ }
+
+ // Atomically update the flags.
+ for {
+ old := atomic.LoadInt32(&s.shState)
+ if atomic.CompareAndSwapInt32(&s.shState, old, old|int32(f)) {
+ break
+ }
+ }
+}
+
+func (s *socketOperations) resetShutdownFlags() {
+ atomic.StoreInt32(&s.shState, 0)
+}
+
+func (s *socketOperations) isShutRdSet() bool {
+ return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownRead) != 0
+}
+
+func (s *socketOperations) isShutWrSet() bool {
+ return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownWrite) != 0
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOperations) Release() {
+ s.notifier.RemoveFD(s.fd)
+
+ // We always need to close the FD.
+ _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: s.fd}}}, true /* ignoreResult */)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.notifier.NonBlockingPoll(s.fd, mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.wq.EventRegister(e, mask)
+ s.notifier.UpdateFD(s.fd)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOperations) EventUnregister(e *waiter.Entry) {
+ s.wq.EventUnregister(e)
+ s.notifier.UpdateFD(s.fd)
+}
+
+func rpcRead(t *kernel.Task, req *pb.SyscallRequest_Read) (*pb.ReadResponse_Data, *syserr.Error) {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Read).Read.Result
+ if e, ok := res.(*pb.ReadResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.ReadResponse_Data), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ req := &pb.SyscallRequest_Read{&pb.ReadRequest{
+ Fd: s.fd,
+ Length: uint32(dst.NumBytes()),
+ }}
+
+ res, se := rpcRead(ctx.(*kernel.Task), req)
+ if se == nil {
+ n, e := dst.CopyOut(ctx, res.Data)
+ return int64(n), e
+ }
+
+ return 0, se.ToError()
+}
+
+func rpcWrite(t *kernel.Task, req *pb.SyscallRequest_Write) (uint32, *syserr.Error) {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Write).Write.Result
+ if e, ok := res.(*pb.WriteResponse_ErrorNumber); ok {
+ return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.WriteResponse_Length).Length, nil
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ t := ctx.(*kernel.Task)
+ v := buffer.NewView(int(src.NumBytes()))
+
+ // Copy all the data into the buffer.
+ if _, err := src.CopyIn(t, v); err != nil {
+ return 0, err
+ }
+
+ n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}})
+ if n > 0 && n < uint32(src.NumBytes()) {
+ // The FileOperations.Write interface expects us to return ErrWouldBlock in
+ // the event of a partial write.
+ return int64(n), syserror.ErrWouldBlock
+ }
+ return int64(n), err.ToError()
+}
+
+func rpcConnect(t *kernel.Task, fd uint32, sockaddr []byte) *syserr.Error {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Connect{&pb.ConnectRequest{Fd: uint32(fd), Address: sockaddr}}}, false /* ignoreResult */)
+ <-c
+
+ if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Connect).Connect.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ if !blocking {
+ e := rpcConnect(t, s.fd, sockaddr)
+ if e == nil {
+ // Reset the shutdown state on new connects.
+ s.resetShutdownFlags()
+ }
+ return e
+ }
+
+ // Register for notification when the endpoint becomes writable, then
+ // initiate the connection.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut|waiter.EventIn|waiter.EventHUp)
+ defer s.EventUnregister(&e)
+ for {
+ if err := rpcConnect(t, s.fd, sockaddr); err == nil || err != syserr.ErrInProgress && err != syserr.ErrAlreadyInProgress {
+ if err == nil {
+ // Reset the shutdown state on new connects.
+ s.resetShutdownFlags()
+ }
+ return err
+ }
+
+ // It's pending, so we have to wait for a notification, and fetch the
+ // result once the wait completes.
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+ }
+}
+
+func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultPayload, *syserr.Error) {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Accept{&pb.AcceptRequest{Fd: fd, Peer: peer, Flags: syscall.SOCK_NONBLOCK}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Accept).Accept.Result
+ if e, ok := res.(*pb.AcceptResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+ return res.(*pb.AcceptResponse_Payload).Payload, nil
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ payload, se := rpcAccept(t, s.fd, peerRequested)
+
+ // Check if we need to block.
+ if blocking && se == syserr.ErrTryAgain {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ // FIXME(b/119878986): This waiter.EventHUp is a partial
+ // measure, need to figure out how to translate linux events to
+ // internal events.
+ s.EventRegister(&e, waiter.EventIn|waiter.EventHUp)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection again; if it fails, then wait until we
+ // get a notification.
+ for {
+ if payload, se = rpcAccept(t, s.fd, peerRequested); se != syserr.ErrTryAgain {
+ break
+ }
+
+ if err := t.Block(ch); err != nil {
+ return 0, nil, 0, syserr.FromError(err)
+ }
+ }
+ }
+
+ // Handle any error from accept.
+ if se != nil {
+ return 0, nil, 0, se
+ }
+
+ var wq waiter.Queue
+ s.notifier.AddFD(payload.Fd, &wq)
+
+ dirent := socket.NewDirent(t, socketDevice)
+ defer dirent.DecRef()
+ file := fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonBlocking: flags&linux.SOCK_NONBLOCK != 0}, &socketOperations{
+ wq: &wq,
+ fd: payload.Fd,
+ rpcConn: s.rpcConn,
+ notifier: s.notifier,
+ })
+ defer file.DecRef()
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ }
+ fd, err := t.FDMap().NewFDFrom(0, file, fdFlags, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, 0, syserr.FromError(err)
+ }
+ t.Kernel().RecordSocket(file, s.family)
+
+ if peerRequested {
+ return fd, payload.Address.Address, payload.Address.Length, nil
+ }
+
+ return fd, nil, 0, nil
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: s.fd, Address: sockaddr}}}, false /* ignoreResult */)
+ <-c
+
+ if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Listen{&pb.ListenRequest{Fd: s.fd, Backlog: int64(backlog)}}}, false /* ignoreResult */)
+ <-c
+
+ if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Listen).Listen.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ // We save the shutdown state because of strange differences on linux
+ // related to recvs on blocking vs. non-blocking sockets after a SHUT_RD.
+ // We need to emulate that behavior on the blocking side.
+ // TODO(b/120096741): There is a possible race that can exist with loopback,
+ // where data could possibly be lost.
+ s.setShutdownFlags(how)
+
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Shutdown{&pb.ShutdownRequest{Fd: s.fd, How: int64(how)}}}, false /* ignoreResult */)
+ <-c
+
+ if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Shutdown).Shutdown.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+
+ return nil
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+ // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
+ // within the sentry.
+ if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.RecvTimeout()), nil
+ }
+ if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+ }
+
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockOpt).GetSockOpt.Result
+ if e, ok := res.(*pb.GetSockOptResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.GetSockOptResponse_Opt).Opt, nil
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ // Because blocking actually happens within the sentry we need to inspect
+ // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO,
+ // and if so, we will save it and use it as the deadline for recv(2)
+ // or send(2) related syscalls.
+ if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
+ if len(opt) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetRecvTimeout(v.ToNsecCapped())
+ return nil
+ }
+ if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
+ if len(opt) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+ }
+
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */)
+ <-c
+
+ if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_SetSockOpt).SetSockOpt.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetPeerName).GetPeerName.Result
+ if e, ok := res.(*pb.GetPeerNameResponse_ErrorNumber); ok {
+ return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ addr := res.(*pb.GetPeerNameResponse_Address).Address
+ return addr.Address, addr.Length, nil
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockName).GetSockName.Result
+ if e, ok := res.(*pb.GetSockNameResponse_ErrorNumber); ok {
+ return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ addr := res.(*pb.GetSockNameResponse_Address).Address
+ return addr.Address, addr.Length, nil
+}
+
+func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) {
+ stack := t.NetworkContext().(*Stack)
+
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Ioctl{&pb.IOCtlRequest{Fd: fd, Cmd: cmd, Arg: arg}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Ioctl).Ioctl.Result
+ if e, ok := res.(*pb.IOCtlResponse_ErrorNumber); ok {
+ return nil, syscall.Errno(e.ErrorNumber)
+ }
+
+ return res.(*pb.IOCtlResponse_Value).Value, nil
+}
+
+// ifconfIoctlFromStack populates a struct ifconf for the SIOCGIFCONF ioctl.
+func ifconfIoctlFromStack(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+ // If Ptr is NULL, return the necessary buffer size via Len.
+ // Otherwise, write up to Len bytes starting at Ptr containing ifreq
+ // structs.
+ t := ctx.(*kernel.Task)
+ s := t.NetworkContext().(*Stack)
+ if s == nil {
+ return syserr.ErrNoDevice.ToError()
+ }
+
+ if ifc.Ptr == 0 {
+ ifc.Len = int32(len(s.Interfaces())) * int32(linux.SizeOfIFReq)
+ return nil
+ }
+
+ max := ifc.Len
+ ifc.Len = 0
+ for key, ifaceAddrs := range s.InterfaceAddrs() {
+ iface := s.Interfaces()[key]
+ for _, ifaceAddr := range ifaceAddrs {
+ // Don't write past the end of the buffer.
+ if ifc.Len+int32(linux.SizeOfIFReq) > max {
+ break
+ }
+ if ifaceAddr.Family != linux.AF_INET {
+ continue
+ }
+
+ // Populate ifr.ifr_addr.
+ ifr := linux.IFReq{}
+ ifr.SetName(iface.Name)
+ usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
+ usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
+
+ // Copy the ifr to userspace.
+ dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
+ ifc.Len += int32(linux.SizeOfIFReq)
+ if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *socketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := ctx.(*kernel.Task)
+
+ cmd := uint32(args[1].Int())
+ arg := args[2].Pointer()
+
+ var buf []byte
+ switch cmd {
+ // The following ioctls take 4 byte argument parameters.
+ case syscall.TIOCINQ,
+ syscall.TIOCOUTQ:
+ buf = make([]byte, 4)
+ // The following ioctls have args which are sizeof(struct ifreq).
+ case syscall.SIOCGIFADDR,
+ syscall.SIOCGIFBRDADDR,
+ syscall.SIOCGIFDSTADDR,
+ syscall.SIOCGIFFLAGS,
+ syscall.SIOCGIFHWADDR,
+ syscall.SIOCGIFINDEX,
+ syscall.SIOCGIFMAP,
+ syscall.SIOCGIFMETRIC,
+ syscall.SIOCGIFMTU,
+ syscall.SIOCGIFNAME,
+ syscall.SIOCGIFNETMASK,
+ syscall.SIOCGIFTXQLEN:
+ buf = make([]byte, linux.SizeOfIFReq)
+ case syscall.SIOCGIFCONF:
+ // SIOCGIFCONF has slightly different behavior than the others, in that it
+ // will need to populate the array of ifreqs.
+ var ifc linux.IFConf
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ if err := ifconfIoctlFromStack(ctx, io, &ifc); err != nil {
+ return 0, err
+ }
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+
+ return 0, err
+
+ case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
+ unimpl.EmitUnimplementedEvent(ctx)
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+
+ _, err := io.CopyIn(ctx, arg, buf, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+
+ if err != nil {
+ return 0, err
+ }
+
+ v, err := rpcIoctl(t, s.fd, cmd, buf)
+ if err != nil {
+ return 0, err
+ }
+
+ if len(v) != len(buf) {
+ return 0, syserror.EINVAL
+ }
+
+ _, err = io.CopyOut(ctx, arg, v, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+func rpcRecvMsg(t *kernel.Task, req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
+ if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.RecvmsgResponse_Payload).Payload, nil
+}
+
+// Because we only support SO_TIMESTAMP we will search control messages for
+// that value and set it if so, all other control messages will be ignored.
+func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_ResultPayload) socket.ControlMessages {
+ c := socket.ControlMessages{}
+ if len(payload.GetCmsgData()) > 0 {
+ // Parse the control messages looking for SO_TIMESTAMP.
+ msgs, e := syscall.ParseSocketControlMessage(payload.GetCmsgData())
+ if e != nil {
+ return socket.ControlMessages{}
+ }
+ for _, m := range msgs {
+ if m.Header.Level != linux.SOL_SOCKET || m.Header.Type != linux.SO_TIMESTAMP {
+ continue
+ }
+
+ // Let's parse the time stamp and set it.
+ if len(m.Data) < linux.SizeOfTimeval {
+ // Give up on locating the SO_TIMESTAMP option.
+ return socket.ControlMessages{}
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(m.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ c.IP.HasTimestamp = true
+ c.IP.Timestamp = v.ToNsecCapped()
+ break
+ }
+ }
+ return c
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+ req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
+ Fd: s.fd,
+ Length: uint32(dst.NumBytes()),
+ Sender: senderRequested,
+ Trunc: flags&linux.MSG_TRUNC != 0,
+ Peek: flags&linux.MSG_PEEK != 0,
+ CmsgLength: uint32(controlDataLen),
+ }}
+
+ res, err := rpcRecvMsg(t, req)
+ if err == nil {
+ var e error
+ var n int
+ if len(res.Data) > 0 {
+ n, e = dst.CopyOut(t, res.Data)
+ if e == nil && n != len(res.Data) {
+ panic("CopyOut failed to copy full buffer")
+ }
+ }
+ c := s.extractControlMessages(res)
+ return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ }
+ if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ // We'll have to block. Register for notifications and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ res, err := rpcRecvMsg(t, req)
+ if err == nil {
+ var e error
+ var n int
+ if len(res.Data) > 0 {
+ n, e = dst.CopyOut(t, res.Data)
+ if e == nil && n != len(res.Data) {
+ panic("CopyOut failed to copy full buffer")
+ }
+ }
+ c := s.extractControlMessages(res)
+ return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ }
+ if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if s.isShutRdSet() {
+ // Blocking would have caused us to block indefinitely so we return 0,
+ // this is the same behavior as Linux.
+ return 0, 0, nil, 0, socket.ControlMessages{}, nil
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
+ if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
+ return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.SendmsgResponse_Length).Length, nil
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ // Whitelist flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ // Reject Unix control messages.
+ if !controlMessages.Unix.Empty() {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ v := buffer.NewView(int(src.NumBytes()))
+
+ // Copy all the data into the buffer.
+ if _, err := src.CopyIn(t, v); err != nil {
+ return 0, syserr.FromError(err)
+ }
+
+ // TODO(bgeffon): this needs to change to map directly to a SendMsg syscall
+ // in the RPC.
+ totalWritten := 0
+ n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
+ Fd: uint32(s.fd),
+ Data: v,
+ Address: to,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }})
+
+ if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
+ return int(n), err
+ }
+
+ if n > 0 {
+ totalWritten += int(n)
+ v.TrimFront(int(n))
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ for {
+ n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
+ Fd: uint32(s.fd),
+ Data: v,
+ Address: to,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }})
+
+ if n > 0 {
+ totalWritten += int(n)
+ v.TrimFront(int(n))
+
+ if err == nil && totalWritten < int(src.NumBytes()) {
+ continue
+ }
+ }
+
+ if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
+ // We eat the error in this situation.
+ return int(totalWritten), nil
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(totalWritten), syserr.ErrTryAgain
+ }
+ return int(totalWritten), syserr.FromError(err)
+ }
+ }
+}
+
+type socketProvider struct {
+ family int
+}
+
+// Socket implements socket.Provider.Socket.
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check that we are using the RPC network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+
+ s, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ //
+ // Try to restrict the flags we will accept to minimize backwards
+ // incompatibility with netstack.
+ stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ return newSocketFile(t, s, p.family, stype, 0)
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
+
+func init() {
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socket.RegisterProvider(family, &socketProvider{family})
+ }
+}
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
new file mode 100644
index 000000000..a1be711df
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/stack.go
@@ -0,0 +1,135 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpcinet
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/hostinet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// Stack implements inet.Stack for RPC backed sockets.
+type Stack struct {
+ interfaces map[int32]inet.Interface
+ interfaceAddrs map[int32][]inet.InterfaceAddr
+ rpcConn *conn.RPCConnection
+ notifier *notifier.Notifier
+}
+
+// NewStack returns a Stack containing the current state of the host network
+// stack.
+func NewStack(fd int32) (*Stack, error) {
+ sock, err := unet.NewSocket(int(fd))
+ if err != nil {
+ return nil, err
+ }
+
+ stack := &Stack{
+ interfaces: make(map[int32]inet.Interface),
+ interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
+ rpcConn: conn.NewRPCConnection(sock),
+ }
+
+ var e error
+ stack.notifier, e = notifier.NewRPCNotifier(stack.rpcConn)
+ if e != nil {
+ return nil, e
+ }
+
+ links, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETLINK)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETLINK failed: %v", err)
+ }
+
+ addrs, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETADDR)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETADDR failed: %v", err)
+ }
+
+ e = hostinet.ExtractHostInterfaces(links, addrs, stack.interfaces, stack.interfaceAddrs)
+ if e != nil {
+ return nil, e
+ }
+
+ return stack, nil
+}
+
+// RPCReadFile will execute the ReadFile helper RPC method which avoids the
+// common pattern of open(2), read(2), close(2) by doing all three operations
+// as a single RPC. It will read the entire file or return EFBIG if the file
+// was too large.
+func (s *Stack) RPCReadFile(path string) ([]byte, *syserr.Error) {
+ return s.rpcConn.RPCReadFile(path)
+}
+
+// RPCWriteFile will execute the WriteFile helper RPC method which avoids the
+// common pattern of open(2), write(2), write(2), close(2) by doing all
+// operations as a single RPC.
+func (s *Stack) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
+ return s.rpcConn.RPCWriteFile(path, data)
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ return s.interfaces
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ return s.interfaceAddrs
+}
+
+// SupportsIPv6 implements inet.Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ panic("rpcinet handles procfs directly this method should not be called")
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ panic("rpcinet handles procfs directly this method should not be called")
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ panic("rpcinet handles procfs directly this method should not be called")
+
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ panic("rpcinet handles procfs directly this method should not be called")
+
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ panic("rpcinet handles procfs directly this method should not be called")
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ panic("rpcinet handles procfs directly this method should not be called")
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ panic("rpcinet handles procfs directly this method should not be called")
+}
diff --git a/pkg/sentry/socket/rpcinet/stack_unsafe.go b/pkg/sentry/socket/rpcinet/stack_unsafe.go
new file mode 100644
index 000000000..e53f578ba
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/stack_unsafe.go
@@ -0,0 +1,193 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpcinet
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+)
+
+// NewNetlinkRouteRequest builds a netlink message for getting the RIB,
+// the routing information base.
+func newNetlinkRouteRequest(proto, seq, family int) []byte {
+ rr := &syscall.NetlinkRouteRequest{}
+ rr.Header.Len = uint32(syscall.NLMSG_HDRLEN + syscall.SizeofRtGenmsg)
+ rr.Header.Type = uint16(proto)
+ rr.Header.Flags = syscall.NLM_F_DUMP | syscall.NLM_F_REQUEST
+ rr.Header.Seq = uint32(seq)
+ rr.Data.Family = uint8(family)
+ return netlinkRRtoWireFormat(rr)
+}
+
+func netlinkRRtoWireFormat(rr *syscall.NetlinkRouteRequest) []byte {
+ b := make([]byte, rr.Header.Len)
+ *(*uint32)(unsafe.Pointer(&b[0:4][0])) = rr.Header.Len
+ *(*uint16)(unsafe.Pointer(&b[4:6][0])) = rr.Header.Type
+ *(*uint16)(unsafe.Pointer(&b[6:8][0])) = rr.Header.Flags
+ *(*uint32)(unsafe.Pointer(&b[8:12][0])) = rr.Header.Seq
+ *(*uint32)(unsafe.Pointer(&b[12:16][0])) = rr.Header.Pid
+ b[16] = byte(rr.Data.Family)
+ return b
+}
+
+func (s *Stack) getNetlinkFd() (uint32, *syserr.Error) {
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(syscall.AF_NETLINK), Type: int64(syscall.SOCK_RAW | syscall.SOCK_NONBLOCK), Protocol: int64(syscall.NETLINK_ROUTE)}}}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
+ if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
+ return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+ return res.(*pb.SocketResponse_Fd).Fd, nil
+}
+
+func (s *Stack) bindNetlinkFd(fd uint32, sockaddr []byte) *syserr.Error {
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: fd, Address: sockaddr}}}, false /* ignoreResult */)
+ <-c
+
+ if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+func (s *Stack) closeNetlinkFd(fd uint32) {
+ _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: fd}}}, true /* ignoreResult */)
+}
+
+func (s *Stack) rpcSendMsg(req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
+ if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
+ return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.SendmsgResponse_Length).Length, nil
+}
+
+func (s *Stack) sendMsg(fd uint32, buf []byte, to []byte, flags int) (int, *syserr.Error) {
+ // Whitelist flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
+ Fd: fd,
+ Data: buf,
+ Address: to,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }}
+
+ n, err := s.rpcSendMsg(req)
+ return int(n), err
+}
+
+func (s *Stack) rpcRecvMsg(req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
+ if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.RecvmsgResponse_Payload).Payload, nil
+}
+
+func (s *Stack) recvMsg(fd, l, flags uint32) ([]byte, *syserr.Error) {
+ req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
+ Fd: fd,
+ Length: l,
+ Sender: false,
+ Trunc: flags&linux.MSG_TRUNC != 0,
+ Peek: flags&linux.MSG_PEEK != 0,
+ }}
+
+ res, err := s.rpcRecvMsg(req)
+ if err != nil {
+ return nil, err
+ }
+ return res.Data, nil
+}
+
+func (s *Stack) netlinkRequest(proto, family int) ([]byte, error) {
+ fd, err := s.getNetlinkFd()
+ if err != nil {
+ return nil, err.ToError()
+ }
+ defer s.closeNetlinkFd(fd)
+
+ lsa := syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}
+ b := binary.Marshal(nil, usermem.ByteOrder, &lsa)
+ if err := s.bindNetlinkFd(fd, b); err != nil {
+ return nil, err.ToError()
+ }
+
+ wb := newNetlinkRouteRequest(proto, 1, family)
+ _, err = s.sendMsg(fd, wb, b, 0)
+ if err != nil {
+ return nil, err.ToError()
+ }
+
+ var tab []byte
+done:
+ for {
+ rb, err := s.recvMsg(fd, uint32(syscall.Getpagesize()), 0)
+ nr := len(rb)
+ if err != nil {
+ return nil, err.ToError()
+ }
+
+ if nr < syscall.NLMSG_HDRLEN {
+ return nil, syserr.ErrInvalidArgument.ToError()
+ }
+
+ tab = append(tab, rb...)
+ msgs, e := syscall.ParseNetlinkMessage(rb)
+ if e != nil {
+ return nil, e
+ }
+
+ for _, m := range msgs {
+ if m.Header.Type == syscall.NLMSG_DONE {
+ break done
+ }
+ if m.Header.Type == syscall.NLMSG_ERROR {
+ return nil, syserr.ErrInvalidArgument.ToError()
+ }
+ }
+ }
+
+ return tab, nil
+}
+
+// DoNetlinkRouteRequest returns routing information base, also known as RIB,
+// which consists of network facility information, states and parameters.
+func (s *Stack) DoNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
+ data, err := s.netlinkRequest(req, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, err
+ }
+ return syscall.ParseNetlinkMessage(data)
+}
diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto/syscall_rpc.pb.go b/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto/syscall_rpc.pb.go
new file mode 100755
index 000000000..fb68d5294
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto/syscall_rpc.pb.go
@@ -0,0 +1,3938 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: pkg/sentry/socket/rpcinet/syscall_rpc.proto
+
+package syscall_rpc
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type SendmsgRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
+ Address []byte `protobuf:"bytes,3,opt,name=address,proto3" json:"address,omitempty"`
+ More bool `protobuf:"varint,4,opt,name=more,proto3" json:"more,omitempty"`
+ EndOfRecord bool `protobuf:"varint,5,opt,name=end_of_record,json=endOfRecord,proto3" json:"end_of_record,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *SendmsgRequest) Reset() { *m = SendmsgRequest{} }
+func (m *SendmsgRequest) String() string { return proto.CompactTextString(m) }
+func (*SendmsgRequest) ProtoMessage() {}
+func (*SendmsgRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{0}
+}
+
+func (m *SendmsgRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_SendmsgRequest.Unmarshal(m, b)
+}
+func (m *SendmsgRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_SendmsgRequest.Marshal(b, m, deterministic)
+}
+func (m *SendmsgRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_SendmsgRequest.Merge(m, src)
+}
+func (m *SendmsgRequest) XXX_Size() int {
+ return xxx_messageInfo_SendmsgRequest.Size(m)
+}
+func (m *SendmsgRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_SendmsgRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_SendmsgRequest proto.InternalMessageInfo
+
+func (m *SendmsgRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *SendmsgRequest) GetData() []byte {
+ if m != nil {
+ return m.Data
+ }
+ return nil
+}
+
+func (m *SendmsgRequest) GetAddress() []byte {
+ if m != nil {
+ return m.Address
+ }
+ return nil
+}
+
+func (m *SendmsgRequest) GetMore() bool {
+ if m != nil {
+ return m.More
+ }
+ return false
+}
+
+func (m *SendmsgRequest) GetEndOfRecord() bool {
+ if m != nil {
+ return m.EndOfRecord
+ }
+ return false
+}
+
+type SendmsgResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *SendmsgResponse_ErrorNumber
+ // *SendmsgResponse_Length
+ Result isSendmsgResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *SendmsgResponse) Reset() { *m = SendmsgResponse{} }
+func (m *SendmsgResponse) String() string { return proto.CompactTextString(m) }
+func (*SendmsgResponse) ProtoMessage() {}
+func (*SendmsgResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{1}
+}
+
+func (m *SendmsgResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_SendmsgResponse.Unmarshal(m, b)
+}
+func (m *SendmsgResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_SendmsgResponse.Marshal(b, m, deterministic)
+}
+func (m *SendmsgResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_SendmsgResponse.Merge(m, src)
+}
+func (m *SendmsgResponse) XXX_Size() int {
+ return xxx_messageInfo_SendmsgResponse.Size(m)
+}
+func (m *SendmsgResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_SendmsgResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_SendmsgResponse proto.InternalMessageInfo
+
+type isSendmsgResponse_Result interface {
+ isSendmsgResponse_Result()
+}
+
+type SendmsgResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type SendmsgResponse_Length struct {
+ Length uint32 `protobuf:"varint,2,opt,name=length,proto3,oneof"`
+}
+
+func (*SendmsgResponse_ErrorNumber) isSendmsgResponse_Result() {}
+
+func (*SendmsgResponse_Length) isSendmsgResponse_Result() {}
+
+func (m *SendmsgResponse) GetResult() isSendmsgResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *SendmsgResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*SendmsgResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *SendmsgResponse) GetLength() uint32 {
+ if x, ok := m.GetResult().(*SendmsgResponse_Length); ok {
+ return x.Length
+ }
+ return 0
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*SendmsgResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*SendmsgResponse_ErrorNumber)(nil),
+ (*SendmsgResponse_Length)(nil),
+ }
+}
+
+type IOCtlRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Cmd uint32 `protobuf:"varint,2,opt,name=cmd,proto3" json:"cmd,omitempty"`
+ Arg []byte `protobuf:"bytes,3,opt,name=arg,proto3" json:"arg,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *IOCtlRequest) Reset() { *m = IOCtlRequest{} }
+func (m *IOCtlRequest) String() string { return proto.CompactTextString(m) }
+func (*IOCtlRequest) ProtoMessage() {}
+func (*IOCtlRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{2}
+}
+
+func (m *IOCtlRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_IOCtlRequest.Unmarshal(m, b)
+}
+func (m *IOCtlRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_IOCtlRequest.Marshal(b, m, deterministic)
+}
+func (m *IOCtlRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_IOCtlRequest.Merge(m, src)
+}
+func (m *IOCtlRequest) XXX_Size() int {
+ return xxx_messageInfo_IOCtlRequest.Size(m)
+}
+func (m *IOCtlRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_IOCtlRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_IOCtlRequest proto.InternalMessageInfo
+
+func (m *IOCtlRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *IOCtlRequest) GetCmd() uint32 {
+ if m != nil {
+ return m.Cmd
+ }
+ return 0
+}
+
+func (m *IOCtlRequest) GetArg() []byte {
+ if m != nil {
+ return m.Arg
+ }
+ return nil
+}
+
+type IOCtlResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *IOCtlResponse_ErrorNumber
+ // *IOCtlResponse_Value
+ Result isIOCtlResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *IOCtlResponse) Reset() { *m = IOCtlResponse{} }
+func (m *IOCtlResponse) String() string { return proto.CompactTextString(m) }
+func (*IOCtlResponse) ProtoMessage() {}
+func (*IOCtlResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{3}
+}
+
+func (m *IOCtlResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_IOCtlResponse.Unmarshal(m, b)
+}
+func (m *IOCtlResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_IOCtlResponse.Marshal(b, m, deterministic)
+}
+func (m *IOCtlResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_IOCtlResponse.Merge(m, src)
+}
+func (m *IOCtlResponse) XXX_Size() int {
+ return xxx_messageInfo_IOCtlResponse.Size(m)
+}
+func (m *IOCtlResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_IOCtlResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_IOCtlResponse proto.InternalMessageInfo
+
+type isIOCtlResponse_Result interface {
+ isIOCtlResponse_Result()
+}
+
+type IOCtlResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type IOCtlResponse_Value struct {
+ Value []byte `protobuf:"bytes,2,opt,name=value,proto3,oneof"`
+}
+
+func (*IOCtlResponse_ErrorNumber) isIOCtlResponse_Result() {}
+
+func (*IOCtlResponse_Value) isIOCtlResponse_Result() {}
+
+func (m *IOCtlResponse) GetResult() isIOCtlResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *IOCtlResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*IOCtlResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *IOCtlResponse) GetValue() []byte {
+ if x, ok := m.GetResult().(*IOCtlResponse_Value); ok {
+ return x.Value
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*IOCtlResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*IOCtlResponse_ErrorNumber)(nil),
+ (*IOCtlResponse_Value)(nil),
+ }
+}
+
+type RecvmsgRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Length uint32 `protobuf:"varint,2,opt,name=length,proto3" json:"length,omitempty"`
+ Sender bool `protobuf:"varint,3,opt,name=sender,proto3" json:"sender,omitempty"`
+ Peek bool `protobuf:"varint,4,opt,name=peek,proto3" json:"peek,omitempty"`
+ Trunc bool `protobuf:"varint,5,opt,name=trunc,proto3" json:"trunc,omitempty"`
+ CmsgLength uint32 `protobuf:"varint,6,opt,name=cmsg_length,json=cmsgLength,proto3" json:"cmsg_length,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *RecvmsgRequest) Reset() { *m = RecvmsgRequest{} }
+func (m *RecvmsgRequest) String() string { return proto.CompactTextString(m) }
+func (*RecvmsgRequest) ProtoMessage() {}
+func (*RecvmsgRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{4}
+}
+
+func (m *RecvmsgRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_RecvmsgRequest.Unmarshal(m, b)
+}
+func (m *RecvmsgRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_RecvmsgRequest.Marshal(b, m, deterministic)
+}
+func (m *RecvmsgRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_RecvmsgRequest.Merge(m, src)
+}
+func (m *RecvmsgRequest) XXX_Size() int {
+ return xxx_messageInfo_RecvmsgRequest.Size(m)
+}
+func (m *RecvmsgRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_RecvmsgRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_RecvmsgRequest proto.InternalMessageInfo
+
+func (m *RecvmsgRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *RecvmsgRequest) GetLength() uint32 {
+ if m != nil {
+ return m.Length
+ }
+ return 0
+}
+
+func (m *RecvmsgRequest) GetSender() bool {
+ if m != nil {
+ return m.Sender
+ }
+ return false
+}
+
+func (m *RecvmsgRequest) GetPeek() bool {
+ if m != nil {
+ return m.Peek
+ }
+ return false
+}
+
+func (m *RecvmsgRequest) GetTrunc() bool {
+ if m != nil {
+ return m.Trunc
+ }
+ return false
+}
+
+func (m *RecvmsgRequest) GetCmsgLength() uint32 {
+ if m != nil {
+ return m.CmsgLength
+ }
+ return 0
+}
+
+type OpenRequest struct {
+ Path []byte `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
+ Flags uint32 `protobuf:"varint,2,opt,name=flags,proto3" json:"flags,omitempty"`
+ Mode uint32 `protobuf:"varint,3,opt,name=mode,proto3" json:"mode,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *OpenRequest) Reset() { *m = OpenRequest{} }
+func (m *OpenRequest) String() string { return proto.CompactTextString(m) }
+func (*OpenRequest) ProtoMessage() {}
+func (*OpenRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{5}
+}
+
+func (m *OpenRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_OpenRequest.Unmarshal(m, b)
+}
+func (m *OpenRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_OpenRequest.Marshal(b, m, deterministic)
+}
+func (m *OpenRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_OpenRequest.Merge(m, src)
+}
+func (m *OpenRequest) XXX_Size() int {
+ return xxx_messageInfo_OpenRequest.Size(m)
+}
+func (m *OpenRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_OpenRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_OpenRequest proto.InternalMessageInfo
+
+func (m *OpenRequest) GetPath() []byte {
+ if m != nil {
+ return m.Path
+ }
+ return nil
+}
+
+func (m *OpenRequest) GetFlags() uint32 {
+ if m != nil {
+ return m.Flags
+ }
+ return 0
+}
+
+func (m *OpenRequest) GetMode() uint32 {
+ if m != nil {
+ return m.Mode
+ }
+ return 0
+}
+
+type OpenResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *OpenResponse_ErrorNumber
+ // *OpenResponse_Fd
+ Result isOpenResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *OpenResponse) Reset() { *m = OpenResponse{} }
+func (m *OpenResponse) String() string { return proto.CompactTextString(m) }
+func (*OpenResponse) ProtoMessage() {}
+func (*OpenResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{6}
+}
+
+func (m *OpenResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_OpenResponse.Unmarshal(m, b)
+}
+func (m *OpenResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_OpenResponse.Marshal(b, m, deterministic)
+}
+func (m *OpenResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_OpenResponse.Merge(m, src)
+}
+func (m *OpenResponse) XXX_Size() int {
+ return xxx_messageInfo_OpenResponse.Size(m)
+}
+func (m *OpenResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_OpenResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_OpenResponse proto.InternalMessageInfo
+
+type isOpenResponse_Result interface {
+ isOpenResponse_Result()
+}
+
+type OpenResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type OpenResponse_Fd struct {
+ Fd uint32 `protobuf:"varint,2,opt,name=fd,proto3,oneof"`
+}
+
+func (*OpenResponse_ErrorNumber) isOpenResponse_Result() {}
+
+func (*OpenResponse_Fd) isOpenResponse_Result() {}
+
+func (m *OpenResponse) GetResult() isOpenResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *OpenResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*OpenResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *OpenResponse) GetFd() uint32 {
+ if x, ok := m.GetResult().(*OpenResponse_Fd); ok {
+ return x.Fd
+ }
+ return 0
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*OpenResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*OpenResponse_ErrorNumber)(nil),
+ (*OpenResponse_Fd)(nil),
+ }
+}
+
+type ReadRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Length uint32 `protobuf:"varint,2,opt,name=length,proto3" json:"length,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ReadRequest) Reset() { *m = ReadRequest{} }
+func (m *ReadRequest) String() string { return proto.CompactTextString(m) }
+func (*ReadRequest) ProtoMessage() {}
+func (*ReadRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{7}
+}
+
+func (m *ReadRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ReadRequest.Unmarshal(m, b)
+}
+func (m *ReadRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ReadRequest.Marshal(b, m, deterministic)
+}
+func (m *ReadRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ReadRequest.Merge(m, src)
+}
+func (m *ReadRequest) XXX_Size() int {
+ return xxx_messageInfo_ReadRequest.Size(m)
+}
+func (m *ReadRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_ReadRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ReadRequest proto.InternalMessageInfo
+
+func (m *ReadRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *ReadRequest) GetLength() uint32 {
+ if m != nil {
+ return m.Length
+ }
+ return 0
+}
+
+type ReadResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *ReadResponse_ErrorNumber
+ // *ReadResponse_Data
+ Result isReadResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ReadResponse) Reset() { *m = ReadResponse{} }
+func (m *ReadResponse) String() string { return proto.CompactTextString(m) }
+func (*ReadResponse) ProtoMessage() {}
+func (*ReadResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{8}
+}
+
+func (m *ReadResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ReadResponse.Unmarshal(m, b)
+}
+func (m *ReadResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ReadResponse.Marshal(b, m, deterministic)
+}
+func (m *ReadResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ReadResponse.Merge(m, src)
+}
+func (m *ReadResponse) XXX_Size() int {
+ return xxx_messageInfo_ReadResponse.Size(m)
+}
+func (m *ReadResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_ReadResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ReadResponse proto.InternalMessageInfo
+
+type isReadResponse_Result interface {
+ isReadResponse_Result()
+}
+
+type ReadResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type ReadResponse_Data struct {
+ Data []byte `protobuf:"bytes,2,opt,name=data,proto3,oneof"`
+}
+
+func (*ReadResponse_ErrorNumber) isReadResponse_Result() {}
+
+func (*ReadResponse_Data) isReadResponse_Result() {}
+
+func (m *ReadResponse) GetResult() isReadResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *ReadResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*ReadResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *ReadResponse) GetData() []byte {
+ if x, ok := m.GetResult().(*ReadResponse_Data); ok {
+ return x.Data
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*ReadResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*ReadResponse_ErrorNumber)(nil),
+ (*ReadResponse_Data)(nil),
+ }
+}
+
+type ReadFileRequest struct {
+ Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ReadFileRequest) Reset() { *m = ReadFileRequest{} }
+func (m *ReadFileRequest) String() string { return proto.CompactTextString(m) }
+func (*ReadFileRequest) ProtoMessage() {}
+func (*ReadFileRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{9}
+}
+
+func (m *ReadFileRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ReadFileRequest.Unmarshal(m, b)
+}
+func (m *ReadFileRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ReadFileRequest.Marshal(b, m, deterministic)
+}
+func (m *ReadFileRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ReadFileRequest.Merge(m, src)
+}
+func (m *ReadFileRequest) XXX_Size() int {
+ return xxx_messageInfo_ReadFileRequest.Size(m)
+}
+func (m *ReadFileRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_ReadFileRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ReadFileRequest proto.InternalMessageInfo
+
+func (m *ReadFileRequest) GetPath() string {
+ if m != nil {
+ return m.Path
+ }
+ return ""
+}
+
+type ReadFileResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *ReadFileResponse_ErrorNumber
+ // *ReadFileResponse_Data
+ Result isReadFileResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ReadFileResponse) Reset() { *m = ReadFileResponse{} }
+func (m *ReadFileResponse) String() string { return proto.CompactTextString(m) }
+func (*ReadFileResponse) ProtoMessage() {}
+func (*ReadFileResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{10}
+}
+
+func (m *ReadFileResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ReadFileResponse.Unmarshal(m, b)
+}
+func (m *ReadFileResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ReadFileResponse.Marshal(b, m, deterministic)
+}
+func (m *ReadFileResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ReadFileResponse.Merge(m, src)
+}
+func (m *ReadFileResponse) XXX_Size() int {
+ return xxx_messageInfo_ReadFileResponse.Size(m)
+}
+func (m *ReadFileResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_ReadFileResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ReadFileResponse proto.InternalMessageInfo
+
+type isReadFileResponse_Result interface {
+ isReadFileResponse_Result()
+}
+
+type ReadFileResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type ReadFileResponse_Data struct {
+ Data []byte `protobuf:"bytes,2,opt,name=data,proto3,oneof"`
+}
+
+func (*ReadFileResponse_ErrorNumber) isReadFileResponse_Result() {}
+
+func (*ReadFileResponse_Data) isReadFileResponse_Result() {}
+
+func (m *ReadFileResponse) GetResult() isReadFileResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *ReadFileResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*ReadFileResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *ReadFileResponse) GetData() []byte {
+ if x, ok := m.GetResult().(*ReadFileResponse_Data); ok {
+ return x.Data
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*ReadFileResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*ReadFileResponse_ErrorNumber)(nil),
+ (*ReadFileResponse_Data)(nil),
+ }
+}
+
+type WriteRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *WriteRequest) Reset() { *m = WriteRequest{} }
+func (m *WriteRequest) String() string { return proto.CompactTextString(m) }
+func (*WriteRequest) ProtoMessage() {}
+func (*WriteRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{11}
+}
+
+func (m *WriteRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_WriteRequest.Unmarshal(m, b)
+}
+func (m *WriteRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_WriteRequest.Marshal(b, m, deterministic)
+}
+func (m *WriteRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_WriteRequest.Merge(m, src)
+}
+func (m *WriteRequest) XXX_Size() int {
+ return xxx_messageInfo_WriteRequest.Size(m)
+}
+func (m *WriteRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_WriteRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_WriteRequest proto.InternalMessageInfo
+
+func (m *WriteRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *WriteRequest) GetData() []byte {
+ if m != nil {
+ return m.Data
+ }
+ return nil
+}
+
+type WriteResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *WriteResponse_ErrorNumber
+ // *WriteResponse_Length
+ Result isWriteResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *WriteResponse) Reset() { *m = WriteResponse{} }
+func (m *WriteResponse) String() string { return proto.CompactTextString(m) }
+func (*WriteResponse) ProtoMessage() {}
+func (*WriteResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{12}
+}
+
+func (m *WriteResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_WriteResponse.Unmarshal(m, b)
+}
+func (m *WriteResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_WriteResponse.Marshal(b, m, deterministic)
+}
+func (m *WriteResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_WriteResponse.Merge(m, src)
+}
+func (m *WriteResponse) XXX_Size() int {
+ return xxx_messageInfo_WriteResponse.Size(m)
+}
+func (m *WriteResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_WriteResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_WriteResponse proto.InternalMessageInfo
+
+type isWriteResponse_Result interface {
+ isWriteResponse_Result()
+}
+
+type WriteResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type WriteResponse_Length struct {
+ Length uint32 `protobuf:"varint,2,opt,name=length,proto3,oneof"`
+}
+
+func (*WriteResponse_ErrorNumber) isWriteResponse_Result() {}
+
+func (*WriteResponse_Length) isWriteResponse_Result() {}
+
+func (m *WriteResponse) GetResult() isWriteResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *WriteResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*WriteResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *WriteResponse) GetLength() uint32 {
+ if x, ok := m.GetResult().(*WriteResponse_Length); ok {
+ return x.Length
+ }
+ return 0
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*WriteResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*WriteResponse_ErrorNumber)(nil),
+ (*WriteResponse_Length)(nil),
+ }
+}
+
+type WriteFileRequest struct {
+ Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
+ Content []byte `protobuf:"bytes,2,opt,name=content,proto3" json:"content,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *WriteFileRequest) Reset() { *m = WriteFileRequest{} }
+func (m *WriteFileRequest) String() string { return proto.CompactTextString(m) }
+func (*WriteFileRequest) ProtoMessage() {}
+func (*WriteFileRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{13}
+}
+
+func (m *WriteFileRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_WriteFileRequest.Unmarshal(m, b)
+}
+func (m *WriteFileRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_WriteFileRequest.Marshal(b, m, deterministic)
+}
+func (m *WriteFileRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_WriteFileRequest.Merge(m, src)
+}
+func (m *WriteFileRequest) XXX_Size() int {
+ return xxx_messageInfo_WriteFileRequest.Size(m)
+}
+func (m *WriteFileRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_WriteFileRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_WriteFileRequest proto.InternalMessageInfo
+
+func (m *WriteFileRequest) GetPath() string {
+ if m != nil {
+ return m.Path
+ }
+ return ""
+}
+
+func (m *WriteFileRequest) GetContent() []byte {
+ if m != nil {
+ return m.Content
+ }
+ return nil
+}
+
+type WriteFileResponse struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"`
+ Written uint32 `protobuf:"varint,2,opt,name=written,proto3" json:"written,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *WriteFileResponse) Reset() { *m = WriteFileResponse{} }
+func (m *WriteFileResponse) String() string { return proto.CompactTextString(m) }
+func (*WriteFileResponse) ProtoMessage() {}
+func (*WriteFileResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{14}
+}
+
+func (m *WriteFileResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_WriteFileResponse.Unmarshal(m, b)
+}
+func (m *WriteFileResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_WriteFileResponse.Marshal(b, m, deterministic)
+}
+func (m *WriteFileResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_WriteFileResponse.Merge(m, src)
+}
+func (m *WriteFileResponse) XXX_Size() int {
+ return xxx_messageInfo_WriteFileResponse.Size(m)
+}
+func (m *WriteFileResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_WriteFileResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_WriteFileResponse proto.InternalMessageInfo
+
+func (m *WriteFileResponse) GetErrorNumber() uint32 {
+ if m != nil {
+ return m.ErrorNumber
+ }
+ return 0
+}
+
+func (m *WriteFileResponse) GetWritten() uint32 {
+ if m != nil {
+ return m.Written
+ }
+ return 0
+}
+
+type AddressResponse struct {
+ Address []byte `protobuf:"bytes,1,opt,name=address,proto3" json:"address,omitempty"`
+ Length uint32 `protobuf:"varint,2,opt,name=length,proto3" json:"length,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *AddressResponse) Reset() { *m = AddressResponse{} }
+func (m *AddressResponse) String() string { return proto.CompactTextString(m) }
+func (*AddressResponse) ProtoMessage() {}
+func (*AddressResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{15}
+}
+
+func (m *AddressResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_AddressResponse.Unmarshal(m, b)
+}
+func (m *AddressResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_AddressResponse.Marshal(b, m, deterministic)
+}
+func (m *AddressResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_AddressResponse.Merge(m, src)
+}
+func (m *AddressResponse) XXX_Size() int {
+ return xxx_messageInfo_AddressResponse.Size(m)
+}
+func (m *AddressResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_AddressResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_AddressResponse proto.InternalMessageInfo
+
+func (m *AddressResponse) GetAddress() []byte {
+ if m != nil {
+ return m.Address
+ }
+ return nil
+}
+
+func (m *AddressResponse) GetLength() uint32 {
+ if m != nil {
+ return m.Length
+ }
+ return 0
+}
+
+type RecvmsgResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *RecvmsgResponse_ErrorNumber
+ // *RecvmsgResponse_Payload
+ Result isRecvmsgResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *RecvmsgResponse) Reset() { *m = RecvmsgResponse{} }
+func (m *RecvmsgResponse) String() string { return proto.CompactTextString(m) }
+func (*RecvmsgResponse) ProtoMessage() {}
+func (*RecvmsgResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{16}
+}
+
+func (m *RecvmsgResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_RecvmsgResponse.Unmarshal(m, b)
+}
+func (m *RecvmsgResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_RecvmsgResponse.Marshal(b, m, deterministic)
+}
+func (m *RecvmsgResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_RecvmsgResponse.Merge(m, src)
+}
+func (m *RecvmsgResponse) XXX_Size() int {
+ return xxx_messageInfo_RecvmsgResponse.Size(m)
+}
+func (m *RecvmsgResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_RecvmsgResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_RecvmsgResponse proto.InternalMessageInfo
+
+type isRecvmsgResponse_Result interface {
+ isRecvmsgResponse_Result()
+}
+
+type RecvmsgResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type RecvmsgResponse_Payload struct {
+ Payload *RecvmsgResponse_ResultPayload `protobuf:"bytes,2,opt,name=payload,proto3,oneof"`
+}
+
+func (*RecvmsgResponse_ErrorNumber) isRecvmsgResponse_Result() {}
+
+func (*RecvmsgResponse_Payload) isRecvmsgResponse_Result() {}
+
+func (m *RecvmsgResponse) GetResult() isRecvmsgResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *RecvmsgResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*RecvmsgResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *RecvmsgResponse) GetPayload() *RecvmsgResponse_ResultPayload {
+ if x, ok := m.GetResult().(*RecvmsgResponse_Payload); ok {
+ return x.Payload
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*RecvmsgResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*RecvmsgResponse_ErrorNumber)(nil),
+ (*RecvmsgResponse_Payload)(nil),
+ }
+}
+
+type RecvmsgResponse_ResultPayload struct {
+ Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
+ Address *AddressResponse `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
+ Length uint32 `protobuf:"varint,3,opt,name=length,proto3" json:"length,omitempty"`
+ CmsgData []byte `protobuf:"bytes,4,opt,name=cmsg_data,json=cmsgData,proto3" json:"cmsg_data,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *RecvmsgResponse_ResultPayload) Reset() { *m = RecvmsgResponse_ResultPayload{} }
+func (m *RecvmsgResponse_ResultPayload) String() string { return proto.CompactTextString(m) }
+func (*RecvmsgResponse_ResultPayload) ProtoMessage() {}
+func (*RecvmsgResponse_ResultPayload) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{16, 0}
+}
+
+func (m *RecvmsgResponse_ResultPayload) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_RecvmsgResponse_ResultPayload.Unmarshal(m, b)
+}
+func (m *RecvmsgResponse_ResultPayload) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_RecvmsgResponse_ResultPayload.Marshal(b, m, deterministic)
+}
+func (m *RecvmsgResponse_ResultPayload) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_RecvmsgResponse_ResultPayload.Merge(m, src)
+}
+func (m *RecvmsgResponse_ResultPayload) XXX_Size() int {
+ return xxx_messageInfo_RecvmsgResponse_ResultPayload.Size(m)
+}
+func (m *RecvmsgResponse_ResultPayload) XXX_DiscardUnknown() {
+ xxx_messageInfo_RecvmsgResponse_ResultPayload.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_RecvmsgResponse_ResultPayload proto.InternalMessageInfo
+
+func (m *RecvmsgResponse_ResultPayload) GetData() []byte {
+ if m != nil {
+ return m.Data
+ }
+ return nil
+}
+
+func (m *RecvmsgResponse_ResultPayload) GetAddress() *AddressResponse {
+ if m != nil {
+ return m.Address
+ }
+ return nil
+}
+
+func (m *RecvmsgResponse_ResultPayload) GetLength() uint32 {
+ if m != nil {
+ return m.Length
+ }
+ return 0
+}
+
+func (m *RecvmsgResponse_ResultPayload) GetCmsgData() []byte {
+ if m != nil {
+ return m.CmsgData
+ }
+ return nil
+}
+
+type BindRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Address []byte `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *BindRequest) Reset() { *m = BindRequest{} }
+func (m *BindRequest) String() string { return proto.CompactTextString(m) }
+func (*BindRequest) ProtoMessage() {}
+func (*BindRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{17}
+}
+
+func (m *BindRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_BindRequest.Unmarshal(m, b)
+}
+func (m *BindRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_BindRequest.Marshal(b, m, deterministic)
+}
+func (m *BindRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_BindRequest.Merge(m, src)
+}
+func (m *BindRequest) XXX_Size() int {
+ return xxx_messageInfo_BindRequest.Size(m)
+}
+func (m *BindRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_BindRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_BindRequest proto.InternalMessageInfo
+
+func (m *BindRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *BindRequest) GetAddress() []byte {
+ if m != nil {
+ return m.Address
+ }
+ return nil
+}
+
+type BindResponse struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *BindResponse) Reset() { *m = BindResponse{} }
+func (m *BindResponse) String() string { return proto.CompactTextString(m) }
+func (*BindResponse) ProtoMessage() {}
+func (*BindResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{18}
+}
+
+func (m *BindResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_BindResponse.Unmarshal(m, b)
+}
+func (m *BindResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_BindResponse.Marshal(b, m, deterministic)
+}
+func (m *BindResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_BindResponse.Merge(m, src)
+}
+func (m *BindResponse) XXX_Size() int {
+ return xxx_messageInfo_BindResponse.Size(m)
+}
+func (m *BindResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_BindResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_BindResponse proto.InternalMessageInfo
+
+func (m *BindResponse) GetErrorNumber() uint32 {
+ if m != nil {
+ return m.ErrorNumber
+ }
+ return 0
+}
+
+type AcceptRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Peer bool `protobuf:"varint,2,opt,name=peer,proto3" json:"peer,omitempty"`
+ Flags int64 `protobuf:"varint,3,opt,name=flags,proto3" json:"flags,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *AcceptRequest) Reset() { *m = AcceptRequest{} }
+func (m *AcceptRequest) String() string { return proto.CompactTextString(m) }
+func (*AcceptRequest) ProtoMessage() {}
+func (*AcceptRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{19}
+}
+
+func (m *AcceptRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_AcceptRequest.Unmarshal(m, b)
+}
+func (m *AcceptRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_AcceptRequest.Marshal(b, m, deterministic)
+}
+func (m *AcceptRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_AcceptRequest.Merge(m, src)
+}
+func (m *AcceptRequest) XXX_Size() int {
+ return xxx_messageInfo_AcceptRequest.Size(m)
+}
+func (m *AcceptRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_AcceptRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_AcceptRequest proto.InternalMessageInfo
+
+func (m *AcceptRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *AcceptRequest) GetPeer() bool {
+ if m != nil {
+ return m.Peer
+ }
+ return false
+}
+
+func (m *AcceptRequest) GetFlags() int64 {
+ if m != nil {
+ return m.Flags
+ }
+ return 0
+}
+
+type AcceptResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *AcceptResponse_ErrorNumber
+ // *AcceptResponse_Payload
+ Result isAcceptResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *AcceptResponse) Reset() { *m = AcceptResponse{} }
+func (m *AcceptResponse) String() string { return proto.CompactTextString(m) }
+func (*AcceptResponse) ProtoMessage() {}
+func (*AcceptResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{20}
+}
+
+func (m *AcceptResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_AcceptResponse.Unmarshal(m, b)
+}
+func (m *AcceptResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_AcceptResponse.Marshal(b, m, deterministic)
+}
+func (m *AcceptResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_AcceptResponse.Merge(m, src)
+}
+func (m *AcceptResponse) XXX_Size() int {
+ return xxx_messageInfo_AcceptResponse.Size(m)
+}
+func (m *AcceptResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_AcceptResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_AcceptResponse proto.InternalMessageInfo
+
+type isAcceptResponse_Result interface {
+ isAcceptResponse_Result()
+}
+
+type AcceptResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type AcceptResponse_Payload struct {
+ Payload *AcceptResponse_ResultPayload `protobuf:"bytes,2,opt,name=payload,proto3,oneof"`
+}
+
+func (*AcceptResponse_ErrorNumber) isAcceptResponse_Result() {}
+
+func (*AcceptResponse_Payload) isAcceptResponse_Result() {}
+
+func (m *AcceptResponse) GetResult() isAcceptResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *AcceptResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*AcceptResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *AcceptResponse) GetPayload() *AcceptResponse_ResultPayload {
+ if x, ok := m.GetResult().(*AcceptResponse_Payload); ok {
+ return x.Payload
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*AcceptResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*AcceptResponse_ErrorNumber)(nil),
+ (*AcceptResponse_Payload)(nil),
+ }
+}
+
+type AcceptResponse_ResultPayload struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Address *AddressResponse `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *AcceptResponse_ResultPayload) Reset() { *m = AcceptResponse_ResultPayload{} }
+func (m *AcceptResponse_ResultPayload) String() string { return proto.CompactTextString(m) }
+func (*AcceptResponse_ResultPayload) ProtoMessage() {}
+func (*AcceptResponse_ResultPayload) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{20, 0}
+}
+
+func (m *AcceptResponse_ResultPayload) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_AcceptResponse_ResultPayload.Unmarshal(m, b)
+}
+func (m *AcceptResponse_ResultPayload) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_AcceptResponse_ResultPayload.Marshal(b, m, deterministic)
+}
+func (m *AcceptResponse_ResultPayload) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_AcceptResponse_ResultPayload.Merge(m, src)
+}
+func (m *AcceptResponse_ResultPayload) XXX_Size() int {
+ return xxx_messageInfo_AcceptResponse_ResultPayload.Size(m)
+}
+func (m *AcceptResponse_ResultPayload) XXX_DiscardUnknown() {
+ xxx_messageInfo_AcceptResponse_ResultPayload.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_AcceptResponse_ResultPayload proto.InternalMessageInfo
+
+func (m *AcceptResponse_ResultPayload) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *AcceptResponse_ResultPayload) GetAddress() *AddressResponse {
+ if m != nil {
+ return m.Address
+ }
+ return nil
+}
+
+type ConnectRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Address []byte `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ConnectRequest) Reset() { *m = ConnectRequest{} }
+func (m *ConnectRequest) String() string { return proto.CompactTextString(m) }
+func (*ConnectRequest) ProtoMessage() {}
+func (*ConnectRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{21}
+}
+
+func (m *ConnectRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ConnectRequest.Unmarshal(m, b)
+}
+func (m *ConnectRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ConnectRequest.Marshal(b, m, deterministic)
+}
+func (m *ConnectRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ConnectRequest.Merge(m, src)
+}
+func (m *ConnectRequest) XXX_Size() int {
+ return xxx_messageInfo_ConnectRequest.Size(m)
+}
+func (m *ConnectRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_ConnectRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ConnectRequest proto.InternalMessageInfo
+
+func (m *ConnectRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *ConnectRequest) GetAddress() []byte {
+ if m != nil {
+ return m.Address
+ }
+ return nil
+}
+
+type ConnectResponse struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ConnectResponse) Reset() { *m = ConnectResponse{} }
+func (m *ConnectResponse) String() string { return proto.CompactTextString(m) }
+func (*ConnectResponse) ProtoMessage() {}
+func (*ConnectResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{22}
+}
+
+func (m *ConnectResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ConnectResponse.Unmarshal(m, b)
+}
+func (m *ConnectResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ConnectResponse.Marshal(b, m, deterministic)
+}
+func (m *ConnectResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ConnectResponse.Merge(m, src)
+}
+func (m *ConnectResponse) XXX_Size() int {
+ return xxx_messageInfo_ConnectResponse.Size(m)
+}
+func (m *ConnectResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_ConnectResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ConnectResponse proto.InternalMessageInfo
+
+func (m *ConnectResponse) GetErrorNumber() uint32 {
+ if m != nil {
+ return m.ErrorNumber
+ }
+ return 0
+}
+
+type ListenRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Backlog int64 `protobuf:"varint,2,opt,name=backlog,proto3" json:"backlog,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ListenRequest) Reset() { *m = ListenRequest{} }
+func (m *ListenRequest) String() string { return proto.CompactTextString(m) }
+func (*ListenRequest) ProtoMessage() {}
+func (*ListenRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{23}
+}
+
+func (m *ListenRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ListenRequest.Unmarshal(m, b)
+}
+func (m *ListenRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ListenRequest.Marshal(b, m, deterministic)
+}
+func (m *ListenRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ListenRequest.Merge(m, src)
+}
+func (m *ListenRequest) XXX_Size() int {
+ return xxx_messageInfo_ListenRequest.Size(m)
+}
+func (m *ListenRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_ListenRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ListenRequest proto.InternalMessageInfo
+
+func (m *ListenRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *ListenRequest) GetBacklog() int64 {
+ if m != nil {
+ return m.Backlog
+ }
+ return 0
+}
+
+type ListenResponse struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ListenResponse) Reset() { *m = ListenResponse{} }
+func (m *ListenResponse) String() string { return proto.CompactTextString(m) }
+func (*ListenResponse) ProtoMessage() {}
+func (*ListenResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{24}
+}
+
+func (m *ListenResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ListenResponse.Unmarshal(m, b)
+}
+func (m *ListenResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ListenResponse.Marshal(b, m, deterministic)
+}
+func (m *ListenResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ListenResponse.Merge(m, src)
+}
+func (m *ListenResponse) XXX_Size() int {
+ return xxx_messageInfo_ListenResponse.Size(m)
+}
+func (m *ListenResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_ListenResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ListenResponse proto.InternalMessageInfo
+
+func (m *ListenResponse) GetErrorNumber() uint32 {
+ if m != nil {
+ return m.ErrorNumber
+ }
+ return 0
+}
+
+type ShutdownRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ How int64 `protobuf:"varint,2,opt,name=how,proto3" json:"how,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ShutdownRequest) Reset() { *m = ShutdownRequest{} }
+func (m *ShutdownRequest) String() string { return proto.CompactTextString(m) }
+func (*ShutdownRequest) ProtoMessage() {}
+func (*ShutdownRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{25}
+}
+
+func (m *ShutdownRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ShutdownRequest.Unmarshal(m, b)
+}
+func (m *ShutdownRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ShutdownRequest.Marshal(b, m, deterministic)
+}
+func (m *ShutdownRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ShutdownRequest.Merge(m, src)
+}
+func (m *ShutdownRequest) XXX_Size() int {
+ return xxx_messageInfo_ShutdownRequest.Size(m)
+}
+func (m *ShutdownRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_ShutdownRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ShutdownRequest proto.InternalMessageInfo
+
+func (m *ShutdownRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *ShutdownRequest) GetHow() int64 {
+ if m != nil {
+ return m.How
+ }
+ return 0
+}
+
+type ShutdownResponse struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *ShutdownResponse) Reset() { *m = ShutdownResponse{} }
+func (m *ShutdownResponse) String() string { return proto.CompactTextString(m) }
+func (*ShutdownResponse) ProtoMessage() {}
+func (*ShutdownResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{26}
+}
+
+func (m *ShutdownResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_ShutdownResponse.Unmarshal(m, b)
+}
+func (m *ShutdownResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_ShutdownResponse.Marshal(b, m, deterministic)
+}
+func (m *ShutdownResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_ShutdownResponse.Merge(m, src)
+}
+func (m *ShutdownResponse) XXX_Size() int {
+ return xxx_messageInfo_ShutdownResponse.Size(m)
+}
+func (m *ShutdownResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_ShutdownResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_ShutdownResponse proto.InternalMessageInfo
+
+func (m *ShutdownResponse) GetErrorNumber() uint32 {
+ if m != nil {
+ return m.ErrorNumber
+ }
+ return 0
+}
+
+type CloseRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *CloseRequest) Reset() { *m = CloseRequest{} }
+func (m *CloseRequest) String() string { return proto.CompactTextString(m) }
+func (*CloseRequest) ProtoMessage() {}
+func (*CloseRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{27}
+}
+
+func (m *CloseRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_CloseRequest.Unmarshal(m, b)
+}
+func (m *CloseRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_CloseRequest.Marshal(b, m, deterministic)
+}
+func (m *CloseRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_CloseRequest.Merge(m, src)
+}
+func (m *CloseRequest) XXX_Size() int {
+ return xxx_messageInfo_CloseRequest.Size(m)
+}
+func (m *CloseRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_CloseRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_CloseRequest proto.InternalMessageInfo
+
+func (m *CloseRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+type CloseResponse struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *CloseResponse) Reset() { *m = CloseResponse{} }
+func (m *CloseResponse) String() string { return proto.CompactTextString(m) }
+func (*CloseResponse) ProtoMessage() {}
+func (*CloseResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{28}
+}
+
+func (m *CloseResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_CloseResponse.Unmarshal(m, b)
+}
+func (m *CloseResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_CloseResponse.Marshal(b, m, deterministic)
+}
+func (m *CloseResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_CloseResponse.Merge(m, src)
+}
+func (m *CloseResponse) XXX_Size() int {
+ return xxx_messageInfo_CloseResponse.Size(m)
+}
+func (m *CloseResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_CloseResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_CloseResponse proto.InternalMessageInfo
+
+func (m *CloseResponse) GetErrorNumber() uint32 {
+ if m != nil {
+ return m.ErrorNumber
+ }
+ return 0
+}
+
+type GetSockOptRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Level int64 `protobuf:"varint,2,opt,name=level,proto3" json:"level,omitempty"`
+ Name int64 `protobuf:"varint,3,opt,name=name,proto3" json:"name,omitempty"`
+ Length uint32 `protobuf:"varint,4,opt,name=length,proto3" json:"length,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *GetSockOptRequest) Reset() { *m = GetSockOptRequest{} }
+func (m *GetSockOptRequest) String() string { return proto.CompactTextString(m) }
+func (*GetSockOptRequest) ProtoMessage() {}
+func (*GetSockOptRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{29}
+}
+
+func (m *GetSockOptRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_GetSockOptRequest.Unmarshal(m, b)
+}
+func (m *GetSockOptRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_GetSockOptRequest.Marshal(b, m, deterministic)
+}
+func (m *GetSockOptRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_GetSockOptRequest.Merge(m, src)
+}
+func (m *GetSockOptRequest) XXX_Size() int {
+ return xxx_messageInfo_GetSockOptRequest.Size(m)
+}
+func (m *GetSockOptRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_GetSockOptRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_GetSockOptRequest proto.InternalMessageInfo
+
+func (m *GetSockOptRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *GetSockOptRequest) GetLevel() int64 {
+ if m != nil {
+ return m.Level
+ }
+ return 0
+}
+
+func (m *GetSockOptRequest) GetName() int64 {
+ if m != nil {
+ return m.Name
+ }
+ return 0
+}
+
+func (m *GetSockOptRequest) GetLength() uint32 {
+ if m != nil {
+ return m.Length
+ }
+ return 0
+}
+
+type GetSockOptResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *GetSockOptResponse_ErrorNumber
+ // *GetSockOptResponse_Opt
+ Result isGetSockOptResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *GetSockOptResponse) Reset() { *m = GetSockOptResponse{} }
+func (m *GetSockOptResponse) String() string { return proto.CompactTextString(m) }
+func (*GetSockOptResponse) ProtoMessage() {}
+func (*GetSockOptResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{30}
+}
+
+func (m *GetSockOptResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_GetSockOptResponse.Unmarshal(m, b)
+}
+func (m *GetSockOptResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_GetSockOptResponse.Marshal(b, m, deterministic)
+}
+func (m *GetSockOptResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_GetSockOptResponse.Merge(m, src)
+}
+func (m *GetSockOptResponse) XXX_Size() int {
+ return xxx_messageInfo_GetSockOptResponse.Size(m)
+}
+func (m *GetSockOptResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_GetSockOptResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_GetSockOptResponse proto.InternalMessageInfo
+
+type isGetSockOptResponse_Result interface {
+ isGetSockOptResponse_Result()
+}
+
+type GetSockOptResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type GetSockOptResponse_Opt struct {
+ Opt []byte `protobuf:"bytes,2,opt,name=opt,proto3,oneof"`
+}
+
+func (*GetSockOptResponse_ErrorNumber) isGetSockOptResponse_Result() {}
+
+func (*GetSockOptResponse_Opt) isGetSockOptResponse_Result() {}
+
+func (m *GetSockOptResponse) GetResult() isGetSockOptResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *GetSockOptResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*GetSockOptResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *GetSockOptResponse) GetOpt() []byte {
+ if x, ok := m.GetResult().(*GetSockOptResponse_Opt); ok {
+ return x.Opt
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*GetSockOptResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*GetSockOptResponse_ErrorNumber)(nil),
+ (*GetSockOptResponse_Opt)(nil),
+ }
+}
+
+type SetSockOptRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Level int64 `protobuf:"varint,2,opt,name=level,proto3" json:"level,omitempty"`
+ Name int64 `protobuf:"varint,3,opt,name=name,proto3" json:"name,omitempty"`
+ Opt []byte `protobuf:"bytes,4,opt,name=opt,proto3" json:"opt,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *SetSockOptRequest) Reset() { *m = SetSockOptRequest{} }
+func (m *SetSockOptRequest) String() string { return proto.CompactTextString(m) }
+func (*SetSockOptRequest) ProtoMessage() {}
+func (*SetSockOptRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{31}
+}
+
+func (m *SetSockOptRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_SetSockOptRequest.Unmarshal(m, b)
+}
+func (m *SetSockOptRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_SetSockOptRequest.Marshal(b, m, deterministic)
+}
+func (m *SetSockOptRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_SetSockOptRequest.Merge(m, src)
+}
+func (m *SetSockOptRequest) XXX_Size() int {
+ return xxx_messageInfo_SetSockOptRequest.Size(m)
+}
+func (m *SetSockOptRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_SetSockOptRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_SetSockOptRequest proto.InternalMessageInfo
+
+func (m *SetSockOptRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *SetSockOptRequest) GetLevel() int64 {
+ if m != nil {
+ return m.Level
+ }
+ return 0
+}
+
+func (m *SetSockOptRequest) GetName() int64 {
+ if m != nil {
+ return m.Name
+ }
+ return 0
+}
+
+func (m *SetSockOptRequest) GetOpt() []byte {
+ if m != nil {
+ return m.Opt
+ }
+ return nil
+}
+
+type SetSockOptResponse struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *SetSockOptResponse) Reset() { *m = SetSockOptResponse{} }
+func (m *SetSockOptResponse) String() string { return proto.CompactTextString(m) }
+func (*SetSockOptResponse) ProtoMessage() {}
+func (*SetSockOptResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{32}
+}
+
+func (m *SetSockOptResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_SetSockOptResponse.Unmarshal(m, b)
+}
+func (m *SetSockOptResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_SetSockOptResponse.Marshal(b, m, deterministic)
+}
+func (m *SetSockOptResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_SetSockOptResponse.Merge(m, src)
+}
+func (m *SetSockOptResponse) XXX_Size() int {
+ return xxx_messageInfo_SetSockOptResponse.Size(m)
+}
+func (m *SetSockOptResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_SetSockOptResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_SetSockOptResponse proto.InternalMessageInfo
+
+func (m *SetSockOptResponse) GetErrorNumber() uint32 {
+ if m != nil {
+ return m.ErrorNumber
+ }
+ return 0
+}
+
+type GetSockNameRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *GetSockNameRequest) Reset() { *m = GetSockNameRequest{} }
+func (m *GetSockNameRequest) String() string { return proto.CompactTextString(m) }
+func (*GetSockNameRequest) ProtoMessage() {}
+func (*GetSockNameRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{33}
+}
+
+func (m *GetSockNameRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_GetSockNameRequest.Unmarshal(m, b)
+}
+func (m *GetSockNameRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_GetSockNameRequest.Marshal(b, m, deterministic)
+}
+func (m *GetSockNameRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_GetSockNameRequest.Merge(m, src)
+}
+func (m *GetSockNameRequest) XXX_Size() int {
+ return xxx_messageInfo_GetSockNameRequest.Size(m)
+}
+func (m *GetSockNameRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_GetSockNameRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_GetSockNameRequest proto.InternalMessageInfo
+
+func (m *GetSockNameRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+type GetSockNameResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *GetSockNameResponse_ErrorNumber
+ // *GetSockNameResponse_Address
+ Result isGetSockNameResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *GetSockNameResponse) Reset() { *m = GetSockNameResponse{} }
+func (m *GetSockNameResponse) String() string { return proto.CompactTextString(m) }
+func (*GetSockNameResponse) ProtoMessage() {}
+func (*GetSockNameResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{34}
+}
+
+func (m *GetSockNameResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_GetSockNameResponse.Unmarshal(m, b)
+}
+func (m *GetSockNameResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_GetSockNameResponse.Marshal(b, m, deterministic)
+}
+func (m *GetSockNameResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_GetSockNameResponse.Merge(m, src)
+}
+func (m *GetSockNameResponse) XXX_Size() int {
+ return xxx_messageInfo_GetSockNameResponse.Size(m)
+}
+func (m *GetSockNameResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_GetSockNameResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_GetSockNameResponse proto.InternalMessageInfo
+
+type isGetSockNameResponse_Result interface {
+ isGetSockNameResponse_Result()
+}
+
+type GetSockNameResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type GetSockNameResponse_Address struct {
+ Address *AddressResponse `protobuf:"bytes,2,opt,name=address,proto3,oneof"`
+}
+
+func (*GetSockNameResponse_ErrorNumber) isGetSockNameResponse_Result() {}
+
+func (*GetSockNameResponse_Address) isGetSockNameResponse_Result() {}
+
+func (m *GetSockNameResponse) GetResult() isGetSockNameResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *GetSockNameResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*GetSockNameResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *GetSockNameResponse) GetAddress() *AddressResponse {
+ if x, ok := m.GetResult().(*GetSockNameResponse_Address); ok {
+ return x.Address
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*GetSockNameResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*GetSockNameResponse_ErrorNumber)(nil),
+ (*GetSockNameResponse_Address)(nil),
+ }
+}
+
+type GetPeerNameRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *GetPeerNameRequest) Reset() { *m = GetPeerNameRequest{} }
+func (m *GetPeerNameRequest) String() string { return proto.CompactTextString(m) }
+func (*GetPeerNameRequest) ProtoMessage() {}
+func (*GetPeerNameRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{35}
+}
+
+func (m *GetPeerNameRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_GetPeerNameRequest.Unmarshal(m, b)
+}
+func (m *GetPeerNameRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_GetPeerNameRequest.Marshal(b, m, deterministic)
+}
+func (m *GetPeerNameRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_GetPeerNameRequest.Merge(m, src)
+}
+func (m *GetPeerNameRequest) XXX_Size() int {
+ return xxx_messageInfo_GetPeerNameRequest.Size(m)
+}
+func (m *GetPeerNameRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_GetPeerNameRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_GetPeerNameRequest proto.InternalMessageInfo
+
+func (m *GetPeerNameRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+type GetPeerNameResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *GetPeerNameResponse_ErrorNumber
+ // *GetPeerNameResponse_Address
+ Result isGetPeerNameResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *GetPeerNameResponse) Reset() { *m = GetPeerNameResponse{} }
+func (m *GetPeerNameResponse) String() string { return proto.CompactTextString(m) }
+func (*GetPeerNameResponse) ProtoMessage() {}
+func (*GetPeerNameResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{36}
+}
+
+func (m *GetPeerNameResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_GetPeerNameResponse.Unmarshal(m, b)
+}
+func (m *GetPeerNameResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_GetPeerNameResponse.Marshal(b, m, deterministic)
+}
+func (m *GetPeerNameResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_GetPeerNameResponse.Merge(m, src)
+}
+func (m *GetPeerNameResponse) XXX_Size() int {
+ return xxx_messageInfo_GetPeerNameResponse.Size(m)
+}
+func (m *GetPeerNameResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_GetPeerNameResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_GetPeerNameResponse proto.InternalMessageInfo
+
+type isGetPeerNameResponse_Result interface {
+ isGetPeerNameResponse_Result()
+}
+
+type GetPeerNameResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type GetPeerNameResponse_Address struct {
+ Address *AddressResponse `protobuf:"bytes,2,opt,name=address,proto3,oneof"`
+}
+
+func (*GetPeerNameResponse_ErrorNumber) isGetPeerNameResponse_Result() {}
+
+func (*GetPeerNameResponse_Address) isGetPeerNameResponse_Result() {}
+
+func (m *GetPeerNameResponse) GetResult() isGetPeerNameResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *GetPeerNameResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*GetPeerNameResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *GetPeerNameResponse) GetAddress() *AddressResponse {
+ if x, ok := m.GetResult().(*GetPeerNameResponse_Address); ok {
+ return x.Address
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*GetPeerNameResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*GetPeerNameResponse_ErrorNumber)(nil),
+ (*GetPeerNameResponse_Address)(nil),
+ }
+}
+
+type SocketRequest struct {
+ Family int64 `protobuf:"varint,1,opt,name=family,proto3" json:"family,omitempty"`
+ Type int64 `protobuf:"varint,2,opt,name=type,proto3" json:"type,omitempty"`
+ Protocol int64 `protobuf:"varint,3,opt,name=protocol,proto3" json:"protocol,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *SocketRequest) Reset() { *m = SocketRequest{} }
+func (m *SocketRequest) String() string { return proto.CompactTextString(m) }
+func (*SocketRequest) ProtoMessage() {}
+func (*SocketRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{37}
+}
+
+func (m *SocketRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_SocketRequest.Unmarshal(m, b)
+}
+func (m *SocketRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_SocketRequest.Marshal(b, m, deterministic)
+}
+func (m *SocketRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_SocketRequest.Merge(m, src)
+}
+func (m *SocketRequest) XXX_Size() int {
+ return xxx_messageInfo_SocketRequest.Size(m)
+}
+func (m *SocketRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_SocketRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_SocketRequest proto.InternalMessageInfo
+
+func (m *SocketRequest) GetFamily() int64 {
+ if m != nil {
+ return m.Family
+ }
+ return 0
+}
+
+func (m *SocketRequest) GetType() int64 {
+ if m != nil {
+ return m.Type
+ }
+ return 0
+}
+
+func (m *SocketRequest) GetProtocol() int64 {
+ if m != nil {
+ return m.Protocol
+ }
+ return 0
+}
+
+type SocketResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *SocketResponse_ErrorNumber
+ // *SocketResponse_Fd
+ Result isSocketResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *SocketResponse) Reset() { *m = SocketResponse{} }
+func (m *SocketResponse) String() string { return proto.CompactTextString(m) }
+func (*SocketResponse) ProtoMessage() {}
+func (*SocketResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{38}
+}
+
+func (m *SocketResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_SocketResponse.Unmarshal(m, b)
+}
+func (m *SocketResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_SocketResponse.Marshal(b, m, deterministic)
+}
+func (m *SocketResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_SocketResponse.Merge(m, src)
+}
+func (m *SocketResponse) XXX_Size() int {
+ return xxx_messageInfo_SocketResponse.Size(m)
+}
+func (m *SocketResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_SocketResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_SocketResponse proto.InternalMessageInfo
+
+type isSocketResponse_Result interface {
+ isSocketResponse_Result()
+}
+
+type SocketResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type SocketResponse_Fd struct {
+ Fd uint32 `protobuf:"varint,2,opt,name=fd,proto3,oneof"`
+}
+
+func (*SocketResponse_ErrorNumber) isSocketResponse_Result() {}
+
+func (*SocketResponse_Fd) isSocketResponse_Result() {}
+
+func (m *SocketResponse) GetResult() isSocketResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *SocketResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*SocketResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *SocketResponse) GetFd() uint32 {
+ if x, ok := m.GetResult().(*SocketResponse_Fd); ok {
+ return x.Fd
+ }
+ return 0
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*SocketResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*SocketResponse_ErrorNumber)(nil),
+ (*SocketResponse_Fd)(nil),
+ }
+}
+
+type EpollWaitRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ NumEvents uint32 `protobuf:"varint,2,opt,name=num_events,json=numEvents,proto3" json:"num_events,omitempty"`
+ Msec int64 `protobuf:"zigzag64,3,opt,name=msec,proto3" json:"msec,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EpollWaitRequest) Reset() { *m = EpollWaitRequest{} }
+func (m *EpollWaitRequest) String() string { return proto.CompactTextString(m) }
+func (*EpollWaitRequest) ProtoMessage() {}
+func (*EpollWaitRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{39}
+}
+
+func (m *EpollWaitRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EpollWaitRequest.Unmarshal(m, b)
+}
+func (m *EpollWaitRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EpollWaitRequest.Marshal(b, m, deterministic)
+}
+func (m *EpollWaitRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EpollWaitRequest.Merge(m, src)
+}
+func (m *EpollWaitRequest) XXX_Size() int {
+ return xxx_messageInfo_EpollWaitRequest.Size(m)
+}
+func (m *EpollWaitRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_EpollWaitRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EpollWaitRequest proto.InternalMessageInfo
+
+func (m *EpollWaitRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *EpollWaitRequest) GetNumEvents() uint32 {
+ if m != nil {
+ return m.NumEvents
+ }
+ return 0
+}
+
+func (m *EpollWaitRequest) GetMsec() int64 {
+ if m != nil {
+ return m.Msec
+ }
+ return 0
+}
+
+type EpollEvent struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Events uint32 `protobuf:"varint,2,opt,name=events,proto3" json:"events,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EpollEvent) Reset() { *m = EpollEvent{} }
+func (m *EpollEvent) String() string { return proto.CompactTextString(m) }
+func (*EpollEvent) ProtoMessage() {}
+func (*EpollEvent) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{40}
+}
+
+func (m *EpollEvent) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EpollEvent.Unmarshal(m, b)
+}
+func (m *EpollEvent) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EpollEvent.Marshal(b, m, deterministic)
+}
+func (m *EpollEvent) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EpollEvent.Merge(m, src)
+}
+func (m *EpollEvent) XXX_Size() int {
+ return xxx_messageInfo_EpollEvent.Size(m)
+}
+func (m *EpollEvent) XXX_DiscardUnknown() {
+ xxx_messageInfo_EpollEvent.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EpollEvent proto.InternalMessageInfo
+
+func (m *EpollEvent) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *EpollEvent) GetEvents() uint32 {
+ if m != nil {
+ return m.Events
+ }
+ return 0
+}
+
+type EpollEvents struct {
+ Events []*EpollEvent `protobuf:"bytes,1,rep,name=events,proto3" json:"events,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EpollEvents) Reset() { *m = EpollEvents{} }
+func (m *EpollEvents) String() string { return proto.CompactTextString(m) }
+func (*EpollEvents) ProtoMessage() {}
+func (*EpollEvents) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{41}
+}
+
+func (m *EpollEvents) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EpollEvents.Unmarshal(m, b)
+}
+func (m *EpollEvents) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EpollEvents.Marshal(b, m, deterministic)
+}
+func (m *EpollEvents) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EpollEvents.Merge(m, src)
+}
+func (m *EpollEvents) XXX_Size() int {
+ return xxx_messageInfo_EpollEvents.Size(m)
+}
+func (m *EpollEvents) XXX_DiscardUnknown() {
+ xxx_messageInfo_EpollEvents.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EpollEvents proto.InternalMessageInfo
+
+func (m *EpollEvents) GetEvents() []*EpollEvent {
+ if m != nil {
+ return m.Events
+ }
+ return nil
+}
+
+type EpollWaitResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *EpollWaitResponse_ErrorNumber
+ // *EpollWaitResponse_Events
+ Result isEpollWaitResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EpollWaitResponse) Reset() { *m = EpollWaitResponse{} }
+func (m *EpollWaitResponse) String() string { return proto.CompactTextString(m) }
+func (*EpollWaitResponse) ProtoMessage() {}
+func (*EpollWaitResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{42}
+}
+
+func (m *EpollWaitResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EpollWaitResponse.Unmarshal(m, b)
+}
+func (m *EpollWaitResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EpollWaitResponse.Marshal(b, m, deterministic)
+}
+func (m *EpollWaitResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EpollWaitResponse.Merge(m, src)
+}
+func (m *EpollWaitResponse) XXX_Size() int {
+ return xxx_messageInfo_EpollWaitResponse.Size(m)
+}
+func (m *EpollWaitResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_EpollWaitResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EpollWaitResponse proto.InternalMessageInfo
+
+type isEpollWaitResponse_Result interface {
+ isEpollWaitResponse_Result()
+}
+
+type EpollWaitResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type EpollWaitResponse_Events struct {
+ Events *EpollEvents `protobuf:"bytes,2,opt,name=events,proto3,oneof"`
+}
+
+func (*EpollWaitResponse_ErrorNumber) isEpollWaitResponse_Result() {}
+
+func (*EpollWaitResponse_Events) isEpollWaitResponse_Result() {}
+
+func (m *EpollWaitResponse) GetResult() isEpollWaitResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *EpollWaitResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*EpollWaitResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *EpollWaitResponse) GetEvents() *EpollEvents {
+ if x, ok := m.GetResult().(*EpollWaitResponse_Events); ok {
+ return x.Events
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*EpollWaitResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*EpollWaitResponse_ErrorNumber)(nil),
+ (*EpollWaitResponse_Events)(nil),
+ }
+}
+
+type EpollCtlRequest struct {
+ Epfd uint32 `protobuf:"varint,1,opt,name=epfd,proto3" json:"epfd,omitempty"`
+ Op int64 `protobuf:"varint,2,opt,name=op,proto3" json:"op,omitempty"`
+ Fd uint32 `protobuf:"varint,3,opt,name=fd,proto3" json:"fd,omitempty"`
+ Event *EpollEvent `protobuf:"bytes,4,opt,name=event,proto3" json:"event,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EpollCtlRequest) Reset() { *m = EpollCtlRequest{} }
+func (m *EpollCtlRequest) String() string { return proto.CompactTextString(m) }
+func (*EpollCtlRequest) ProtoMessage() {}
+func (*EpollCtlRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{43}
+}
+
+func (m *EpollCtlRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EpollCtlRequest.Unmarshal(m, b)
+}
+func (m *EpollCtlRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EpollCtlRequest.Marshal(b, m, deterministic)
+}
+func (m *EpollCtlRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EpollCtlRequest.Merge(m, src)
+}
+func (m *EpollCtlRequest) XXX_Size() int {
+ return xxx_messageInfo_EpollCtlRequest.Size(m)
+}
+func (m *EpollCtlRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_EpollCtlRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EpollCtlRequest proto.InternalMessageInfo
+
+func (m *EpollCtlRequest) GetEpfd() uint32 {
+ if m != nil {
+ return m.Epfd
+ }
+ return 0
+}
+
+func (m *EpollCtlRequest) GetOp() int64 {
+ if m != nil {
+ return m.Op
+ }
+ return 0
+}
+
+func (m *EpollCtlRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *EpollCtlRequest) GetEvent() *EpollEvent {
+ if m != nil {
+ return m.Event
+ }
+ return nil
+}
+
+type EpollCtlResponse struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EpollCtlResponse) Reset() { *m = EpollCtlResponse{} }
+func (m *EpollCtlResponse) String() string { return proto.CompactTextString(m) }
+func (*EpollCtlResponse) ProtoMessage() {}
+func (*EpollCtlResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{44}
+}
+
+func (m *EpollCtlResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EpollCtlResponse.Unmarshal(m, b)
+}
+func (m *EpollCtlResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EpollCtlResponse.Marshal(b, m, deterministic)
+}
+func (m *EpollCtlResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EpollCtlResponse.Merge(m, src)
+}
+func (m *EpollCtlResponse) XXX_Size() int {
+ return xxx_messageInfo_EpollCtlResponse.Size(m)
+}
+func (m *EpollCtlResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_EpollCtlResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EpollCtlResponse proto.InternalMessageInfo
+
+func (m *EpollCtlResponse) GetErrorNumber() uint32 {
+ if m != nil {
+ return m.ErrorNumber
+ }
+ return 0
+}
+
+type EpollCreate1Request struct {
+ Flag int64 `protobuf:"varint,1,opt,name=flag,proto3" json:"flag,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EpollCreate1Request) Reset() { *m = EpollCreate1Request{} }
+func (m *EpollCreate1Request) String() string { return proto.CompactTextString(m) }
+func (*EpollCreate1Request) ProtoMessage() {}
+func (*EpollCreate1Request) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{45}
+}
+
+func (m *EpollCreate1Request) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EpollCreate1Request.Unmarshal(m, b)
+}
+func (m *EpollCreate1Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EpollCreate1Request.Marshal(b, m, deterministic)
+}
+func (m *EpollCreate1Request) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EpollCreate1Request.Merge(m, src)
+}
+func (m *EpollCreate1Request) XXX_Size() int {
+ return xxx_messageInfo_EpollCreate1Request.Size(m)
+}
+func (m *EpollCreate1Request) XXX_DiscardUnknown() {
+ xxx_messageInfo_EpollCreate1Request.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EpollCreate1Request proto.InternalMessageInfo
+
+func (m *EpollCreate1Request) GetFlag() int64 {
+ if m != nil {
+ return m.Flag
+ }
+ return 0
+}
+
+type EpollCreate1Response struct {
+ // Types that are valid to be assigned to Result:
+ // *EpollCreate1Response_ErrorNumber
+ // *EpollCreate1Response_Fd
+ Result isEpollCreate1Response_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *EpollCreate1Response) Reset() { *m = EpollCreate1Response{} }
+func (m *EpollCreate1Response) String() string { return proto.CompactTextString(m) }
+func (*EpollCreate1Response) ProtoMessage() {}
+func (*EpollCreate1Response) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{46}
+}
+
+func (m *EpollCreate1Response) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_EpollCreate1Response.Unmarshal(m, b)
+}
+func (m *EpollCreate1Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_EpollCreate1Response.Marshal(b, m, deterministic)
+}
+func (m *EpollCreate1Response) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_EpollCreate1Response.Merge(m, src)
+}
+func (m *EpollCreate1Response) XXX_Size() int {
+ return xxx_messageInfo_EpollCreate1Response.Size(m)
+}
+func (m *EpollCreate1Response) XXX_DiscardUnknown() {
+ xxx_messageInfo_EpollCreate1Response.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_EpollCreate1Response proto.InternalMessageInfo
+
+type isEpollCreate1Response_Result interface {
+ isEpollCreate1Response_Result()
+}
+
+type EpollCreate1Response_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type EpollCreate1Response_Fd struct {
+ Fd uint32 `protobuf:"varint,2,opt,name=fd,proto3,oneof"`
+}
+
+func (*EpollCreate1Response_ErrorNumber) isEpollCreate1Response_Result() {}
+
+func (*EpollCreate1Response_Fd) isEpollCreate1Response_Result() {}
+
+func (m *EpollCreate1Response) GetResult() isEpollCreate1Response_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *EpollCreate1Response) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*EpollCreate1Response_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *EpollCreate1Response) GetFd() uint32 {
+ if x, ok := m.GetResult().(*EpollCreate1Response_Fd); ok {
+ return x.Fd
+ }
+ return 0
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*EpollCreate1Response) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*EpollCreate1Response_ErrorNumber)(nil),
+ (*EpollCreate1Response_Fd)(nil),
+ }
+}
+
+type PollRequest struct {
+ Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"`
+ Events uint32 `protobuf:"varint,2,opt,name=events,proto3" json:"events,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *PollRequest) Reset() { *m = PollRequest{} }
+func (m *PollRequest) String() string { return proto.CompactTextString(m) }
+func (*PollRequest) ProtoMessage() {}
+func (*PollRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{47}
+}
+
+func (m *PollRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_PollRequest.Unmarshal(m, b)
+}
+func (m *PollRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_PollRequest.Marshal(b, m, deterministic)
+}
+func (m *PollRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_PollRequest.Merge(m, src)
+}
+func (m *PollRequest) XXX_Size() int {
+ return xxx_messageInfo_PollRequest.Size(m)
+}
+func (m *PollRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_PollRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_PollRequest proto.InternalMessageInfo
+
+func (m *PollRequest) GetFd() uint32 {
+ if m != nil {
+ return m.Fd
+ }
+ return 0
+}
+
+func (m *PollRequest) GetEvents() uint32 {
+ if m != nil {
+ return m.Events
+ }
+ return 0
+}
+
+type PollResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *PollResponse_ErrorNumber
+ // *PollResponse_Events
+ Result isPollResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *PollResponse) Reset() { *m = PollResponse{} }
+func (m *PollResponse) String() string { return proto.CompactTextString(m) }
+func (*PollResponse) ProtoMessage() {}
+func (*PollResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{48}
+}
+
+func (m *PollResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_PollResponse.Unmarshal(m, b)
+}
+func (m *PollResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_PollResponse.Marshal(b, m, deterministic)
+}
+func (m *PollResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_PollResponse.Merge(m, src)
+}
+func (m *PollResponse) XXX_Size() int {
+ return xxx_messageInfo_PollResponse.Size(m)
+}
+func (m *PollResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_PollResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_PollResponse proto.InternalMessageInfo
+
+type isPollResponse_Result interface {
+ isPollResponse_Result()
+}
+
+type PollResponse_ErrorNumber struct {
+ ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"`
+}
+
+type PollResponse_Events struct {
+ Events uint32 `protobuf:"varint,2,opt,name=events,proto3,oneof"`
+}
+
+func (*PollResponse_ErrorNumber) isPollResponse_Result() {}
+
+func (*PollResponse_Events) isPollResponse_Result() {}
+
+func (m *PollResponse) GetResult() isPollResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *PollResponse) GetErrorNumber() uint32 {
+ if x, ok := m.GetResult().(*PollResponse_ErrorNumber); ok {
+ return x.ErrorNumber
+ }
+ return 0
+}
+
+func (m *PollResponse) GetEvents() uint32 {
+ if x, ok := m.GetResult().(*PollResponse_Events); ok {
+ return x.Events
+ }
+ return 0
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*PollResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*PollResponse_ErrorNumber)(nil),
+ (*PollResponse_Events)(nil),
+ }
+}
+
+type SyscallRequest struct {
+ // Types that are valid to be assigned to Args:
+ // *SyscallRequest_Socket
+ // *SyscallRequest_Sendmsg
+ // *SyscallRequest_Recvmsg
+ // *SyscallRequest_Bind
+ // *SyscallRequest_Accept
+ // *SyscallRequest_Connect
+ // *SyscallRequest_Listen
+ // *SyscallRequest_Shutdown
+ // *SyscallRequest_Close
+ // *SyscallRequest_GetSockOpt
+ // *SyscallRequest_SetSockOpt
+ // *SyscallRequest_GetSockName
+ // *SyscallRequest_GetPeerName
+ // *SyscallRequest_EpollWait
+ // *SyscallRequest_EpollCtl
+ // *SyscallRequest_EpollCreate1
+ // *SyscallRequest_Poll
+ // *SyscallRequest_Read
+ // *SyscallRequest_Write
+ // *SyscallRequest_Open
+ // *SyscallRequest_Ioctl
+ // *SyscallRequest_WriteFile
+ // *SyscallRequest_ReadFile
+ Args isSyscallRequest_Args `protobuf_oneof:"args"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *SyscallRequest) Reset() { *m = SyscallRequest{} }
+func (m *SyscallRequest) String() string { return proto.CompactTextString(m) }
+func (*SyscallRequest) ProtoMessage() {}
+func (*SyscallRequest) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{49}
+}
+
+func (m *SyscallRequest) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_SyscallRequest.Unmarshal(m, b)
+}
+func (m *SyscallRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_SyscallRequest.Marshal(b, m, deterministic)
+}
+func (m *SyscallRequest) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_SyscallRequest.Merge(m, src)
+}
+func (m *SyscallRequest) XXX_Size() int {
+ return xxx_messageInfo_SyscallRequest.Size(m)
+}
+func (m *SyscallRequest) XXX_DiscardUnknown() {
+ xxx_messageInfo_SyscallRequest.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_SyscallRequest proto.InternalMessageInfo
+
+type isSyscallRequest_Args interface {
+ isSyscallRequest_Args()
+}
+
+type SyscallRequest_Socket struct {
+ Socket *SocketRequest `protobuf:"bytes,1,opt,name=socket,proto3,oneof"`
+}
+
+type SyscallRequest_Sendmsg struct {
+ Sendmsg *SendmsgRequest `protobuf:"bytes,2,opt,name=sendmsg,proto3,oneof"`
+}
+
+type SyscallRequest_Recvmsg struct {
+ Recvmsg *RecvmsgRequest `protobuf:"bytes,3,opt,name=recvmsg,proto3,oneof"`
+}
+
+type SyscallRequest_Bind struct {
+ Bind *BindRequest `protobuf:"bytes,4,opt,name=bind,proto3,oneof"`
+}
+
+type SyscallRequest_Accept struct {
+ Accept *AcceptRequest `protobuf:"bytes,5,opt,name=accept,proto3,oneof"`
+}
+
+type SyscallRequest_Connect struct {
+ Connect *ConnectRequest `protobuf:"bytes,6,opt,name=connect,proto3,oneof"`
+}
+
+type SyscallRequest_Listen struct {
+ Listen *ListenRequest `protobuf:"bytes,7,opt,name=listen,proto3,oneof"`
+}
+
+type SyscallRequest_Shutdown struct {
+ Shutdown *ShutdownRequest `protobuf:"bytes,8,opt,name=shutdown,proto3,oneof"`
+}
+
+type SyscallRequest_Close struct {
+ Close *CloseRequest `protobuf:"bytes,9,opt,name=close,proto3,oneof"`
+}
+
+type SyscallRequest_GetSockOpt struct {
+ GetSockOpt *GetSockOptRequest `protobuf:"bytes,10,opt,name=get_sock_opt,json=getSockOpt,proto3,oneof"`
+}
+
+type SyscallRequest_SetSockOpt struct {
+ SetSockOpt *SetSockOptRequest `protobuf:"bytes,11,opt,name=set_sock_opt,json=setSockOpt,proto3,oneof"`
+}
+
+type SyscallRequest_GetSockName struct {
+ GetSockName *GetSockNameRequest `protobuf:"bytes,12,opt,name=get_sock_name,json=getSockName,proto3,oneof"`
+}
+
+type SyscallRequest_GetPeerName struct {
+ GetPeerName *GetPeerNameRequest `protobuf:"bytes,13,opt,name=get_peer_name,json=getPeerName,proto3,oneof"`
+}
+
+type SyscallRequest_EpollWait struct {
+ EpollWait *EpollWaitRequest `protobuf:"bytes,14,opt,name=epoll_wait,json=epollWait,proto3,oneof"`
+}
+
+type SyscallRequest_EpollCtl struct {
+ EpollCtl *EpollCtlRequest `protobuf:"bytes,15,opt,name=epoll_ctl,json=epollCtl,proto3,oneof"`
+}
+
+type SyscallRequest_EpollCreate1 struct {
+ EpollCreate1 *EpollCreate1Request `protobuf:"bytes,16,opt,name=epoll_create1,json=epollCreate1,proto3,oneof"`
+}
+
+type SyscallRequest_Poll struct {
+ Poll *PollRequest `protobuf:"bytes,17,opt,name=poll,proto3,oneof"`
+}
+
+type SyscallRequest_Read struct {
+ Read *ReadRequest `protobuf:"bytes,18,opt,name=read,proto3,oneof"`
+}
+
+type SyscallRequest_Write struct {
+ Write *WriteRequest `protobuf:"bytes,19,opt,name=write,proto3,oneof"`
+}
+
+type SyscallRequest_Open struct {
+ Open *OpenRequest `protobuf:"bytes,20,opt,name=open,proto3,oneof"`
+}
+
+type SyscallRequest_Ioctl struct {
+ Ioctl *IOCtlRequest `protobuf:"bytes,21,opt,name=ioctl,proto3,oneof"`
+}
+
+type SyscallRequest_WriteFile struct {
+ WriteFile *WriteFileRequest `protobuf:"bytes,22,opt,name=write_file,json=writeFile,proto3,oneof"`
+}
+
+type SyscallRequest_ReadFile struct {
+ ReadFile *ReadFileRequest `protobuf:"bytes,23,opt,name=read_file,json=readFile,proto3,oneof"`
+}
+
+func (*SyscallRequest_Socket) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Sendmsg) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Recvmsg) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Bind) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Accept) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Connect) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Listen) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Shutdown) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Close) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_GetSockOpt) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_SetSockOpt) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_GetSockName) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_GetPeerName) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_EpollWait) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_EpollCtl) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_EpollCreate1) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Poll) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Read) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Write) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Open) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_Ioctl) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_WriteFile) isSyscallRequest_Args() {}
+
+func (*SyscallRequest_ReadFile) isSyscallRequest_Args() {}
+
+func (m *SyscallRequest) GetArgs() isSyscallRequest_Args {
+ if m != nil {
+ return m.Args
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetSocket() *SocketRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Socket); ok {
+ return x.Socket
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetSendmsg() *SendmsgRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Sendmsg); ok {
+ return x.Sendmsg
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetRecvmsg() *RecvmsgRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Recvmsg); ok {
+ return x.Recvmsg
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetBind() *BindRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Bind); ok {
+ return x.Bind
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetAccept() *AcceptRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Accept); ok {
+ return x.Accept
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetConnect() *ConnectRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Connect); ok {
+ return x.Connect
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetListen() *ListenRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Listen); ok {
+ return x.Listen
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetShutdown() *ShutdownRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Shutdown); ok {
+ return x.Shutdown
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetClose() *CloseRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Close); ok {
+ return x.Close
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetGetSockOpt() *GetSockOptRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_GetSockOpt); ok {
+ return x.GetSockOpt
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetSetSockOpt() *SetSockOptRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_SetSockOpt); ok {
+ return x.SetSockOpt
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetGetSockName() *GetSockNameRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_GetSockName); ok {
+ return x.GetSockName
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetGetPeerName() *GetPeerNameRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_GetPeerName); ok {
+ return x.GetPeerName
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetEpollWait() *EpollWaitRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_EpollWait); ok {
+ return x.EpollWait
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetEpollCtl() *EpollCtlRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_EpollCtl); ok {
+ return x.EpollCtl
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetEpollCreate1() *EpollCreate1Request {
+ if x, ok := m.GetArgs().(*SyscallRequest_EpollCreate1); ok {
+ return x.EpollCreate1
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetPoll() *PollRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Poll); ok {
+ return x.Poll
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetRead() *ReadRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Read); ok {
+ return x.Read
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetWrite() *WriteRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Write); ok {
+ return x.Write
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetOpen() *OpenRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Open); ok {
+ return x.Open
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetIoctl() *IOCtlRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_Ioctl); ok {
+ return x.Ioctl
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetWriteFile() *WriteFileRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_WriteFile); ok {
+ return x.WriteFile
+ }
+ return nil
+}
+
+func (m *SyscallRequest) GetReadFile() *ReadFileRequest {
+ if x, ok := m.GetArgs().(*SyscallRequest_ReadFile); ok {
+ return x.ReadFile
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*SyscallRequest) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*SyscallRequest_Socket)(nil),
+ (*SyscallRequest_Sendmsg)(nil),
+ (*SyscallRequest_Recvmsg)(nil),
+ (*SyscallRequest_Bind)(nil),
+ (*SyscallRequest_Accept)(nil),
+ (*SyscallRequest_Connect)(nil),
+ (*SyscallRequest_Listen)(nil),
+ (*SyscallRequest_Shutdown)(nil),
+ (*SyscallRequest_Close)(nil),
+ (*SyscallRequest_GetSockOpt)(nil),
+ (*SyscallRequest_SetSockOpt)(nil),
+ (*SyscallRequest_GetSockName)(nil),
+ (*SyscallRequest_GetPeerName)(nil),
+ (*SyscallRequest_EpollWait)(nil),
+ (*SyscallRequest_EpollCtl)(nil),
+ (*SyscallRequest_EpollCreate1)(nil),
+ (*SyscallRequest_Poll)(nil),
+ (*SyscallRequest_Read)(nil),
+ (*SyscallRequest_Write)(nil),
+ (*SyscallRequest_Open)(nil),
+ (*SyscallRequest_Ioctl)(nil),
+ (*SyscallRequest_WriteFile)(nil),
+ (*SyscallRequest_ReadFile)(nil),
+ }
+}
+
+type SyscallResponse struct {
+ // Types that are valid to be assigned to Result:
+ // *SyscallResponse_Socket
+ // *SyscallResponse_Sendmsg
+ // *SyscallResponse_Recvmsg
+ // *SyscallResponse_Bind
+ // *SyscallResponse_Accept
+ // *SyscallResponse_Connect
+ // *SyscallResponse_Listen
+ // *SyscallResponse_Shutdown
+ // *SyscallResponse_Close
+ // *SyscallResponse_GetSockOpt
+ // *SyscallResponse_SetSockOpt
+ // *SyscallResponse_GetSockName
+ // *SyscallResponse_GetPeerName
+ // *SyscallResponse_EpollWait
+ // *SyscallResponse_EpollCtl
+ // *SyscallResponse_EpollCreate1
+ // *SyscallResponse_Poll
+ // *SyscallResponse_Read
+ // *SyscallResponse_Write
+ // *SyscallResponse_Open
+ // *SyscallResponse_Ioctl
+ // *SyscallResponse_WriteFile
+ // *SyscallResponse_ReadFile
+ Result isSyscallResponse_Result `protobuf_oneof:"result"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *SyscallResponse) Reset() { *m = SyscallResponse{} }
+func (m *SyscallResponse) String() string { return proto.CompactTextString(m) }
+func (*SyscallResponse) ProtoMessage() {}
+func (*SyscallResponse) Descriptor() ([]byte, []int) {
+ return fileDescriptor_dd04f3a8f0c5288b, []int{50}
+}
+
+func (m *SyscallResponse) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_SyscallResponse.Unmarshal(m, b)
+}
+func (m *SyscallResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_SyscallResponse.Marshal(b, m, deterministic)
+}
+func (m *SyscallResponse) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_SyscallResponse.Merge(m, src)
+}
+func (m *SyscallResponse) XXX_Size() int {
+ return xxx_messageInfo_SyscallResponse.Size(m)
+}
+func (m *SyscallResponse) XXX_DiscardUnknown() {
+ xxx_messageInfo_SyscallResponse.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_SyscallResponse proto.InternalMessageInfo
+
+type isSyscallResponse_Result interface {
+ isSyscallResponse_Result()
+}
+
+type SyscallResponse_Socket struct {
+ Socket *SocketResponse `protobuf:"bytes,1,opt,name=socket,proto3,oneof"`
+}
+
+type SyscallResponse_Sendmsg struct {
+ Sendmsg *SendmsgResponse `protobuf:"bytes,2,opt,name=sendmsg,proto3,oneof"`
+}
+
+type SyscallResponse_Recvmsg struct {
+ Recvmsg *RecvmsgResponse `protobuf:"bytes,3,opt,name=recvmsg,proto3,oneof"`
+}
+
+type SyscallResponse_Bind struct {
+ Bind *BindResponse `protobuf:"bytes,4,opt,name=bind,proto3,oneof"`
+}
+
+type SyscallResponse_Accept struct {
+ Accept *AcceptResponse `protobuf:"bytes,5,opt,name=accept,proto3,oneof"`
+}
+
+type SyscallResponse_Connect struct {
+ Connect *ConnectResponse `protobuf:"bytes,6,opt,name=connect,proto3,oneof"`
+}
+
+type SyscallResponse_Listen struct {
+ Listen *ListenResponse `protobuf:"bytes,7,opt,name=listen,proto3,oneof"`
+}
+
+type SyscallResponse_Shutdown struct {
+ Shutdown *ShutdownResponse `protobuf:"bytes,8,opt,name=shutdown,proto3,oneof"`
+}
+
+type SyscallResponse_Close struct {
+ Close *CloseResponse `protobuf:"bytes,9,opt,name=close,proto3,oneof"`
+}
+
+type SyscallResponse_GetSockOpt struct {
+ GetSockOpt *GetSockOptResponse `protobuf:"bytes,10,opt,name=get_sock_opt,json=getSockOpt,proto3,oneof"`
+}
+
+type SyscallResponse_SetSockOpt struct {
+ SetSockOpt *SetSockOptResponse `protobuf:"bytes,11,opt,name=set_sock_opt,json=setSockOpt,proto3,oneof"`
+}
+
+type SyscallResponse_GetSockName struct {
+ GetSockName *GetSockNameResponse `protobuf:"bytes,12,opt,name=get_sock_name,json=getSockName,proto3,oneof"`
+}
+
+type SyscallResponse_GetPeerName struct {
+ GetPeerName *GetPeerNameResponse `protobuf:"bytes,13,opt,name=get_peer_name,json=getPeerName,proto3,oneof"`
+}
+
+type SyscallResponse_EpollWait struct {
+ EpollWait *EpollWaitResponse `protobuf:"bytes,14,opt,name=epoll_wait,json=epollWait,proto3,oneof"`
+}
+
+type SyscallResponse_EpollCtl struct {
+ EpollCtl *EpollCtlResponse `protobuf:"bytes,15,opt,name=epoll_ctl,json=epollCtl,proto3,oneof"`
+}
+
+type SyscallResponse_EpollCreate1 struct {
+ EpollCreate1 *EpollCreate1Response `protobuf:"bytes,16,opt,name=epoll_create1,json=epollCreate1,proto3,oneof"`
+}
+
+type SyscallResponse_Poll struct {
+ Poll *PollResponse `protobuf:"bytes,17,opt,name=poll,proto3,oneof"`
+}
+
+type SyscallResponse_Read struct {
+ Read *ReadResponse `protobuf:"bytes,18,opt,name=read,proto3,oneof"`
+}
+
+type SyscallResponse_Write struct {
+ Write *WriteResponse `protobuf:"bytes,19,opt,name=write,proto3,oneof"`
+}
+
+type SyscallResponse_Open struct {
+ Open *OpenResponse `protobuf:"bytes,20,opt,name=open,proto3,oneof"`
+}
+
+type SyscallResponse_Ioctl struct {
+ Ioctl *IOCtlResponse `protobuf:"bytes,21,opt,name=ioctl,proto3,oneof"`
+}
+
+type SyscallResponse_WriteFile struct {
+ WriteFile *WriteFileResponse `protobuf:"bytes,22,opt,name=write_file,json=writeFile,proto3,oneof"`
+}
+
+type SyscallResponse_ReadFile struct {
+ ReadFile *ReadFileResponse `protobuf:"bytes,23,opt,name=read_file,json=readFile,proto3,oneof"`
+}
+
+func (*SyscallResponse_Socket) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Sendmsg) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Recvmsg) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Bind) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Accept) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Connect) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Listen) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Shutdown) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Close) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_GetSockOpt) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_SetSockOpt) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_GetSockName) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_GetPeerName) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_EpollWait) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_EpollCtl) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_EpollCreate1) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Poll) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Read) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Write) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Open) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_Ioctl) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_WriteFile) isSyscallResponse_Result() {}
+
+func (*SyscallResponse_ReadFile) isSyscallResponse_Result() {}
+
+func (m *SyscallResponse) GetResult() isSyscallResponse_Result {
+ if m != nil {
+ return m.Result
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetSocket() *SocketResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Socket); ok {
+ return x.Socket
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetSendmsg() *SendmsgResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Sendmsg); ok {
+ return x.Sendmsg
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetRecvmsg() *RecvmsgResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Recvmsg); ok {
+ return x.Recvmsg
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetBind() *BindResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Bind); ok {
+ return x.Bind
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetAccept() *AcceptResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Accept); ok {
+ return x.Accept
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetConnect() *ConnectResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Connect); ok {
+ return x.Connect
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetListen() *ListenResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Listen); ok {
+ return x.Listen
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetShutdown() *ShutdownResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Shutdown); ok {
+ return x.Shutdown
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetClose() *CloseResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Close); ok {
+ return x.Close
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetGetSockOpt() *GetSockOptResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_GetSockOpt); ok {
+ return x.GetSockOpt
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetSetSockOpt() *SetSockOptResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_SetSockOpt); ok {
+ return x.SetSockOpt
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetGetSockName() *GetSockNameResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_GetSockName); ok {
+ return x.GetSockName
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetGetPeerName() *GetPeerNameResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_GetPeerName); ok {
+ return x.GetPeerName
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetEpollWait() *EpollWaitResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_EpollWait); ok {
+ return x.EpollWait
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetEpollCtl() *EpollCtlResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_EpollCtl); ok {
+ return x.EpollCtl
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetEpollCreate1() *EpollCreate1Response {
+ if x, ok := m.GetResult().(*SyscallResponse_EpollCreate1); ok {
+ return x.EpollCreate1
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetPoll() *PollResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Poll); ok {
+ return x.Poll
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetRead() *ReadResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Read); ok {
+ return x.Read
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetWrite() *WriteResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Write); ok {
+ return x.Write
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetOpen() *OpenResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Open); ok {
+ return x.Open
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetIoctl() *IOCtlResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_Ioctl); ok {
+ return x.Ioctl
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetWriteFile() *WriteFileResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_WriteFile); ok {
+ return x.WriteFile
+ }
+ return nil
+}
+
+func (m *SyscallResponse) GetReadFile() *ReadFileResponse {
+ if x, ok := m.GetResult().(*SyscallResponse_ReadFile); ok {
+ return x.ReadFile
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*SyscallResponse) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*SyscallResponse_Socket)(nil),
+ (*SyscallResponse_Sendmsg)(nil),
+ (*SyscallResponse_Recvmsg)(nil),
+ (*SyscallResponse_Bind)(nil),
+ (*SyscallResponse_Accept)(nil),
+ (*SyscallResponse_Connect)(nil),
+ (*SyscallResponse_Listen)(nil),
+ (*SyscallResponse_Shutdown)(nil),
+ (*SyscallResponse_Close)(nil),
+ (*SyscallResponse_GetSockOpt)(nil),
+ (*SyscallResponse_SetSockOpt)(nil),
+ (*SyscallResponse_GetSockName)(nil),
+ (*SyscallResponse_GetPeerName)(nil),
+ (*SyscallResponse_EpollWait)(nil),
+ (*SyscallResponse_EpollCtl)(nil),
+ (*SyscallResponse_EpollCreate1)(nil),
+ (*SyscallResponse_Poll)(nil),
+ (*SyscallResponse_Read)(nil),
+ (*SyscallResponse_Write)(nil),
+ (*SyscallResponse_Open)(nil),
+ (*SyscallResponse_Ioctl)(nil),
+ (*SyscallResponse_WriteFile)(nil),
+ (*SyscallResponse_ReadFile)(nil),
+ }
+}
+
+func init() {
+ proto.RegisterType((*SendmsgRequest)(nil), "syscall_rpc.SendmsgRequest")
+ proto.RegisterType((*SendmsgResponse)(nil), "syscall_rpc.SendmsgResponse")
+ proto.RegisterType((*IOCtlRequest)(nil), "syscall_rpc.IOCtlRequest")
+ proto.RegisterType((*IOCtlResponse)(nil), "syscall_rpc.IOCtlResponse")
+ proto.RegisterType((*RecvmsgRequest)(nil), "syscall_rpc.RecvmsgRequest")
+ proto.RegisterType((*OpenRequest)(nil), "syscall_rpc.OpenRequest")
+ proto.RegisterType((*OpenResponse)(nil), "syscall_rpc.OpenResponse")
+ proto.RegisterType((*ReadRequest)(nil), "syscall_rpc.ReadRequest")
+ proto.RegisterType((*ReadResponse)(nil), "syscall_rpc.ReadResponse")
+ proto.RegisterType((*ReadFileRequest)(nil), "syscall_rpc.ReadFileRequest")
+ proto.RegisterType((*ReadFileResponse)(nil), "syscall_rpc.ReadFileResponse")
+ proto.RegisterType((*WriteRequest)(nil), "syscall_rpc.WriteRequest")
+ proto.RegisterType((*WriteResponse)(nil), "syscall_rpc.WriteResponse")
+ proto.RegisterType((*WriteFileRequest)(nil), "syscall_rpc.WriteFileRequest")
+ proto.RegisterType((*WriteFileResponse)(nil), "syscall_rpc.WriteFileResponse")
+ proto.RegisterType((*AddressResponse)(nil), "syscall_rpc.AddressResponse")
+ proto.RegisterType((*RecvmsgResponse)(nil), "syscall_rpc.RecvmsgResponse")
+ proto.RegisterType((*RecvmsgResponse_ResultPayload)(nil), "syscall_rpc.RecvmsgResponse.ResultPayload")
+ proto.RegisterType((*BindRequest)(nil), "syscall_rpc.BindRequest")
+ proto.RegisterType((*BindResponse)(nil), "syscall_rpc.BindResponse")
+ proto.RegisterType((*AcceptRequest)(nil), "syscall_rpc.AcceptRequest")
+ proto.RegisterType((*AcceptResponse)(nil), "syscall_rpc.AcceptResponse")
+ proto.RegisterType((*AcceptResponse_ResultPayload)(nil), "syscall_rpc.AcceptResponse.ResultPayload")
+ proto.RegisterType((*ConnectRequest)(nil), "syscall_rpc.ConnectRequest")
+ proto.RegisterType((*ConnectResponse)(nil), "syscall_rpc.ConnectResponse")
+ proto.RegisterType((*ListenRequest)(nil), "syscall_rpc.ListenRequest")
+ proto.RegisterType((*ListenResponse)(nil), "syscall_rpc.ListenResponse")
+ proto.RegisterType((*ShutdownRequest)(nil), "syscall_rpc.ShutdownRequest")
+ proto.RegisterType((*ShutdownResponse)(nil), "syscall_rpc.ShutdownResponse")
+ proto.RegisterType((*CloseRequest)(nil), "syscall_rpc.CloseRequest")
+ proto.RegisterType((*CloseResponse)(nil), "syscall_rpc.CloseResponse")
+ proto.RegisterType((*GetSockOptRequest)(nil), "syscall_rpc.GetSockOptRequest")
+ proto.RegisterType((*GetSockOptResponse)(nil), "syscall_rpc.GetSockOptResponse")
+ proto.RegisterType((*SetSockOptRequest)(nil), "syscall_rpc.SetSockOptRequest")
+ proto.RegisterType((*SetSockOptResponse)(nil), "syscall_rpc.SetSockOptResponse")
+ proto.RegisterType((*GetSockNameRequest)(nil), "syscall_rpc.GetSockNameRequest")
+ proto.RegisterType((*GetSockNameResponse)(nil), "syscall_rpc.GetSockNameResponse")
+ proto.RegisterType((*GetPeerNameRequest)(nil), "syscall_rpc.GetPeerNameRequest")
+ proto.RegisterType((*GetPeerNameResponse)(nil), "syscall_rpc.GetPeerNameResponse")
+ proto.RegisterType((*SocketRequest)(nil), "syscall_rpc.SocketRequest")
+ proto.RegisterType((*SocketResponse)(nil), "syscall_rpc.SocketResponse")
+ proto.RegisterType((*EpollWaitRequest)(nil), "syscall_rpc.EpollWaitRequest")
+ proto.RegisterType((*EpollEvent)(nil), "syscall_rpc.EpollEvent")
+ proto.RegisterType((*EpollEvents)(nil), "syscall_rpc.EpollEvents")
+ proto.RegisterType((*EpollWaitResponse)(nil), "syscall_rpc.EpollWaitResponse")
+ proto.RegisterType((*EpollCtlRequest)(nil), "syscall_rpc.EpollCtlRequest")
+ proto.RegisterType((*EpollCtlResponse)(nil), "syscall_rpc.EpollCtlResponse")
+ proto.RegisterType((*EpollCreate1Request)(nil), "syscall_rpc.EpollCreate1Request")
+ proto.RegisterType((*EpollCreate1Response)(nil), "syscall_rpc.EpollCreate1Response")
+ proto.RegisterType((*PollRequest)(nil), "syscall_rpc.PollRequest")
+ proto.RegisterType((*PollResponse)(nil), "syscall_rpc.PollResponse")
+ proto.RegisterType((*SyscallRequest)(nil), "syscall_rpc.SyscallRequest")
+ proto.RegisterType((*SyscallResponse)(nil), "syscall_rpc.SyscallResponse")
+}
+
+func init() {
+ proto.RegisterFile("pkg/sentry/socket/rpcinet/syscall_rpc.proto", fileDescriptor_dd04f3a8f0c5288b)
+}
+
+var fileDescriptor_dd04f3a8f0c5288b = []byte{
+ // 1838 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xbc, 0x98, 0xdd, 0x52, 0xe3, 0xc8,
+ 0x15, 0x80, 0x6d, 0x6c, 0xc0, 0x1c, 0xff, 0xa2, 0x21, 0xac, 0xf8, 0x67, 0x95, 0xa4, 0x8a, 0x4d,
+ 0x2a, 0xb8, 0xc6, 0x33, 0xec, 0x90, 0xdd, 0xad, 0xdd, 0x2c, 0x04, 0xd6, 0x53, 0x99, 0x1a, 0x88,
+ 0x5c, 0x09, 0xa9, 0xe4, 0xc2, 0x25, 0xa4, 0xb6, 0x71, 0x21, 0x4b, 0x8a, 0x24, 0x43, 0x71, 0x93,
+ 0x07, 0xc8, 0x75, 0xee, 0x72, 0x91, 0x87, 0xca, 0x03, 0xe4, 0x09, 0xf2, 0x0e, 0xa9, 0xd3, 0xdd,
+ 0x92, 0xba, 0x45, 0x8b, 0xc1, 0x29, 0x6a, 0xef, 0xd4, 0xad, 0xf3, 0xd7, 0xe7, 0x74, 0x7f, 0x3a,
+ 0x2d, 0xf8, 0x65, 0x70, 0x3b, 0xee, 0x46, 0xc4, 0x8b, 0xc3, 0x87, 0x6e, 0xe4, 0xdb, 0xb7, 0x24,
+ 0xee, 0x86, 0x81, 0x3d, 0xf1, 0x48, 0xdc, 0x8d, 0x1e, 0x22, 0xdb, 0x72, 0xdd, 0x61, 0x18, 0xd8,
+ 0x87, 0x41, 0xe8, 0xc7, 0xbe, 0x56, 0x17, 0xa6, 0x8c, 0xbf, 0x97, 0xa1, 0x35, 0x20, 0x9e, 0x33,
+ 0x8d, 0xc6, 0x26, 0xf9, 0xeb, 0x8c, 0x44, 0xb1, 0xd6, 0x82, 0x85, 0x91, 0xa3, 0x97, 0xf7, 0xcb,
+ 0x07, 0x4d, 0x73, 0x61, 0xe4, 0x68, 0xeb, 0x50, 0x75, 0xac, 0xd8, 0xd2, 0x17, 0xf6, 0xcb, 0x07,
+ 0x8d, 0x93, 0x85, 0x5a, 0xd9, 0xa4, 0x63, 0x4d, 0x87, 0x65, 0xcb, 0x71, 0x42, 0x12, 0x45, 0x7a,
+ 0x05, 0x5f, 0x99, 0xc9, 0x50, 0xd3, 0xa0, 0x3a, 0xf5, 0x43, 0xa2, 0x57, 0xf7, 0xcb, 0x07, 0x35,
+ 0x93, 0x3e, 0x6b, 0x06, 0x34, 0x89, 0xe7, 0x0c, 0xfd, 0xd1, 0x30, 0x24, 0xb6, 0x1f, 0x3a, 0xfa,
+ 0x22, 0x7d, 0x59, 0x27, 0x9e, 0x73, 0x31, 0x32, 0xe9, 0x94, 0xf1, 0x67, 0x68, 0xa7, 0xb1, 0x44,
+ 0x81, 0xef, 0x45, 0x44, 0xfb, 0x29, 0x34, 0x48, 0x18, 0xfa, 0xe1, 0xd0, 0x9b, 0x4d, 0xaf, 0x49,
+ 0xc8, 0xc2, 0xea, 0x97, 0xcc, 0x3a, 0x9d, 0xfd, 0x48, 0x27, 0x35, 0x1d, 0x96, 0x5c, 0xe2, 0x8d,
+ 0xe3, 0x1b, 0x1a, 0x23, 0xbe, 0xe6, 0xe3, 0x93, 0x1a, 0x2c, 0x85, 0x24, 0x9a, 0xb9, 0xb1, 0x71,
+ 0x02, 0x8d, 0xf7, 0x17, 0xa7, 0xb1, 0x5b, 0xb4, 0xca, 0x0e, 0x54, 0xec, 0xa9, 0xc3, 0x0c, 0x98,
+ 0xf8, 0x88, 0x33, 0x56, 0x38, 0xe6, 0x6b, 0xc3, 0x47, 0xe3, 0x8f, 0xd0, 0xe4, 0x36, 0xe6, 0x89,
+ 0x6e, 0x1d, 0x16, 0xef, 0x2c, 0x77, 0x46, 0x58, 0x02, 0xfb, 0x25, 0x93, 0x0d, 0x85, 0xd8, 0xfe,
+ 0x59, 0x86, 0x96, 0x49, 0xec, 0xbb, 0x27, 0x8b, 0x20, 0x2d, 0x31, 0x59, 0x20, 0xce, 0x47, 0xc4,
+ 0x73, 0x48, 0x48, 0xe3, 0xac, 0x99, 0x7c, 0x84, 0x25, 0x08, 0x08, 0xb9, 0x4d, 0x4a, 0x80, 0xcf,
+ 0xda, 0x1a, 0x2c, 0xc6, 0xe1, 0xcc, 0xb3, 0x79, 0xea, 0xd9, 0x40, 0xdb, 0x83, 0xba, 0x3d, 0x8d,
+ 0xc6, 0x43, 0x6e, 0x7e, 0x89, 0x9a, 0x07, 0x9c, 0xfa, 0x40, 0x67, 0x8c, 0xdf, 0x41, 0xfd, 0x22,
+ 0x20, 0x5e, 0x12, 0x19, 0x5a, 0xb6, 0xe2, 0x1b, 0x1a, 0x5b, 0xc3, 0xa4, 0xcf, 0x68, 0x79, 0xe4,
+ 0x5a, 0xe3, 0x88, 0x07, 0xc7, 0x06, 0x6c, 0x1b, 0x38, 0x84, 0x46, 0xd6, 0x34, 0xe9, 0xb3, 0x71,
+ 0x01, 0x0d, 0x66, 0x6c, 0x9e, 0x0c, 0x76, 0x68, 0x32, 0x92, 0xda, 0x2e, 0x8c, 0x1c, 0x21, 0x77,
+ 0x47, 0x50, 0x37, 0x89, 0xe5, 0xcc, 0x99, 0x37, 0xe3, 0x0a, 0x1a, 0x4c, 0x6d, 0xbe, 0x7d, 0x96,
+ 0x3b, 0x09, 0xfd, 0x12, 0x3b, 0x0b, 0x42, 0x3c, 0x3f, 0x87, 0x36, 0x1a, 0x3e, 0x9f, 0xb8, 0x44,
+ 0x95, 0xb1, 0x15, 0x96, 0x31, 0xe3, 0x2f, 0xd0, 0xc9, 0xc4, 0x5e, 0x3a, 0x86, 0x2f, 0xa1, 0x71,
+ 0x15, 0x4e, 0x62, 0x32, 0xe7, 0x89, 0x36, 0xfe, 0x04, 0x4d, 0xae, 0xf7, 0xd2, 0xa7, 0xef, 0x37,
+ 0xd0, 0xa1, 0x96, 0x3f, 0x91, 0x16, 0x64, 0x8a, 0xed, 0x7b, 0x31, 0xf1, 0x62, 0x16, 0x9c, 0x99,
+ 0x0c, 0x8d, 0x4b, 0x58, 0x15, 0x2c, 0xf0, 0xf8, 0x3e, 0x57, 0xc5, 0x97, 0x8f, 0x6e, 0xf9, 0x3e,
+ 0x9c, 0xc4, 0x31, 0xf1, 0xf8, 0x0e, 0x48, 0x86, 0xc6, 0x29, 0xb4, 0xbf, 0x67, 0xc0, 0x4a, 0xed,
+ 0x09, 0x48, 0x2b, 0xcb, 0x48, 0x2b, 0xda, 0x47, 0xff, 0x5a, 0xc0, 0x7a, 0xf3, 0xa3, 0x3b, 0x4f,
+ 0xd6, 0xce, 0x61, 0x39, 0xb0, 0x1e, 0x5c, 0xdf, 0x62, 0x1b, 0xbb, 0xde, 0xfb, 0xc5, 0xa1, 0x88,
+ 0xea, 0x9c, 0xcd, 0x43, 0x93, 0xe6, 0xf1, 0x92, 0x69, 0xf4, 0x4b, 0x66, 0xa2, 0xbc, 0xf9, 0x8f,
+ 0x32, 0x34, 0xa5, 0x97, 0x69, 0x75, 0xcb, 0x39, 0x5e, 0x7f, 0x99, 0x2d, 0x8e, 0x79, 0xdc, 0x96,
+ 0x3c, 0xe6, 0x72, 0xa1, 0x5a, 0x7a, 0x45, 0x42, 0xcf, 0x16, 0xac, 0x50, 0x70, 0x50, 0x67, 0x55,
+ 0x9a, 0xae, 0x1a, 0x4e, 0xfc, 0x56, 0xde, 0x8c, 0xef, 0xa0, 0x7e, 0x32, 0xf1, 0x0a, 0x0f, 0xa8,
+ 0x2e, 0x47, 0x95, 0xa5, 0xdc, 0x78, 0x0d, 0x0d, 0xa6, 0xf8, 0xec, 0x62, 0x1b, 0xef, 0xa1, 0xf9,
+ 0xbd, 0x6d, 0x93, 0x20, 0x2e, 0xf2, 0xc6, 0xb0, 0x18, 0x52, 0x57, 0x0c, 0x8b, 0x61, 0x06, 0x2f,
+ 0x5c, 0x5e, 0x85, 0xc3, 0xcb, 0xf8, 0x4f, 0x19, 0x5a, 0x89, 0xad, 0x79, 0xea, 0x7a, 0x96, 0xaf,
+ 0xeb, 0x17, 0x72, 0x96, 0x25, 0x93, 0xc5, 0x65, 0xbd, 0xca, 0x57, 0x35, 0xbf, 0x92, 0xff, 0xb3,
+ 0x9a, 0x42, 0x61, 0xbe, 0x82, 0xd6, 0xa9, 0xef, 0x79, 0xc4, 0x8e, 0xe7, 0xaf, 0xcd, 0x5b, 0x68,
+ 0xa7, 0xba, 0xcf, 0x2f, 0xcf, 0xaf, 0xa1, 0xf9, 0x61, 0x12, 0xc5, 0xd9, 0xb7, 0x44, 0xe1, 0xf0,
+ 0xda, 0xb2, 0x6f, 0x5d, 0x7f, 0x4c, 0x1d, 0x56, 0xcc, 0x64, 0x68, 0xbc, 0x81, 0x56, 0xa2, 0xfa,
+ 0x7c, 0x7f, 0x6f, 0xa0, 0x3d, 0xb8, 0x99, 0xc5, 0x8e, 0x7f, 0xef, 0x3d, 0xf1, 0xd9, 0xbf, 0xf1,
+ 0xef, 0xb9, 0x37, 0x7c, 0x34, 0x8e, 0xa0, 0x93, 0x29, 0x3d, 0xdf, 0xd7, 0x2e, 0x34, 0x4e, 0x5d,
+ 0x3f, 0x2a, 0x62, 0xae, 0xd1, 0x83, 0x26, 0x7f, 0xff, 0x7c, 0x9b, 0x04, 0x56, 0x7f, 0x20, 0xf1,
+ 0xc0, 0xb7, 0x6f, 0x2f, 0x8a, 0xb7, 0xf4, 0x1a, 0x2c, 0xba, 0xe4, 0x8e, 0xb8, 0x7c, 0x0d, 0x6c,
+ 0x80, 0x1b, 0xdd, 0xb3, 0xa6, 0x84, 0xef, 0x69, 0xfa, 0x2c, 0x1c, 0xe4, 0x6a, 0xee, 0x5b, 0xa8,
+ 0x89, 0x6e, 0xe6, 0xd9, 0xed, 0x1a, 0x54, 0xfc, 0x20, 0x4e, 0x3b, 0x1b, 0x1c, 0x08, 0x3b, 0x6c,
+ 0x08, 0xab, 0x83, 0x17, 0x8c, 0xbf, 0xc3, 0x9c, 0x31, 0xd4, 0xe0, 0xa3, 0xf1, 0x0e, 0xb4, 0xc1,
+ 0xe3, 0xc8, 0x9f, 0x91, 0xd9, 0x9f, 0xa5, 0x4b, 0xfe, 0x68, 0x4d, 0x0b, 0x6b, 0xf6, 0x37, 0x78,
+ 0x25, 0x49, 0xcd, 0x93, 0x99, 0xe3, 0xb9, 0xce, 0x27, 0x1e, 0xfd, 0xc7, 0x27, 0x94, 0x45, 0x79,
+ 0x49, 0x48, 0xf8, 0xe9, 0x28, 0x33, 0xa9, 0x1f, 0x3b, 0xca, 0x2b, 0x68, 0x0e, 0xe8, 0x9d, 0x23,
+ 0x09, 0x70, 0x1d, 0x96, 0x46, 0xd6, 0x74, 0xe2, 0x3e, 0x50, 0x9f, 0x15, 0x93, 0x8f, 0xb0, 0xa6,
+ 0xf1, 0x43, 0x40, 0x78, 0xa1, 0xe9, 0xb3, 0xb6, 0x09, 0x35, 0x7a, 0x2b, 0xb1, 0x7d, 0x97, 0xd7,
+ 0x3a, 0x1d, 0x1b, 0xbf, 0x87, 0x56, 0x62, 0xf8, 0xa5, 0xba, 0xc5, 0x3f, 0x40, 0xe7, 0x2c, 0xf0,
+ 0x5d, 0xf7, 0xca, 0x9a, 0x14, 0x6e, 0xc8, 0x1d, 0x00, 0x6f, 0x36, 0x1d, 0x92, 0x3b, 0xe2, 0xc5,
+ 0x49, 0x47, 0xbb, 0xe2, 0xcd, 0xa6, 0x67, 0x74, 0x82, 0x76, 0xb5, 0x11, 0xb1, 0x69, 0xb4, 0x9a,
+ 0x49, 0x9f, 0x8d, 0xb7, 0x00, 0xd4, 0x2c, 0x15, 0x51, 0xf5, 0xa0, 0x92, 0x31, 0x3e, 0x32, 0xbe,
+ 0x85, 0x7a, 0xa6, 0x15, 0x69, 0xdd, 0x54, 0xac, 0xbc, 0x5f, 0x39, 0xa8, 0xf7, 0x3e, 0x93, 0x4a,
+ 0x91, 0x49, 0xa6, 0xfa, 0x77, 0xb0, 0x2a, 0x2c, 0x66, 0x9e, 0x14, 0xf5, 0xa4, 0x88, 0xea, 0x3d,
+ 0xbd, 0xc0, 0x55, 0x84, 0xcd, 0x1c, 0x93, 0x14, 0x92, 0x18, 0x43, 0x9b, 0x8a, 0x08, 0xb7, 0x29,
+ 0x0d, 0xaa, 0x24, 0x48, 0x17, 0x4d, 0x9f, 0x31, 0x0d, 0x7e, 0xc0, 0x8b, 0xbd, 0xe0, 0x07, 0x3c,
+ 0x2d, 0x95, 0x34, 0x2d, 0xbf, 0x82, 0x45, 0x6a, 0x9a, 0x1e, 0xe8, 0x27, 0x96, 0xcb, 0xa4, 0x90,
+ 0xcb, 0x99, 0xd7, 0xe7, 0x9f, 0xf4, 0x2f, 0xe0, 0x15, 0x53, 0x0b, 0x89, 0x15, 0x93, 0xd7, 0x42,
+ 0xc0, 0xf8, 0x9d, 0xe7, 0x3b, 0x94, 0x3e, 0x1b, 0x57, 0xb0, 0x26, 0x8b, 0xbe, 0xe0, 0x1d, 0xe5,
+ 0xd2, 0x77, 0xdd, 0x27, 0xee, 0x28, 0xca, 0xfd, 0x71, 0x05, 0x0d, 0xa6, 0x36, 0x67, 0x37, 0x2e,
+ 0x1a, 0x53, 0x16, 0xf0, 0xdf, 0x00, 0xad, 0x01, 0x4b, 0x76, 0x12, 0xd3, 0x5b, 0x58, 0x62, 0x3f,
+ 0x0e, 0xa8, 0xd5, 0x7a, 0x6f, 0x53, 0xaa, 0x86, 0x74, 0xbe, 0xd1, 0x24, 0x93, 0xd5, 0xde, 0xc1,
+ 0x72, 0xc4, 0x2e, 0xec, 0x7c, 0x23, 0x6d, 0xc9, 0x6a, 0xd2, 0x8f, 0x05, 0xa4, 0x07, 0x97, 0x46,
+ 0xc5, 0x90, 0x75, 0xb8, 0x74, 0x43, 0xe4, 0x15, 0xe5, 0xcb, 0x30, 0x2a, 0x72, 0x69, 0xed, 0x10,
+ 0xaa, 0xd7, 0x13, 0xcf, 0xe1, 0x7b, 0x46, 0xde, 0xb7, 0x42, 0x9b, 0x89, 0x97, 0x22, 0x94, 0xc3,
+ 0x75, 0x59, 0xb4, 0xe5, 0xa2, 0x97, 0xde, 0xfc, 0xba, 0xa4, 0x66, 0x11, 0xd7, 0xc5, 0x64, 0x31,
+ 0x3c, 0x9b, 0xb5, 0x37, 0xf4, 0x3e, 0x9c, 0x0f, 0x4f, 0x6e, 0x9b, 0x30, 0x3c, 0x2e, 0x8d, 0xee,
+ 0x5c, 0xda, 0xa6, 0xe8, 0xcb, 0x0a, 0x77, 0x52, 0xf3, 0x43, 0xef, 0x49, 0x74, 0x42, 0xfb, 0x0a,
+ 0x6a, 0x11, 0x6f, 0x39, 0xf4, 0x9a, 0x02, 0xc3, 0xb9, 0x26, 0xa6, 0x5f, 0x32, 0x53, 0x79, 0xed,
+ 0x35, 0x2c, 0xda, 0xd8, 0x57, 0xe8, 0x2b, 0x54, 0x71, 0x43, 0x0e, 0x54, 0xe8, 0x48, 0xfa, 0x25,
+ 0x93, 0x49, 0x6a, 0x27, 0xd0, 0x18, 0x93, 0x78, 0x88, 0x35, 0x1c, 0xe2, 0x07, 0x15, 0xa8, 0xe6,
+ 0xae, 0xa4, 0xf9, 0xa8, 0xef, 0xe8, 0x97, 0x4c, 0x18, 0xa7, 0x93, 0x68, 0x23, 0x12, 0x6d, 0xd4,
+ 0x15, 0x36, 0x06, 0x2a, 0x1b, 0x51, 0x66, 0xe3, 0x0c, 0x9a, 0x69, 0x1c, 0xf4, 0x63, 0xdf, 0xa0,
+ 0x46, 0xf6, 0x54, 0x81, 0x08, 0x1f, 0x40, 0xdc, 0xf1, 0xe3, 0x6c, 0x36, 0x31, 0x83, 0xbd, 0x3c,
+ 0x33, 0xd3, 0x54, 0x9b, 0xc9, 0x7d, 0x47, 0xb9, 0x99, 0x64, 0x56, 0xfb, 0x16, 0x80, 0xe0, 0xe9,
+ 0x1f, 0xde, 0x5b, 0x93, 0x58, 0x6f, 0x51, 0x1b, 0x3b, 0x8f, 0x99, 0x24, 0x7c, 0x39, 0xfa, 0x25,
+ 0x73, 0x85, 0x24, 0x73, 0xda, 0xd7, 0xc0, 0x06, 0x43, 0x3b, 0x76, 0xf5, 0xb6, 0xa2, 0x8a, 0x39,
+ 0x66, 0x62, 0x15, 0x09, 0x9f, 0xd2, 0x7e, 0x80, 0x26, 0x57, 0x66, 0xec, 0xd1, 0x3b, 0xd4, 0xc0,
+ 0xbe, 0xc2, 0x80, 0xc4, 0xb1, 0x7e, 0xc9, 0x6c, 0x10, 0x61, 0x1a, 0xcf, 0x07, 0x0e, 0xf5, 0x55,
+ 0xc5, 0xf9, 0x10, 0x18, 0x84, 0xe7, 0x03, 0xe5, 0x50, 0x3e, 0x24, 0x96, 0xa3, 0x6b, 0x0a, 0x79,
+ 0xe1, 0xbf, 0x0a, 0xca, 0xa3, 0x1c, 0x6e, 0x37, 0xbc, 0x3f, 0x13, 0xfd, 0x95, 0x62, 0xbb, 0x89,
+ 0x3f, 0x1d, 0x70, 0xbb, 0x51, 0x49, 0x74, 0xe1, 0x07, 0xc4, 0xd3, 0xd7, 0x14, 0x2e, 0x84, 0x1f,
+ 0x4b, 0xe8, 0x02, 0xe5, 0xd0, 0xc5, 0xc4, 0xc7, 0x24, 0xfe, 0x44, 0xe1, 0x42, 0xfc, 0x87, 0x87,
+ 0x2e, 0xa8, 0x24, 0xd6, 0x8e, 0xfa, 0x1a, 0x8e, 0x26, 0x2e, 0xd1, 0xd7, 0x15, 0xb5, 0xcb, 0xff,
+ 0x7d, 0xc0, 0xda, 0xdd, 0x27, 0x73, 0x58, 0x3b, 0x5c, 0x1d, 0x53, 0xff, 0x4c, 0x51, 0xbb, 0xdc,
+ 0x2f, 0x1d, 0xac, 0x5d, 0xc8, 0xa7, 0x4e, 0x96, 0xa0, 0x6a, 0x85, 0xe3, 0xc8, 0xf8, 0x2f, 0x40,
+ 0x3b, 0xa5, 0x2a, 0x47, 0xf6, 0x51, 0x0e, 0xab, 0x5b, 0x4a, 0xac, 0xa6, 0xdd, 0x55, 0xc2, 0xd5,
+ 0xe3, 0x3c, 0x57, 0xb7, 0xd5, 0x5c, 0xcd, 0xda, 0xb2, 0x04, 0xac, 0xc7, 0x79, 0xb0, 0x6e, 0x3f,
+ 0xf5, 0x5b, 0x41, 0x24, 0x6b, 0x57, 0x22, 0xeb, 0x86, 0x82, 0xac, 0xa9, 0x0e, 0x43, 0xeb, 0x51,
+ 0x0e, 0xad, 0x5b, 0x4f, 0x5c, 0x74, 0x05, 0xb6, 0x1e, 0xe7, 0xd9, 0xba, 0xad, 0x66, 0x6b, 0x16,
+ 0x61, 0x02, 0xd7, 0xa3, 0x1c, 0x5c, 0xb7, 0x94, 0x70, 0xcd, 0x1c, 0x72, 0xba, 0x7e, 0xfd, 0x88,
+ 0xae, 0x3b, 0x05, 0x74, 0x4d, 0x55, 0x33, 0xbc, 0xf6, 0x64, 0xbc, 0x6e, 0xaa, 0xf0, 0x9a, 0xaa,
+ 0x71, 0xbe, 0x9e, 0x2a, 0xf9, 0xba, 0x57, 0xc8, 0xd7, 0x54, 0x5f, 0x04, 0xec, 0xa9, 0x12, 0xb0,
+ 0x7b, 0x85, 0x80, 0xcd, 0x8c, 0x08, 0x84, 0x3d, 0x57, 0x13, 0x76, 0xbf, 0x98, 0xb0, 0xa9, 0x19,
+ 0x09, 0xb1, 0xe7, 0x6a, 0xc4, 0xee, 0x17, 0x23, 0x56, 0xb2, 0x93, 0x32, 0xf6, 0x3b, 0x05, 0x63,
+ 0x77, 0x8b, 0x18, 0x9b, 0x9a, 0x10, 0x20, 0xfb, 0xcd, 0x63, 0xc8, 0xee, 0x14, 0x40, 0x36, 0x2b,
+ 0x66, 0x4a, 0xd9, 0xbe, 0x9a, 0xb2, 0x9f, 0x3f, 0x41, 0xd9, 0xd4, 0x8a, 0x8c, 0xd9, 0xae, 0x84,
+ 0xd9, 0x0d, 0x05, 0x66, 0xb3, 0xc3, 0x42, 0x39, 0xdb, 0x95, 0x38, 0xbb, 0xa1, 0xe0, 0x6c, 0xa6,
+ 0x40, 0x41, 0xdb, 0x93, 0x41, 0xbb, 0xa9, 0x02, 0x6d, 0xb6, 0xf1, 0x18, 0x69, 0xbb, 0x12, 0x69,
+ 0x37, 0x14, 0xa4, 0xcd, 0x9c, 0x50, 0xd4, 0xf6, 0x64, 0xd4, 0x6e, 0xaa, 0x50, 0x9b, 0x39, 0x61,
+ 0xac, 0xfd, 0x4e, 0xc1, 0xda, 0xdd, 0x22, 0xd6, 0x66, 0x35, 0xcc, 0x60, 0xfb, 0xcd, 0x63, 0xd8,
+ 0xee, 0x14, 0xc0, 0x36, 0xab, 0x61, 0x4a, 0xdb, 0xb4, 0x8b, 0xbd, 0x5e, 0xa2, 0x17, 0xc5, 0x37,
+ 0xff, 0x0b, 0x00, 0x00, 0xff, 0xff, 0xd5, 0x22, 0x29, 0x56, 0xfd, 0x1a, 0x00, 0x00,
+}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
new file mode 100644
index 000000000..9393acd28
--- /dev/null
+++ b/pkg/sentry/socket/socket.go
@@ -0,0 +1,336 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package socket provides the interfaces that need to be provided by socket
+// implementations and providers, as well as per family demultiplexing of socket
+// creation.
+package socket
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// ControlMessages represents the union of unix control messages and tcpip
+// control messages.
+type ControlMessages struct {
+ Unix transport.ControlMessages
+ IP tcpip.ControlMessages
+}
+
+// Socket is the interface containing socket syscalls used by the syscall layer
+// to redirect them to the appropriate implementation.
+type Socket interface {
+ fs.FileOperations
+
+ // Connect implements the connect(2) linux syscall.
+ Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
+
+ // Accept implements the accept4(2) linux syscall.
+ // Returns fd, real peer address length and error. Real peer address
+ // length is only set if len(peer) > 0.
+ Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error)
+
+ // Bind implements the bind(2) linux syscall.
+ Bind(t *kernel.Task, sockaddr []byte) *syserr.Error
+
+ // Listen implements the listen(2) linux syscall.
+ Listen(t *kernel.Task, backlog int) *syserr.Error
+
+ // Shutdown implements the shutdown(2) linux syscall.
+ Shutdown(t *kernel.Task, how int) *syserr.Error
+
+ // GetSockOpt implements the getsockopt(2) linux syscall.
+ GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error)
+
+ // SetSockOpt implements the setsockopt(2) linux syscall.
+ SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
+
+ // GetSockName implements the getsockname(2) linux syscall.
+ //
+ // addrLen is the address length to be returned to the application, not
+ // necessarily the actual length of the address.
+ GetSockName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error)
+
+ // GetPeerName implements the getpeername(2) linux syscall.
+ //
+ // addrLen is the address length to be returned to the application, not
+ // necessarily the actual length of the address.
+ GetPeerName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error)
+
+ // RecvMsg implements the recvmsg(2) linux syscall.
+ //
+ // senderAddrLen is the address length to be returned to the application,
+ // not necessarily the actual length of the address.
+ //
+ // flags control how RecvMsg should be completed. msgFlags indicate how
+ // the RecvMsg call was completed. Note that control message truncation
+ // may still be required even if the MSG_CTRUNC bit is not set in
+ // msgFlags. In that case, the caller should set MSG_CTRUNC appropriately.
+ //
+ // If err != nil, the recv was not successful.
+ RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
+
+ // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
+ // ownership of the ControlMessage on error.
+ //
+ // If n > 0, err will either be nil or an error from t.Block.
+ SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages ControlMessages) (n int, err *syserr.Error)
+
+ // SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means
+ // no timeout, and negative means DONTWAIT.
+ SetRecvTimeout(nanoseconds int64)
+
+ // RecvTimeout gets the current timeout (in ns) for recv operations. Zero
+ // means no timeout, and negative means DONTWAIT.
+ RecvTimeout() int64
+
+ // SetSendTimeout sets the timeout (in ns) for send operations. Zero means
+ // no timeout, and negative means DONTWAIT.
+ SetSendTimeout(nanoseconds int64)
+
+ // SendTimeout gets the current timeout (in ns) for send operations. Zero
+ // means no timeout, and negative means DONTWAIT.
+ SendTimeout() int64
+}
+
+// Provider is the interface implemented by providers of sockets for specific
+// address families (e.g., AF_INET).
+type Provider interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error)
+}
+
+// families holds a map of all known address families and their providers.
+var families = make(map[int][]Provider)
+
+// RegisterProvider registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+func RegisterProvider(family int, provider Provider) {
+ families[family] = append(families[family], provider)
+}
+
+// New creates a new socket with the given family, type and protocol.
+func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+ for _, p := range families[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocket(s, family)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// Pair creates a new connected socket pair with the given family, type and
+// protocol.
+func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ providers, ok := families[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocket(s1, family)
+ k.RecordSocket(s2, family)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
+// NewDirent returns a sockfs fs.Dirent that resides on device d.
+func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
+ ino := d.NextIno()
+ iops := &fsutil.SimpleFileInode{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ }, linux.SOCKFS_MAGIC),
+ }
+ inode := fs.NewInode(iops, fs.NewPseudoMountSource(), fs.StableAttr{
+ Type: fs.Socket,
+ DeviceID: d.DeviceID(),
+ InodeID: ino,
+ BlockSize: usermem.PageSize,
+ })
+
+ // Dirent name matches net/socket.c:sockfs_dname.
+ return fs.NewDirent(inode, fmt.Sprintf("socket:[%d]", ino))
+}
+
+// SendReceiveTimeout stores timeouts for send and receive calls.
+//
+// It is meant to be embedded into Socket implementations to help satisfy the
+// interface.
+//
+// Care must be taken when copying SendReceiveTimeout as it contains atomic
+// variables.
+//
+// +stateify savable
+type SendReceiveTimeout struct {
+ // send is length of the send timeout in nanoseconds.
+ //
+ // send must be accessed atomically.
+ send int64
+
+ // recv is length of the receive timeout in nanoseconds.
+ //
+ // recv must be accessed atomically.
+ recv int64
+}
+
+// SetRecvTimeout implements Socket.SetRecvTimeout.
+func (to *SendReceiveTimeout) SetRecvTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.recv, nanoseconds)
+}
+
+// RecvTimeout implements Socket.RecvTimeout.
+func (to *SendReceiveTimeout) RecvTimeout() int64 {
+ return atomic.LoadInt64(&to.recv)
+}
+
+// SetSendTimeout implements Socket.SetSendTimeout.
+func (to *SendReceiveTimeout) SetSendTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.send, nanoseconds)
+}
+
+// SendTimeout implements Socket.SendTimeout.
+func (to *SendReceiveTimeout) SendTimeout() int64 {
+ return atomic.LoadInt64(&to.send)
+}
+
+// GetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid.
+// It contains names that are valid for GetSockOpt when level is SOL_SOCKET.
+func GetSockOptEmitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_ACCEPTCONN,
+ linux.SO_BPF_EXTENSIONS,
+ linux.SO_COOKIE,
+ linux.SO_DOMAIN,
+ linux.SO_ERROR,
+ linux.SO_GET_FILTER,
+ linux.SO_INCOMING_NAPI_ID,
+ linux.SO_MEMINFO,
+ linux.SO_PEERCRED,
+ linux.SO_PEERGROUPS,
+ linux.SO_PEERNAME,
+ linux.SO_PEERSEC,
+ linux.SO_PROTOCOL,
+ linux.SO_SNDLOWAT,
+ linux.SO_TYPE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEvent(t, name)
+ }
+}
+
+// SetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid.
+// It contains names that are valid for SetSockOpt when level is SOL_SOCKET.
+func SetSockOptEmitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_ATTACH_BPF,
+ linux.SO_ATTACH_FILTER,
+ linux.SO_ATTACH_REUSEPORT_CBPF,
+ linux.SO_ATTACH_REUSEPORT_EBPF,
+ linux.SO_CNX_ADVICE,
+ linux.SO_DETACH_FILTER,
+ linux.SO_RCVBUFFORCE,
+ linux.SO_SNDBUFFORCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEvent(t, name)
+ }
+}
+
+// emitUnimplementedEvent emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSocketOpt when level is
+// SOL_SOCKET.
+func emitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_BINDTODEVICE,
+ linux.SO_BROADCAST,
+ linux.SO_BSDCOMPAT,
+ linux.SO_BUSY_POLL,
+ linux.SO_DEBUG,
+ linux.SO_DONTROUTE,
+ linux.SO_INCOMING_CPU,
+ linux.SO_KEEPALIVE,
+ linux.SO_LINGER,
+ linux.SO_LOCK_FILTER,
+ linux.SO_MARK,
+ linux.SO_MAX_PACING_RATE,
+ linux.SO_NOFCS,
+ linux.SO_NO_CHECK,
+ linux.SO_OOBINLINE,
+ linux.SO_PASSCRED,
+ linux.SO_PASSSEC,
+ linux.SO_PEEK_OFF,
+ linux.SO_PRIORITY,
+ linux.SO_RCVBUF,
+ linux.SO_RCVLOWAT,
+ linux.SO_RCVTIMEO,
+ linux.SO_REUSEADDR,
+ linux.SO_REUSEPORT,
+ linux.SO_RXQ_OVFL,
+ linux.SO_SELECT_ERR_QUEUE,
+ linux.SO_SNDBUF,
+ linux.SO_SNDTIMEO,
+ linux.SO_TIMESTAMP,
+ linux.SO_TIMESTAMPING,
+ linux.SO_TIMESTAMPNS,
+ linux.SO_TXTIME,
+ linux.SO_WIFI_STATUS,
+ linux.SO_ZEROCOPY:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
diff --git a/pkg/sentry/socket/socket_state_autogen.go b/pkg/sentry/socket/socket_state_autogen.go
new file mode 100755
index 000000000..f3c899200
--- /dev/null
+++ b/pkg/sentry/socket/socket_state_autogen.go
@@ -0,0 +1,24 @@
+// automatically generated by stateify.
+
+package socket
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *SendReceiveTimeout) beforeSave() {}
+func (x *SendReceiveTimeout) save(m state.Map) {
+ x.beforeSave()
+ m.Save("send", &x.send)
+ m.Save("recv", &x.recv)
+}
+
+func (x *SendReceiveTimeout) afterLoad() {}
+func (x *SendReceiveTimeout) load(m state.Map) {
+ m.Load("send", &x.send)
+ m.Load("recv", &x.recv)
+}
+
+func init() {
+ state.Register("socket.SendReceiveTimeout", (*SendReceiveTimeout)(nil), state.Fns{Save: (*SendReceiveTimeout).save, Load: (*SendReceiveTimeout).load})
+}
diff --git a/pkg/sentry/socket/unix/device.go b/pkg/sentry/socket/unix/device.go
new file mode 100644
index 000000000..734d39ee6
--- /dev/null
+++ b/pkg/sentry/socket/unix/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// unixSocketDevice is the unix socket virtual device.
+var unixSocketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
new file mode 100644
index 000000000..5a1475ec2
--- /dev/null
+++ b/pkg/sentry/socket/unix/io.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// EndpointWriter implements safemem.Writer that writes to a transport.Endpoint.
+//
+// EndpointWriter is not thread-safe.
+type EndpointWriter struct {
+ // Endpoint is the transport.Endpoint to write to.
+ Endpoint transport.Endpoint
+
+ // Control is the control messages to send.
+ Control transport.ControlMessages
+
+ // To is the endpoint to send to. May be nil.
+ To transport.BoundEndpoint
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecWriterFunc{func(bufs [][]byte) (int64, error) {
+ n, err := w.Endpoint.SendMsg(bufs, w.Control, w.To)
+ if err != nil {
+ return int64(n), err.ToError()
+ }
+ return int64(n), nil
+ }}.WriteFromBlocks(srcs)
+}
+
+// EndpointReader implements safemem.Reader that reads from a
+// transport.Endpoint.
+//
+// EndpointReader is not thread-safe.
+type EndpointReader struct {
+ // Endpoint is the transport.Endpoint to read from.
+ Endpoint transport.Endpoint
+
+ // Creds indicates if credential control messages are requested.
+ Creds bool
+
+ // NumRights is the number of SCM_RIGHTS FDs requested.
+ NumRights uintptr
+
+ // Peek indicates that the data should not be consumed from the
+ // endpoint.
+ Peek bool
+
+ // MsgSize is the size of the message that was read from. For stream
+ // sockets, it is the amount read.
+ MsgSize uintptr
+
+ // From, if not nil, will be set with the address read from.
+ From *tcpip.FullAddress
+
+ // Control contains the received control messages.
+ Control transport.ControlMessages
+
+ // ControlTrunc indicates that SCM_RIGHTS FDs were discarded based on
+ // the value of NumRights.
+ ControlTrunc bool
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
+ n, ms, c, ct, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return int64(n), err.ToError()
+ }
+ return int64(n), nil
+ }}.ReadToBlocks(dsts)
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
new file mode 100644
index 000000000..18e492862
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -0,0 +1,460 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// UniqueIDProvider generates a sequence of unique identifiers useful for,
+// among other things, lock ordering.
+type UniqueIDProvider interface {
+ // UniqueID returns a new unique identifier.
+ UniqueID() uint64
+}
+
+// A ConnectingEndpoint is a connectioned unix endpoint that is attempting to
+// establish a bidirectional connection with a BoundEndpoint.
+type ConnectingEndpoint interface {
+ // ID returns the endpoint's globally unique identifier. This identifier
+ // must be used to determine locking order if more than one endpoint is
+ // to be locked in the same codepath. The endpoint with the smaller
+ // identifier must be locked before endpoints with larger identifiers.
+ ID() uint64
+
+ // Passcred implements socket.Credentialer.Passcred.
+ Passcred() bool
+
+ // Type returns the socket type, typically either SockStream or
+ // SockSeqpacket. The connection attempt must be aborted if this
+ // value doesn't match the ConnectableEndpoint's type.
+ Type() SockType
+
+ // GetLocalAddress returns the bound path.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Locker protects the following methods. While locked, only the holder of
+ // the lock can change the return value of the protected methods.
+ sync.Locker
+
+ // Connected returns true iff the ConnectingEndpoint is in the connected
+ // state. ConnectingEndpoints can only be connected to a single endpoint,
+ // so the connection attempt must be aborted if this returns true.
+ Connected() bool
+
+ // Listening returns true iff the ConnectingEndpoint is in the listening
+ // state. ConnectingEndpoints cannot make connections while listening, so
+ // the connection attempt must be aborted if this returns true.
+ Listening() bool
+
+ // WaiterQueue returns a pointer to the endpoint's waiter queue.
+ WaiterQueue() *waiter.Queue
+}
+
+// connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements
+// ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint.
+//
+// connectionedEndpoints must be in connected state in order to transfer data.
+//
+// This implementation includes STREAM and SEQPACKET Unix sockets created with
+// socket(2), accept(2) or socketpair(2) and dgram unix sockets created with
+// socketpair(2). See unix_connectionless.go for the implementation of DGRAM
+// Unix sockets created with socket(2).
+//
+// The state is much simpler than a TCP endpoint, so it is not encoded
+// explicitly. Instead we enforce the following invariants:
+//
+// receiver != nil, connected != nil => connected.
+// path != "" && acceptedChan == nil => bound, not listening.
+// path != "" && acceptedChan != nil => bound and listening.
+//
+// Only one of these will be true at any moment.
+//
+// +stateify savable
+type connectionedEndpoint struct {
+ baseEndpoint
+
+ // id is the unique endpoint identifier. This is used exclusively for
+ // lock ordering within connect.
+ id uint64
+
+ // idGenerator is used to generate new unique endpoint identifiers.
+ idGenerator UniqueIDProvider
+
+ // stype is used by connecting sockets to ensure that they are the
+ // same type. The value is typically either tcpip.SockSeqpacket or
+ // tcpip.SockStream.
+ stype SockType
+
+ // acceptedChan is per the TCP endpoint implementation. Note that the
+ // sockets in this channel are _already in the connected state_, and
+ // have another associated connectionedEndpoint.
+ //
+ // If nil, then no listen call has been made.
+ acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"`
+}
+
+// NewConnectioned creates a new unbound connectionedEndpoint.
+func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
+func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
+ a := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+ b := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+
+ q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
+ q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
+
+ if stype == SockStream {
+ a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
+ b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}}
+ } else {
+ a.receiver = &queueReceiver{q1}
+ b.receiver = &queueReceiver{q2}
+ }
+
+ q2.IncRef()
+ a.connected = &connectedEndpoint{
+ endpoint: b,
+ writeQueue: q2,
+ }
+ q1.IncRef()
+ b.connected = &connectedEndpoint{
+ endpoint: a,
+ writeQueue: q1,
+ }
+
+ return a, b
+}
+
+// NewExternal creates a new externally backed Endpoint. It behaves like a
+// socketpair.
+func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// ID implements ConnectingEndpoint.ID.
+func (e *connectionedEndpoint) ID() uint64 {
+ return e.id
+}
+
+// Type implements ConnectingEndpoint.Type and Endpoint.Type.
+func (e *connectionedEndpoint) Type() SockType {
+ return e.stype
+}
+
+// WaiterQueue implements ConnectingEndpoint.WaiterQueue.
+func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue {
+ return e.Queue
+}
+
+// isBound returns true iff the connectionedEndpoint is bound (but not
+// listening).
+func (e *connectionedEndpoint) isBound() bool {
+ return e.path != "" && e.acceptedChan == nil
+}
+
+// Listening implements ConnectingEndpoint.Listening.
+func (e *connectionedEndpoint) Listening() bool {
+ return e.acceptedChan != nil
+}
+
+// Close puts the connectionedEndpoint in a closed state and frees all
+// resources associated with it.
+//
+// The socket will be a fresh state after a call to close and may be reused.
+// That is, close may be used to "unbind" or "disconnect" the socket in error
+// paths.
+func (e *connectionedEndpoint) Close() {
+ e.Lock()
+ var c ConnectedEndpoint
+ var r Receiver
+ switch {
+ case e.Connected():
+ e.connected.CloseSend()
+ e.receiver.CloseRecv()
+ c = e.connected
+ r = e.receiver
+ e.connected = nil
+ e.receiver = nil
+ case e.isBound():
+ e.path = ""
+ case e.Listening():
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.Close()
+ }
+ e.acceptedChan = nil
+ e.path = ""
+ }
+ e.Unlock()
+ if c != nil {
+ c.CloseNotify()
+ c.Release()
+ }
+ if r != nil {
+ r.CloseNotify()
+ r.Release()
+ }
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+ if ce.Type() != e.stype {
+ return syserr.ErrConnectionRefused
+ }
+
+ // Check if ce is e to avoid a deadlock.
+ if ce, ok := ce.(*connectionedEndpoint); ok && ce == e {
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Do a dance to safely acquire locks on both endpoints.
+ if e.id < ce.ID() {
+ e.Lock()
+ ce.Lock()
+ } else {
+ ce.Lock()
+ e.Lock()
+ }
+
+ // Check connecting state.
+ if ce.Connected() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Check bound state.
+ if !e.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+
+ // Create a newly bound connectionedEndpoint.
+ ne := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{
+ path: e.path,
+ Queue: &waiter.Queue{},
+ },
+ id: e.idGenerator.UniqueID(),
+ idGenerator: e.idGenerator,
+ stype: e.stype,
+ }
+
+ readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
+ ne.connected = &connectedEndpoint{
+ endpoint: ce,
+ writeQueue: readQueue,
+ }
+
+ writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
+ if e.stype == SockStream {
+ ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
+ } else {
+ ne.receiver = &queueReceiver{readQueue: writeQueue}
+ }
+
+ select {
+ case e.acceptedChan <- ne:
+ // Commit state.
+ writeQueue.IncRef()
+ connected := &connectedEndpoint{
+ endpoint: ne,
+ writeQueue: writeQueue,
+ }
+ readQueue.IncRef()
+ if e.stype == SockStream {
+ returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
+ } else {
+ returnConnect(&queueReceiver{readQueue: readQueue}, connected)
+ }
+
+ // Notify can deadlock if we are holding these locks.
+ e.Unlock()
+ ce.Unlock()
+
+ // Notify on both ends.
+ e.Notify(waiter.EventIn)
+ ce.WaiterQueue().Notify(waiter.EventOut)
+
+ return nil
+ default:
+ // Busy; return ECONNREFUSED per spec.
+ ne.Close()
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) {
+ return nil, syserr.ErrConnectionRefused
+}
+
+// Connect attempts to directly connect to another Endpoint.
+// Implements Endpoint.Connect.
+func (e *connectionedEndpoint) Connect(server BoundEndpoint) *syserr.Error {
+ returnConnect := func(r Receiver, ce ConnectedEndpoint) {
+ e.receiver = r
+ e.connected = ce
+ }
+
+ return server.BidirectionalConnect(e, returnConnect)
+}
+
+// Listen starts listening on the connection.
+func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.Listening() {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return syserr.ErrInvalidEndpointState
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+ if !e.isBound() {
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Normal case.
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ return nil
+}
+
+// Accept accepts a new connection.
+func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
+ e.Lock()
+ defer e.Unlock()
+
+ if !e.Listening() {
+ return nil, syserr.ErrInvalidEndpointState
+ }
+
+ select {
+ case ne := <-e.acceptedChan:
+ return ne, nil
+
+ default:
+ // Nothing left.
+ return nil, syserr.ErrWouldBlock
+ }
+}
+
+// Bind binds the connection.
+//
+// For Unix connectionedEndpoints, this _only sets the address associated with
+// the socket_. Work associated with sockets in the filesystem or finding those
+// sockets must be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() || e.Listening() {
+ return syserr.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return syserr.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+ // Stream sockets do not support specifying the endpoint. Seqpacket
+ // sockets ignore the passed endpoint.
+ if e.stype == SockStream && to != nil {
+ return 0, syserr.ErrNotSupported
+ }
+ return e.baseEndpoint.SendMsg(data, c, to)
+}
+
+// Readiness returns the current readiness of the connectionedEndpoint. For
+// example, if waiter.EventIn is set, the connectionedEndpoint is immediately
+// readable.
+func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ switch {
+ case e.Connected():
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ case e.Listening():
+ if mask&waiter.EventIn != 0 && len(e.acceptedChan) > 0 {
+ ready |= waiter.EventIn
+ }
+ }
+
+ return ready
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go
new file mode 100644
index 000000000..7e02a5db8
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectioned_state.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+// saveAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) saveAcceptedChan() []*connectionedEndpoint {
+ // If acceptedChan is nil (i.e. we are not listening) then we will save nil.
+ // Otherwise we create a (possibly empty) slice of the values in acceptedChan and
+ // save that.
+ var acceptedSlice []*connectionedEndpoint
+ if e.acceptedChan != nil {
+ // Swap out acceptedChan with a new empty channel of the same capacity.
+ saveChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(saveChan))
+
+ // Create a new slice with the same len and capacity as the channel.
+ acceptedSlice = make([]*connectionedEndpoint, len(saveChan), cap(saveChan))
+ // Drain acceptedChan into saveSlice, and fill up the new acceptChan at the
+ // same time.
+ for i := range acceptedSlice {
+ ep := <-saveChan
+ acceptedSlice[i] = ep
+ e.acceptedChan <- ep
+ }
+ close(saveChan)
+ }
+ return acceptedSlice
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEndpoint) {
+ // If acceptedSlice is nil, then acceptedChan should also be nil.
+ if acceptedSlice != nil {
+ // Otherwise, create a new channel with the same capacity as acceptedSlice.
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(acceptedSlice))
+ // Seed the channel with values from acceptedSlice.
+ for _, ep := range acceptedSlice {
+ e.acceptedChan <- ep
+ }
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
new file mode 100644
index 000000000..43ff875e4
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -0,0 +1,196 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// connectionlessEndpoint is a unix endpoint for unix sockets that support operating in
+// a conectionless fashon.
+//
+// Specifically, this means datagram unix sockets not created with
+// socketpair(2).
+//
+// +stateify savable
+type connectionlessEndpoint struct {
+ baseEndpoint
+}
+
+// NewConnectionless creates a new unbound dgram endpoint.
+func NewConnectionless() Endpoint {
+ ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
+ ep.receiver = &queueReceiver{readQueue: &queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}}
+ return ep
+}
+
+// isBound returns true iff the endpoint is bound.
+func (e *connectionlessEndpoint) isBound() bool {
+ return e.path != ""
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it.
+//
+// The socket will be a fresh state after a call to close and may be reused.
+// That is, close may be used to "unbind" or "disconnect" the socket in error
+// paths.
+func (e *connectionlessEndpoint) Close() {
+ e.Lock()
+ var r Receiver
+ if e.Connected() {
+ e.receiver.CloseRecv()
+ r = e.receiver
+ e.receiver = nil
+
+ e.connected.Release()
+ e.connected = nil
+ }
+ if e.isBound() {
+ e.path = ""
+ }
+ e.Unlock()
+ if r != nil {
+ r.CloseNotify()
+ r.Release()
+ }
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+ return syserr.ErrConnectionRefused
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) {
+ e.Lock()
+ r := e.receiver
+ e.Unlock()
+ if r == nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+ q := r.(*queueReceiver).readQueue
+ if !q.TryIncRef() {
+ return nil, syserr.ErrConnectionRefused
+ }
+ return &connectedEndpoint{
+ endpoint: e,
+ writeQueue: q,
+ }, nil
+}
+
+// SendMsg writes data and a control message to the specified endpoint.
+// This method does not block if the data cannot be written.
+func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+ if to == nil {
+ return e.baseEndpoint.SendMsg(data, c, nil)
+ }
+
+ connected, err := to.UnidirectionalConnect()
+ if err != nil {
+ return 0, syserr.ErrInvalidEndpointState
+ }
+ defer connected.Release()
+
+ e.Lock()
+ n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+
+ if notify {
+ connected.SendNotify()
+ }
+
+ return n, err
+}
+
+// Type implements Endpoint.Type.
+func (e *connectionlessEndpoint) Type() SockType {
+ return SockDgram
+}
+
+// Connect attempts to connect directly to server.
+func (e *connectionlessEndpoint) Connect(server BoundEndpoint) *syserr.Error {
+ connected, err := server.UnidirectionalConnect()
+ if err != nil {
+ return err
+ }
+
+ e.Lock()
+ e.connected = connected
+ e.Unlock()
+
+ return nil
+}
+
+// Listen starts listening on the connection.
+func (e *connectionlessEndpoint) Listen(int) *syserr.Error {
+ return syserr.ErrNotSupported
+}
+
+// Accept accepts a new connection.
+func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) {
+ return nil, syserr.ErrNotSupported
+}
+
+// Bind binds the connection.
+//
+// For Unix endpoints, this _only sets the address associated with the socket_.
+// Work associated with sockets in the filesystem or finding those sockets must
+// be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() {
+ return syserr.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return syserr.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+
+ if e.Connected() {
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ }
+
+ return ready
+}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
new file mode 100644
index 000000000..b650caae7
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -0,0 +1,210 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// queue is a buffer queue.
+//
+// +stateify savable
+type queue struct {
+ refs.AtomicRefCount
+
+ ReaderQueue *waiter.Queue
+ WriterQueue *waiter.Queue
+
+ mu sync.Mutex `state:"nosave"`
+ closed bool
+ used int64
+ limit int64
+ dataList messageList
+}
+
+// Close closes q for reading and writing. It is immediately not writable and
+// will become unreadable when no more data is pending.
+//
+// Both the read and write queues must be notified after closing:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Close() {
+ q.mu.Lock()
+ q.closed = true
+ q.mu.Unlock()
+}
+
+// Reset empties the queue and Releases all of the Entries.
+//
+// Both the read and write queues must be notified after resetting:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Reset() {
+ q.mu.Lock()
+ for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.Release()
+ }
+ q.dataList.Reset()
+ q.used = 0
+ q.mu.Unlock()
+}
+
+// DecRef implements RefCounter.DecRef with destructor q.Reset.
+func (q *queue) DecRef() {
+ q.DecRefWithDestructor(q.Reset)
+ // We don't need to notify after resetting because no one cares about
+ // this queue after all references have been dropped.
+}
+
+// IsReadable determines if q is currently readable.
+func (q *queue) IsReadable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.dataList.Front() != nil
+}
+
+// bufWritable returns true if there is space for writing.
+//
+// N.B. Linux only considers a unix socket "writable" if >75% of the buffer is
+// free.
+//
+// See net/unix/af_unix.c:unix_writeable.
+func (q *queue) bufWritable() bool {
+ return 4*q.used < q.limit
+}
+
+// IsWritable determines if q is currently writable.
+func (q *queue) IsWritable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.bufWritable()
+}
+
+// Enqueue adds an entry to the data queue if room is available.
+//
+// If truncate is true, Enqueue may truncate the message beforing enqueuing it.
+// Otherwise, the entire message must fit. If n < e.Length(), err indicates why.
+//
+// If notify is true, ReaderQueue.Notify must be called:
+// q.ReaderQueue.Notify(waiter.EventIn)
+func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *syserr.Error) {
+ q.mu.Lock()
+
+ if q.closed {
+ q.mu.Unlock()
+ return 0, false, syserr.ErrClosedForSend
+ }
+
+ free := q.limit - q.used
+
+ l = e.Length()
+
+ if l > free && truncate {
+ if free == 0 {
+ // Message can't fit right now.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrWouldBlock
+ }
+
+ e.Truncate(free)
+ l = e.Length()
+ err = syserr.ErrWouldBlock
+ }
+
+ if l > q.limit {
+ // Message is too big to ever fit.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrMessageTooLong
+ }
+
+ if l > free {
+ // Message can't fit right now.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrWouldBlock
+ }
+
+ notify = q.dataList.Front() == nil
+ q.used += l
+ q.dataList.PushBack(e)
+
+ q.mu.Unlock()
+
+ return l, notify, err
+}
+
+// Dequeue removes the first entry in the data queue, if one exists.
+//
+// If notify is true, WriterQueue.Notify must be called:
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) {
+ q.mu.Lock()
+
+ if q.dataList.Front() == nil {
+ err := syserr.ErrWouldBlock
+ if q.closed {
+ err = syserr.ErrClosedForReceive
+ }
+ q.mu.Unlock()
+
+ return nil, false, err
+ }
+
+ notify = !q.bufWritable()
+
+ e = q.dataList.Front()
+ q.dataList.Remove(e)
+ q.used -= e.Length()
+
+ notify = notify && q.bufWritable()
+
+ q.mu.Unlock()
+
+ return e, notify, nil
+}
+
+// Peek returns the first entry in the data queue, if one exists.
+func (q *queue) Peek() (*message, *syserr.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if q.dataList.Front() == nil {
+ err := syserr.ErrWouldBlock
+ if q.closed {
+ err = syserr.ErrClosedForReceive
+ }
+ return nil, err
+ }
+
+ return q.dataList.Front().Peek(), nil
+}
+
+// QueuedSize returns the number of bytes currently in the queue, that is, the
+// number of readable bytes.
+func (q *queue) QueuedSize() int64 {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ return q.used
+}
+
+// MaxQueueSize returns the maximum number of bytes storable in the queue.
+func (q *queue) MaxQueueSize() int64 {
+ return q.limit
+}
diff --git a/pkg/sentry/socket/unix/transport/transport_message_list.go b/pkg/sentry/socket/unix/transport/transport_message_list.go
new file mode 100755
index 000000000..6d394860c
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/transport_message_list.go
@@ -0,0 +1,173 @@
+package transport
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type messageElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (messageElementMapper) linkerFor(elem *message) *message { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type messageList struct {
+ head *message
+ tail *message
+}
+
+// Reset resets list l to the empty state.
+func (l *messageList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *messageList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *messageList) Front() *message {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *messageList) Back() *message {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *messageList) PushFront(e *message) {
+ messageElementMapper{}.linkerFor(e).SetNext(l.head)
+ messageElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ messageElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *messageList) PushBack(e *message) {
+ messageElementMapper{}.linkerFor(e).SetNext(nil)
+ messageElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ messageElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *messageList) PushBackList(m *messageList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ messageElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ messageElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *messageList) InsertAfter(b, e *message) {
+ a := messageElementMapper{}.linkerFor(b).Next()
+ messageElementMapper{}.linkerFor(e).SetNext(a)
+ messageElementMapper{}.linkerFor(e).SetPrev(b)
+ messageElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ messageElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *messageList) InsertBefore(a, e *message) {
+ b := messageElementMapper{}.linkerFor(a).Prev()
+ messageElementMapper{}.linkerFor(e).SetNext(a)
+ messageElementMapper{}.linkerFor(e).SetPrev(b)
+ messageElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ messageElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *messageList) Remove(e *message) {
+ prev := messageElementMapper{}.linkerFor(e).Prev()
+ next := messageElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ messageElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ messageElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type messageEntry struct {
+ next *message
+ prev *message
+}
+
+// Next returns the entry that follows e in the list.
+func (e *messageEntry) Next() *message {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *messageEntry) Prev() *message {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *messageEntry) SetNext(elem *message) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *messageEntry) SetPrev(elem *message) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/socket/unix/transport/transport_state_autogen.go b/pkg/sentry/socket/unix/transport/transport_state_autogen.go
new file mode 100755
index 000000000..8bb1b5b31
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/transport_state_autogen.go
@@ -0,0 +1,191 @@
+// automatically generated by stateify.
+
+package transport
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *connectionedEndpoint) beforeSave() {}
+func (x *connectionedEndpoint) save(m state.Map) {
+ x.beforeSave()
+ var acceptedChan []*connectionedEndpoint = x.saveAcceptedChan()
+ m.SaveValue("acceptedChan", acceptedChan)
+ m.Save("baseEndpoint", &x.baseEndpoint)
+ m.Save("id", &x.id)
+ m.Save("idGenerator", &x.idGenerator)
+ m.Save("stype", &x.stype)
+}
+
+func (x *connectionedEndpoint) afterLoad() {}
+func (x *connectionedEndpoint) load(m state.Map) {
+ m.Load("baseEndpoint", &x.baseEndpoint)
+ m.Load("id", &x.id)
+ m.Load("idGenerator", &x.idGenerator)
+ m.Load("stype", &x.stype)
+ m.LoadValue("acceptedChan", new([]*connectionedEndpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*connectionedEndpoint)) })
+}
+
+func (x *connectionlessEndpoint) beforeSave() {}
+func (x *connectionlessEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("baseEndpoint", &x.baseEndpoint)
+}
+
+func (x *connectionlessEndpoint) afterLoad() {}
+func (x *connectionlessEndpoint) load(m state.Map) {
+ m.Load("baseEndpoint", &x.baseEndpoint)
+}
+
+func (x *queue) beforeSave() {}
+func (x *queue) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("ReaderQueue", &x.ReaderQueue)
+ m.Save("WriterQueue", &x.WriterQueue)
+ m.Save("closed", &x.closed)
+ m.Save("used", &x.used)
+ m.Save("limit", &x.limit)
+ m.Save("dataList", &x.dataList)
+}
+
+func (x *queue) afterLoad() {}
+func (x *queue) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("ReaderQueue", &x.ReaderQueue)
+ m.Load("WriterQueue", &x.WriterQueue)
+ m.Load("closed", &x.closed)
+ m.Load("used", &x.used)
+ m.Load("limit", &x.limit)
+ m.Load("dataList", &x.dataList)
+}
+
+func (x *messageList) beforeSave() {}
+func (x *messageList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *messageList) afterLoad() {}
+func (x *messageList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *messageEntry) beforeSave() {}
+func (x *messageEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *messageEntry) afterLoad() {}
+func (x *messageEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *ControlMessages) beforeSave() {}
+func (x *ControlMessages) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Rights", &x.Rights)
+ m.Save("Credentials", &x.Credentials)
+}
+
+func (x *ControlMessages) afterLoad() {}
+func (x *ControlMessages) load(m state.Map) {
+ m.Load("Rights", &x.Rights)
+ m.Load("Credentials", &x.Credentials)
+}
+
+func (x *message) beforeSave() {}
+func (x *message) save(m state.Map) {
+ x.beforeSave()
+ m.Save("messageEntry", &x.messageEntry)
+ m.Save("Data", &x.Data)
+ m.Save("Control", &x.Control)
+ m.Save("Address", &x.Address)
+}
+
+func (x *message) afterLoad() {}
+func (x *message) load(m state.Map) {
+ m.Load("messageEntry", &x.messageEntry)
+ m.Load("Data", &x.Data)
+ m.Load("Control", &x.Control)
+ m.Load("Address", &x.Address)
+}
+
+func (x *queueReceiver) beforeSave() {}
+func (x *queueReceiver) save(m state.Map) {
+ x.beforeSave()
+ m.Save("readQueue", &x.readQueue)
+}
+
+func (x *queueReceiver) afterLoad() {}
+func (x *queueReceiver) load(m state.Map) {
+ m.Load("readQueue", &x.readQueue)
+}
+
+func (x *streamQueueReceiver) beforeSave() {}
+func (x *streamQueueReceiver) save(m state.Map) {
+ x.beforeSave()
+ m.Save("queueReceiver", &x.queueReceiver)
+ m.Save("buffer", &x.buffer)
+ m.Save("control", &x.control)
+ m.Save("addr", &x.addr)
+}
+
+func (x *streamQueueReceiver) afterLoad() {}
+func (x *streamQueueReceiver) load(m state.Map) {
+ m.Load("queueReceiver", &x.queueReceiver)
+ m.Load("buffer", &x.buffer)
+ m.Load("control", &x.control)
+ m.Load("addr", &x.addr)
+}
+
+func (x *connectedEndpoint) beforeSave() {}
+func (x *connectedEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("endpoint", &x.endpoint)
+ m.Save("writeQueue", &x.writeQueue)
+}
+
+func (x *connectedEndpoint) afterLoad() {}
+func (x *connectedEndpoint) load(m state.Map) {
+ m.Load("endpoint", &x.endpoint)
+ m.Load("writeQueue", &x.writeQueue)
+}
+
+func (x *baseEndpoint) beforeSave() {}
+func (x *baseEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Queue", &x.Queue)
+ m.Save("passcred", &x.passcred)
+ m.Save("receiver", &x.receiver)
+ m.Save("connected", &x.connected)
+ m.Save("path", &x.path)
+}
+
+func (x *baseEndpoint) afterLoad() {}
+func (x *baseEndpoint) load(m state.Map) {
+ m.Load("Queue", &x.Queue)
+ m.Load("passcred", &x.passcred)
+ m.Load("receiver", &x.receiver)
+ m.Load("connected", &x.connected)
+ m.Load("path", &x.path)
+}
+
+func init() {
+ state.Register("transport.connectionedEndpoint", (*connectionedEndpoint)(nil), state.Fns{Save: (*connectionedEndpoint).save, Load: (*connectionedEndpoint).load})
+ state.Register("transport.connectionlessEndpoint", (*connectionlessEndpoint)(nil), state.Fns{Save: (*connectionlessEndpoint).save, Load: (*connectionlessEndpoint).load})
+ state.Register("transport.queue", (*queue)(nil), state.Fns{Save: (*queue).save, Load: (*queue).load})
+ state.Register("transport.messageList", (*messageList)(nil), state.Fns{Save: (*messageList).save, Load: (*messageList).load})
+ state.Register("transport.messageEntry", (*messageEntry)(nil), state.Fns{Save: (*messageEntry).save, Load: (*messageEntry).load})
+ state.Register("transport.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load})
+ state.Register("transport.message", (*message)(nil), state.Fns{Save: (*message).save, Load: (*message).load})
+ state.Register("transport.queueReceiver", (*queueReceiver)(nil), state.Fns{Save: (*queueReceiver).save, Load: (*queueReceiver).load})
+ state.Register("transport.streamQueueReceiver", (*streamQueueReceiver)(nil), state.Fns{Save: (*streamQueueReceiver).save, Load: (*streamQueueReceiver).load})
+ state.Register("transport.connectedEndpoint", (*connectedEndpoint)(nil), state.Fns{Save: (*connectedEndpoint).save, Load: (*connectedEndpoint).load})
+ state.Register("transport.baseEndpoint", (*baseEndpoint)(nil), state.Fns{Save: (*baseEndpoint).save, Load: (*baseEndpoint).load})
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
new file mode 100644
index 000000000..b734b4c20
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -0,0 +1,973 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport contains the implementation of Unix endpoints.
+package transport
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// initialLimit is the starting limit for the socket buffers.
+const initialLimit = 16 * 1024
+
+// A SockType is a type (as opposed to family) of sockets. These are enumerated
+// in the syscall package as syscall.SOCK_* constants.
+type SockType int
+
+const (
+ // SockStream corresponds to syscall.SOCK_STREAM.
+ SockStream SockType = 1
+ // SockDgram corresponds to syscall.SOCK_DGRAM.
+ SockDgram SockType = 2
+ // SockRaw corresponds to syscall.SOCK_RAW.
+ SockRaw SockType = 3
+ // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET.
+ SockSeqpacket SockType = 5
+)
+
+// A RightsControlMessage is a control message containing FDs.
+type RightsControlMessage interface {
+ // Clone returns a copy of the RightsControlMessage.
+ Clone() RightsControlMessage
+
+ // Release releases any resources owned by the RightsControlMessage.
+ Release()
+}
+
+// A CredentialsControlMessage is a control message containing Unix credentials.
+type CredentialsControlMessage interface {
+ // Equals returns true iff the two messages are equal.
+ Equals(CredentialsControlMessage) bool
+}
+
+// A ControlMessages represents a collection of socket control messages.
+//
+// +stateify savable
+type ControlMessages struct {
+ // Rights is a control message containing FDs.
+ Rights RightsControlMessage
+
+ // Credentials is a control message containing Unix credentials.
+ Credentials CredentialsControlMessage
+}
+
+// Empty returns true iff the ControlMessages does not contain either
+// credentials or rights.
+func (c *ControlMessages) Empty() bool {
+ return c.Rights == nil && c.Credentials == nil
+}
+
+// Clone clones both the credentials and the rights.
+func (c *ControlMessages) Clone() ControlMessages {
+ cm := ControlMessages{}
+ if c.Rights != nil {
+ cm.Rights = c.Rights.Clone()
+ }
+ cm.Credentials = c.Credentials
+ return cm
+}
+
+// Release releases both the credentials and the rights.
+func (c *ControlMessages) Release() {
+ if c.Rights != nil {
+ c.Rights.Release()
+ }
+ *c = ControlMessages{}
+}
+
+// Endpoint is the interface implemented by Unix transport protocol
+// implementations that expose functionality like sendmsg, recvmsg, connect,
+// etc. to Unix socket implementations.
+type Endpoint interface {
+ Credentialer
+ waiter.Waitable
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it.
+ Close()
+
+ // RecvMsg reads data and a control message from the endpoint. This method
+ // does not block if there is no data pending.
+ //
+ // creds indicates if credential control messages are requested by the
+ // caller. This is useful for determining if control messages can be
+ // coalesced. creds is a hint and can be safely ignored by the
+ // implementation if no coalescing is possible. It is fine to return
+ // credential control messages when none were requested or to not return
+ // credential control messages when they were requested.
+ //
+ // numRights is the number of SCM_RIGHTS FDs requested by the caller. This
+ // is useful if one must allocate a buffer to receive a SCM_RIGHTS message
+ // or determine if control messages can be coalesced. numRights is a hint
+ // and can be safely ignored by the implementation if the number of
+ // available SCM_RIGHTS FDs is known and no coalescing is possible. It is
+ // fine for the returned number of SCM_RIGHTS FDs to be either higher or
+ // lower than the requested number.
+ //
+ // If peek is true, no data should be consumed from the Endpoint. Any and
+ // all data returned from a peek should be available in the next call to
+ // RecvMsg.
+ //
+ // recvLen is the number of bytes copied into data.
+ //
+ // msgLen is the length of the read message consumed for datagram Endpoints.
+ // msgLen is always the same as recvLen for stream Endpoints.
+ //
+ // CMTruncated indicates that the numRights hint was used to receive fewer
+ // than the total available SCM_RIGHTS FDs. Additional truncation may be
+ // required by the caller.
+ RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error)
+
+ // SendMsg writes data and a control message to the endpoint's peer.
+ // This method does not block if the data cannot be written.
+ //
+ // SendMsg does not take ownership of any of its arguments on error.
+ SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error)
+
+ // Connect connects this endpoint directly to another.
+ //
+ // This should be called on the client endpoint, and the (bound)
+ // endpoint passed in as a parameter.
+ //
+ // The error codes are the same as Connect.
+ Connect(server BoundEndpoint) *syserr.Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags tcpip.ShutdownFlags) *syserr.Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *syserr.Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *syserr.Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ //
+ // An optional commit function will be executed atomically with respect
+ // to binding the endpoint. If this returns an error, the bind will not
+ // occur and the error will be propagated back to the caller.
+ Bind(address tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error
+
+ // Type return the socket type, typically either SockStream, SockDgram
+ // or SockSeqpacket.
+ Type() SockType
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option
+ // types.
+ SetSockOpt(opt interface{}) *tcpip.Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // tcpip.*Option types.
+ GetSockOpt(opt interface{}) *tcpip.Error
+}
+
+// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
+// option.
+type Credentialer interface {
+ // Passcred returns whether or not the SO_PASSCRED socket option is
+ // enabled on this end.
+ Passcred() bool
+
+ // ConnectedPasscred returns whether or not the SO_PASSCRED socket option
+ // is enabled on the connected end.
+ ConnectedPasscred() bool
+}
+
+// A BoundEndpoint is a unix endpoint that can be connected to.
+type BoundEndpoint interface {
+ // BidirectionalConnect establishes a bi-directional connection between two
+ // unix endpoints in an all-or-nothing manner. If an error occurs during
+ // connecting, the state of neither endpoint should be modified.
+ //
+ // In order for an endpoint to establish such a bidirectional connection
+ // with a BoundEndpoint, the endpoint calls the BidirectionalConnect method
+ // on the BoundEndpoint and sends a representation of itself (the
+ // ConnectingEndpoint) and a callback (returnConnect) to receive the
+ // connection information (Receiver and ConnectedEndpoint) upon a
+ // successful connect. The callback should only be called on a successful
+ // connect.
+ //
+ // For a connection attempt to be successful, the ConnectingEndpoint must
+ // be unconnected and not listening and the BoundEndpoint whose
+ // BidirectionalConnect method is being called must be listening.
+ //
+ // This method will return syserr.ErrConnectionRefused on endpoints with a
+ // type that isn't SockStream or SockSeqpacket.
+ BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error
+
+ // UnidirectionalConnect establishes a write-only connection to a unix
+ // endpoint.
+ //
+ // An endpoint which calls UnidirectionalConnect and supports it itself must
+ // not hold its own lock when calling UnidirectionalConnect.
+ //
+ // This method will return syserr.ErrConnectionRefused on a non-SockDgram
+ // endpoint.
+ UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error)
+
+ // Release releases any resources held by the BoundEndpoint. It must be
+ // called before dropping all references to a BoundEndpoint returned by a
+ // function.
+ Release()
+}
+
+// message represents a message passed over a Unix domain socket.
+//
+// +stateify savable
+type message struct {
+ messageEntry
+
+ // Data is the Message payload.
+ Data buffer.View
+
+ // Control is auxiliary control message data that goes along with the
+ // data.
+ Control ControlMessages
+
+ // Address is the bound address of the endpoint that sent the message.
+ //
+ // If the endpoint that sent the message is not bound, the Address is
+ // the empty string.
+ Address tcpip.FullAddress
+}
+
+// Length returns number of bytes stored in the message.
+func (m *message) Length() int64 {
+ return int64(len(m.Data))
+}
+
+// Release releases any resources held by the message.
+func (m *message) Release() {
+ m.Control.Release()
+}
+
+// Peek returns a copy of the message.
+func (m *message) Peek() *message {
+ return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address}
+}
+
+// Truncate reduces the length of the message payload to n bytes.
+//
+// Preconditions: n <= m.Length().
+func (m *message) Truncate(n int64) {
+ m.Data.CapLength(int(n))
+}
+
+// A Receiver can be used to receive Messages.
+type Receiver interface {
+ // Recv receives a single message. This method does not block.
+ //
+ // See Endpoint.RecvMsg for documentation on shared arguments.
+ //
+ // notify indicates if RecvNotify should be called.
+ Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
+
+ // RecvNotify notifies the Receiver of a successful Recv. This must not be
+ // called while holding any endpoint locks.
+ RecvNotify()
+
+ // CloseRecv prevents the receiving of additional Messages.
+ //
+ // After CloseRecv is called, CloseNotify must also be called.
+ CloseRecv()
+
+ // CloseNotify notifies the Receiver of recv being closed. This must not be
+ // called while holding any endpoint locks.
+ CloseNotify()
+
+ // Readable returns if messages should be attempted to be received. This
+ // includes when read has been shutdown.
+ Readable() bool
+
+ // RecvQueuedSize returns the total amount of data currently receivable.
+ // RecvQueuedSize should return -1 if the operation isn't supported.
+ RecvQueuedSize() int64
+
+ // RecvMaxQueueSize returns maximum value for RecvQueuedSize.
+ // RecvMaxQueueSize should return -1 if the operation isn't supported.
+ RecvMaxQueueSize() int64
+
+ // Release releases any resources owned by the Receiver. It should be
+ // called before droping all references to a Receiver.
+ Release()
+}
+
+// queueReceiver implements Receiver for datagram sockets.
+//
+// +stateify savable
+type queueReceiver struct {
+ readQueue *queue
+}
+
+// Recv implements Receiver.Recv.
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ var m *message
+ var notify bool
+ var err *syserr.Error
+ if peek {
+ m, err = q.readQueue.Peek()
+ } else {
+ m, notify, err = q.readQueue.Dequeue()
+ }
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
+ }
+ src := []byte(m.Data)
+ var copied uintptr
+ for i := 0; i < len(data) && len(src) > 0; i++ {
+ n := copy(data[i], src)
+ copied += uintptr(n)
+ src = src[n:]
+ }
+ return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil
+}
+
+// RecvNotify implements Receiver.RecvNotify.
+func (q *queueReceiver) RecvNotify() {
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseNotify implements Receiver.CloseNotify.
+func (q *queueReceiver) CloseNotify() {
+ q.readQueue.ReaderQueue.Notify(waiter.EventIn)
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseRecv implements Receiver.CloseRecv.
+func (q *queueReceiver) CloseRecv() {
+ q.readQueue.Close()
+}
+
+// Readable implements Receiver.Readable.
+func (q *queueReceiver) Readable() bool {
+ return q.readQueue.IsReadable()
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *queueReceiver) RecvQueuedSize() int64 {
+ return q.readQueue.QueuedSize()
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *queueReceiver) RecvMaxQueueSize() int64 {
+ return q.readQueue.MaxQueueSize()
+}
+
+// Release implements Receiver.Release.
+func (q *queueReceiver) Release() {
+ q.readQueue.DecRef()
+}
+
+// streamQueueReceiver implements Receiver for stream sockets.
+//
+// +stateify savable
+type streamQueueReceiver struct {
+ queueReceiver
+
+ mu sync.Mutex `state:"nosave"`
+ buffer []byte
+ control ControlMessages
+ addr tcpip.FullAddress
+}
+
+func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) {
+ var copied uintptr
+ for len(data) > 0 && len(buf) > 0 {
+ n := copy(data[0], buf)
+ copied += uintptr(n)
+ buf = buf[n:]
+ data[0] = data[0][n:]
+ if len(data[0]) == 0 {
+ data = data[1:]
+ }
+ }
+ return copied, data, buf
+}
+
+// Readable implements Receiver.Readable.
+func (q *streamQueueReceiver) Readable() bool {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ r := q.readQueue.IsReadable()
+ q.mu.Unlock()
+ // We're readable if we have data in our buffer or if the queue receiver is
+ // readable.
+ return bl > 0 || r
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *streamQueueReceiver) RecvQueuedSize() int64 {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ qs := q.readQueue.QueuedSize()
+ q.mu.Unlock()
+ return int64(bl) + qs
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
+ // The RecvMaxQueueSize() is the readQueue's MaxQueueSize() plus the largest
+ // message we can buffer which is also the largest message we can receive.
+ return 2 * q.readQueue.MaxQueueSize()
+}
+
+// Recv implements Receiver.Recv.
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ var notify bool
+
+ // If we have no data in the endpoint, we need to get some.
+ if len(q.buffer) == 0 {
+ // Load the next message into a buffer, even if we are peeking. Peeking
+ // won't consume the message, so it will be still available to be read
+ // the next time Recv() is called.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
+ }
+ notify = n
+ q.buffer = []byte(m.Data)
+ q.control = m.Control
+ q.addr = m.Address
+ }
+
+ var copied uintptr
+ if peek {
+ // Don't consume control message if we are peeking.
+ c := q.control.Clone()
+
+ // Don't consume data since we are peeking.
+ copied, data, _ = vecCopy(data, q.buffer)
+
+ return copied, copied, c, false, q.addr, notify, nil
+ }
+
+ // Consume data and control message since we are not peeking.
+ copied, data, q.buffer = vecCopy(data, q.buffer)
+
+ // Save the original state of q.control.
+ c := q.control
+
+ // Remove rights from q.control and leave behind just the creds.
+ q.control.Rights = nil
+ if !wantCreds {
+ c.Credentials = nil
+ }
+
+ var cmTruncated bool
+ if c.Rights != nil && numRights == 0 {
+ c.Rights.Release()
+ c.Rights = nil
+ cmTruncated = true
+ }
+
+ haveRights := c.Rights != nil
+
+ // If we have more capacity for data and haven't received any usable
+ // rights.
+ //
+ // Linux never coalesces rights control messages.
+ for !haveRights && len(data) > 0 {
+ // Get a message from the readQueue.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ // We already got some data, so ignore this error. This will
+ // manifest as a short read to the user, which is what Linux
+ // does.
+ break
+ }
+ notify = notify || n
+ q.buffer = []byte(m.Data)
+ q.control = m.Control
+ q.addr = m.Address
+
+ if wantCreds {
+ if (q.control.Credentials == nil) != (c.Credentials == nil) {
+ // One message has credentials, the other does not.
+ break
+ }
+
+ if q.control.Credentials != nil && c.Credentials != nil && !q.control.Credentials.Equals(c.Credentials) {
+ // Both messages have credentials, but they don't match.
+ break
+ }
+ }
+
+ if numRights != 0 && c.Rights != nil && q.control.Rights != nil {
+ // Both messages have rights.
+ break
+ }
+
+ var cpd uintptr
+ cpd, data, q.buffer = vecCopy(data, q.buffer)
+ copied += cpd
+
+ if cpd == 0 {
+ // data was actually full.
+ break
+ }
+
+ if q.control.Rights != nil {
+ // Consume rights.
+ if numRights == 0 {
+ cmTruncated = true
+ q.control.Rights.Release()
+ } else {
+ c.Rights = q.control.Rights
+ haveRights = true
+ }
+ q.control.Rights = nil
+ }
+ }
+ return copied, copied, c, cmTruncated, q.addr, notify, nil
+}
+
+// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
+type ConnectedEndpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Send sends a single message. This method does not block.
+ //
+ // notify indicates if SendNotify should be called.
+ //
+ // syserr.ErrWouldBlock can be returned along with a partial write if
+ // the caller should block to send the rest of the data.
+ Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *syserr.Error)
+
+ // SendNotify notifies the ConnectedEndpoint of a successful Send. This
+ // must not be called while holding any endpoint locks.
+ SendNotify()
+
+ // CloseSend prevents the sending of additional Messages.
+ //
+ // After CloseSend is call, CloseNotify must also be called.
+ CloseSend()
+
+ // CloseNotify notifies the ConnectedEndpoint of send being closed. This
+ // must not be called while holding any endpoint locks.
+ CloseNotify()
+
+ // Writable returns if messages should be attempted to be sent. This
+ // includes when write has been shutdown.
+ Writable() bool
+
+ // EventUpdate lets the ConnectedEndpoint know that event registrations
+ // have changed.
+ EventUpdate()
+
+ // SendQueuedSize returns the total amount of data currently queued for
+ // sending. SendQueuedSize should return -1 if the operation isn't
+ // supported.
+ SendQueuedSize() int64
+
+ // SendMaxQueueSize returns maximum value for SendQueuedSize.
+ // SendMaxQueueSize should return -1 if the operation isn't supported.
+ SendMaxQueueSize() int64
+
+ // Release releases any resources owned by the ConnectedEndpoint. It should
+ // be called before droping all references to a ConnectedEndpoint.
+ Release()
+}
+
+// +stateify savable
+type connectedEndpoint struct {
+ // endpoint represents the subset of the Endpoint functionality needed by
+ // the connectedEndpoint. It is implemented by both connectionedEndpoint
+ // and connectionlessEndpoint and allows the use of types which don't
+ // fully implement Endpoint.
+ endpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Type implements Endpoint.Type.
+ Type() SockType
+ }
+
+ writeQueue *queue
+}
+
+// Passcred implements ConnectedEndpoint.Passcred.
+func (e *connectedEndpoint) Passcred() bool {
+ return e.endpoint.Passcred()
+}
+
+// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress.
+func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return e.endpoint.GetLocalAddress()
+}
+
+// Send implements ConnectedEndpoint.Send.
+func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
+ var l int64
+ for _, d := range data {
+ l += int64(len(d))
+ }
+
+ truncate := false
+ if e.endpoint.Type() == SockStream {
+ // Since stream sockets don't preserve message boundaries, we
+ // can write only as much of the message as fits in the queue.
+ truncate = true
+
+ // Discard empty stream packets. Since stream sockets don't
+ // preserve message boundaries, sending zero bytes is a no-op.
+ // In Linux, the receiver actually uses a zero-length receive
+ // as an indication that the stream was closed.
+ if l == 0 {
+ controlMessages.Release()
+ return 0, false, nil
+ }
+ }
+
+ v := make([]byte, 0, l)
+ for _, d := range data {
+ v = append(v, d...)
+ }
+
+ l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate)
+ return uintptr(l), notify, err
+}
+
+// SendNotify implements ConnectedEndpoint.SendNotify.
+func (e *connectedEndpoint) SendNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+}
+
+// CloseNotify implements ConnectedEndpoint.CloseNotify.
+func (e *connectedEndpoint) CloseNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+ e.writeQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseSend implements ConnectedEndpoint.CloseSend.
+func (e *connectedEndpoint) CloseSend() {
+ e.writeQueue.Close()
+}
+
+// Writable implements ConnectedEndpoint.Writable.
+func (e *connectedEndpoint) Writable() bool {
+ return e.writeQueue.IsWritable()
+}
+
+// EventUpdate implements ConnectedEndpoint.EventUpdate.
+func (*connectedEndpoint) EventUpdate() {}
+
+// SendQueuedSize implements ConnectedEndpoint.SendQueuedSize.
+func (e *connectedEndpoint) SendQueuedSize() int64 {
+ return e.writeQueue.QueuedSize()
+}
+
+// SendMaxQueueSize implements ConnectedEndpoint.SendMaxQueueSize.
+func (e *connectedEndpoint) SendMaxQueueSize() int64 {
+ return e.writeQueue.MaxQueueSize()
+}
+
+// Release implements ConnectedEndpoint.Release.
+func (e *connectedEndpoint) Release() {
+ e.writeQueue.DecRef()
+}
+
+// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
+// unix domain socket Endpoint implementations.
+//
+// Not to be used on its own.
+//
+// +stateify savable
+type baseEndpoint struct {
+ *waiter.Queue
+
+ // passcred specifies whether SCM_CREDENTIALS socket control messages are
+ // enabled on this endpoint. Must be accessed atomically.
+ passcred int32
+
+ // Mutex protects the below fields.
+ sync.Mutex `state:"nosave"`
+
+ // receiver allows Messages to be received.
+ receiver Receiver
+
+ // connected allows messages to be sent and state information about the
+ // connected endpoint to be read.
+ connected ConnectedEndpoint
+
+ // path is not empty if the endpoint has been bound,
+ // or may be used if the endpoint is connected.
+ path string
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
+ e.Queue.EventRegister(we, mask)
+ e.Lock()
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
+ e.Queue.EventUnregister(we)
+ e.Lock()
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// Passcred implements Credentialer.Passcred.
+func (e *baseEndpoint) Passcred() bool {
+ return atomic.LoadInt32(&e.passcred) != 0
+}
+
+// ConnectedPasscred implements Credentialer.ConnectedPasscred.
+func (e *baseEndpoint) ConnectedPasscred() bool {
+ e.Lock()
+ defer e.Unlock()
+ return e.connected != nil && e.connected.Passcred()
+}
+
+func (e *baseEndpoint) setPasscred(pc bool) {
+ if pc {
+ atomic.StoreInt32(&e.passcred, 1)
+ } else {
+ atomic.StoreInt32(&e.passcred, 0)
+ }
+}
+
+// Connected implements ConnectingEndpoint.Connected.
+func (e *baseEndpoint) Connected() bool {
+ return e.receiver != nil && e.connected != nil
+}
+
+// RecvMsg reads data and a control message from the endpoint.
+func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) {
+ e.Lock()
+
+ if e.receiver == nil {
+ e.Unlock()
+ return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected
+ }
+
+ recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ e.Unlock()
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, err
+ }
+
+ if notify {
+ e.receiver.RecvNotify()
+ }
+
+ if addr != nil {
+ *addr = a
+ }
+ return recvLen, msgLen, cms, cmt, nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return 0, syserr.ErrNotConnected
+ }
+ if to != nil {
+ e.Unlock()
+ return 0, syserr.ErrAlreadyConnected
+ }
+
+ n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+
+ if notify {
+ e.connected.SendNotify()
+ }
+
+ return n, err
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.PasscredOption:
+ e.setPasscred(v != 0)
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
+ }
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.Lock()
+ if e.receiver == nil {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return syserr.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseRecv()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseSend()
+ }
+
+ e.Unlock()
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseNotify()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseNotify()
+ }
+
+ return nil
+}
+
+// GetLocalAddress returns the bound path.
+func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ defer e.Unlock()
+ return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil
+}
+
+// GetRemoteAddress returns the local address of the connected endpoint (if
+// available).
+func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ c := e.connected
+ e.Unlock()
+ if c != nil {
+ return c.GetLocalAddress()
+ }
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Release implements BoundEndpoint.Release.
+func (*baseEndpoint) Release() {
+ // Binding a baseEndpoint doesn't take a reference.
+}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
new file mode 100644
index 000000000..1414be0c6
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix.go
@@ -0,0 +1,650 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unix provides an implementation of the socket.Socket interface for
+// the AF_UNIX protocol family.
+package unix
+
+import (
+ "strings"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/epsocket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// SocketOperations is a Unix socket. It is similar to an epsocket, except it
+// is backed by a transport.Endpoint instead of a tcpip.Endpoint.
+//
+// +stateify savable
+type SocketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ isPacket bool
+}
+
+// New creates a new unix socket.
+func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
+ dirent := socket.NewDirent(ctx, unixSocketDevice)
+ defer dirent.DecRef()
+ return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
+}
+
+// NewWithDirent creates a new unix socket using an existing dirent.
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
+ return fs.NewFile(ctx, d, flags, &SocketOperations{
+ ep: ep,
+ isPacket: isPacket,
+ })
+}
+
+// DecRef implements RefCounter.DecRef.
+func (s *SocketOperations) DecRef() {
+ s.DecRefWithDestructor(func() {
+ s.ep.Close()
+ })
+}
+
+// Release implemements fs.FileOperations.Release.
+func (s *SocketOperations) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+// Endpoint extracts the transport.Endpoint.
+func (s *SocketOperations) Endpoint() transport.Endpoint {
+ return s.ep
+}
+
+// extractPath extracts and validates the address.
+func extractPath(sockaddr []byte) (string, *syserr.Error) {
+ addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr)
+ if err != nil {
+ return "", err
+ }
+
+ // The address is trimmed by GetAddress.
+ p := string(addr.Addr)
+ if p == "" {
+ // Not allowed.
+ return "", syserr.ErrInvalidArgument
+ }
+ if p[len(p)-1] == '/' {
+ // Weird, they tried to bind '/a/b/c/'?
+ return "", syserr.ErrIsDir
+ }
+
+ return p, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.ep.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.ep.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return epsocket.Ioctl(ctx, s.ep, io, args)
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+ return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return s.ep.Listen(backlog)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns := New(t, ep, s.isPacket)
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr interface{}
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ }
+ fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocket(ns, linux.AF_UNIX)
+
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ // The parent and name.
+ var d *fs.Dirent
+ var name string
+
+ cwd := t.FSContext().WorkingDirectory()
+ defer cwd.DecRef()
+
+ // Is there no slash at all?
+ if !strings.Contains(p, "/") {
+ d = cwd
+ name = p
+ } else {
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+ // Find the last path component, we know that something follows
+ // that final slash, otherwise extractPath() would have failed.
+ lastSlash := strings.LastIndex(p, "/")
+ subPath := p[:lastSlash]
+ if subPath == "" {
+ // Fix up subpath in case file is in root.
+ subPath = "/"
+ }
+ var err error
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, err = t.MountNamespace().FindInode(t, root, cwd, subPath, &remainingTraversals)
+ if err != nil {
+ // No path available.
+ return syserr.ErrNoSuchFile
+ }
+ defer d.DecRef()
+ name = p[lastSlash+1:]
+ }
+
+ // Create the socket.
+ childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
+ if err != nil {
+ return syserr.ErrPortInUse
+ }
+ childDir.DecRef()
+ }
+
+ return nil
+ })
+}
+
+// extractEndpoint retrieves the transport.BoundEndpoint associated with a Unix
+// socket path. The Release must be called on the transport.BoundEndpoint when
+// the caller is done with it.
+func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, *syserr.Error) {
+ path, err := extractPath(sockaddr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Is it abstract?
+ if path[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ ep := t.AbstractSockets().BoundEndpoint(path[1:])
+ if ep == nil {
+ // No socket found.
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ return ep, nil
+ }
+
+ // Find the node in the filesystem.
+ root := t.FSContext().RootDirectory()
+ cwd := t.FSContext().WorkingDirectory()
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals)
+ cwd.DecRef()
+ root.DecRef()
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+
+ // Extract the endpoint if one is there.
+ ep := d.Inode.BoundEndpoint(path)
+ d.DecRef()
+ if ep == nil {
+ // No socket!
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ return ep, nil
+}
+
+// Connect implements the linux syscall connect(2) for unix sockets.
+func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ ep, err := extractEndpoint(t, sockaddr)
+ if err != nil {
+ return err
+ }
+ defer ep.Release()
+
+ // Connect the server endpoint.
+ return s.ep.Connect(ep)
+}
+
+// Writev implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg([][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ w := EndpointWriter{
+ Endpoint: s.ep,
+ Control: controlMessages.Unix,
+ To: nil,
+ }
+ if len(to) > 0 {
+ ep, err := extractEndpoint(t, to)
+ if err != nil {
+ return 0, err
+ }
+ defer ep.Release()
+ w.To = ep
+ }
+
+ n, err := src.CopyInTo(t, &w)
+ if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return int(n), syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ n, err = src.CopyInTo(t, &w)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+
+ return int(total), syserr.FromError(err)
+}
+
+// Passcred implements transport.Credentialer.Passcred.
+func (s *SocketOperations) Passcred() bool {
+ return s.ep.Passcred()
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *SocketOperations) ConnectedPasscred() bool {
+ return s.ep.ConnectedPasscred()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.ep.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := epsocket.ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return s.ep.Shutdown(f)
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+ peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
+
+ // Calculate the number of FDs for which we have space and if we are
+ // requesting credentials.
+ var wantCreds bool
+ rightsLen := int(controlDataLen) - syscall.SizeofCmsghdr
+ if s.Passcred() {
+ // Credentials take priority if they are enabled and there is space.
+ wantCreds = rightsLen > 0
+ if !wantCreds {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+ credLen := syscall.CmsgSpace(syscall.SizeofUcred)
+ rightsLen -= credLen
+ }
+ // FDs are 32 bit (4 byte) ints.
+ numRights := rightsLen / 4
+ if numRights < 0 {
+ numRights = 0
+ }
+
+ r := EndpointReader{
+ Endpoint: s.ep,
+ Creds: wantCreds,
+ NumRights: uintptr(numRights),
+ Peek: peek,
+ }
+ if senderRequested {
+ r.From = &tcpip.FullAddress{}
+ }
+ var total int64
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
+ var from interface{}
+ var fromLen uint32
+ if r.From != nil {
+ from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
+ if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if s.isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+
+ return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ total += n
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ var from interface{}
+ var fromLen uint32
+ if r.From != nil {
+ from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
+ if trunc {
+ // n and r.MsgSize are the same for streams.
+ total += int64(r.MsgSize)
+ } else {
+ total += n
+ }
+
+ if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if total > 0 {
+ err = nil
+ }
+ if s.isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if total > 0 {
+ err = nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// provider is a unix domain socket provider.
+type provider struct{}
+
+// Socket returns a new unix domain socket.
+func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ var isPacket bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ isPacket = true
+ ep = transport.NewConnectionless()
+ case linux.SOCK_SEQPACKET:
+ isPacket = true
+ fallthrough
+ case linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return New(t, ep, isPacket), nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ var isPacket bool
+ switch stype {
+ case linux.SOCK_STREAM:
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ isPacket = true
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(stype, t.Kernel())
+ s1 := New(t, ep1, isPacket)
+ s2 := New(t, ep2, isPacket)
+
+ return s1, s2, nil
+}
+
+func init() {
+ socket.RegisterProvider(linux.AF_UNIX, &provider{})
+}
diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go
new file mode 100755
index 000000000..6f8d24b44
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_state_autogen.go
@@ -0,0 +1,28 @@
+// automatically generated by stateify.
+
+package unix
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *SocketOperations) beforeSave() {}
+func (x *SocketOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("SendReceiveTimeout", &x.SendReceiveTimeout)
+ m.Save("ep", &x.ep)
+ m.Save("isPacket", &x.isPacket)
+}
+
+func (x *SocketOperations) afterLoad() {}
+func (x *SocketOperations) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("SendReceiveTimeout", &x.SendReceiveTimeout)
+ m.Load("ep", &x.ep)
+ m.Load("isPacket", &x.isPacket)
+}
+
+func init() {
+ state.Register("unix.SocketOperations", (*SocketOperations)(nil), state.Fns{Save: (*SocketOperations).save, Load: (*SocketOperations).load})
+}
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
new file mode 100644
index 000000000..27fde505b
--- /dev/null
+++ b/pkg/sentry/state/state.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package state provides high-level state wrappers.
+package state
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/watchdog"
+ "gvisor.googlesource.com/gvisor/pkg/state/statefile"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+var previousMetadata map[string]string
+
+// ErrStateFile is returned when an error is encountered writing the statefile
+// (which may occur during open or close calls in addition to write).
+type ErrStateFile struct {
+ err error
+}
+
+// Error implements error.Error().
+func (e ErrStateFile) Error() string {
+ return fmt.Sprintf("statefile error: %v", e.err)
+}
+
+// SaveOpts contains save-related options.
+type SaveOpts struct {
+ // Destination is the save target.
+ Destination io.Writer
+
+ // Key is used for state integrity check.
+ Key []byte
+
+ // Metadata is save metadata.
+ Metadata map[string]string
+
+ // Callback is called prior to unpause, with any save error.
+ Callback func(err error)
+}
+
+// Save saves the system state.
+func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
+ log.Infof("Sandbox save started, pausing all tasks.")
+ k.Pause()
+ defer k.Unpause()
+ defer log.Infof("Tasks resumed after save.")
+
+ w.Stop()
+ defer w.Start()
+
+ // Supplement the metadata.
+ if opts.Metadata == nil {
+ opts.Metadata = make(map[string]string)
+ }
+ addSaveMetadata(opts.Metadata)
+
+ // Open the statefile.
+ wc, err := statefile.NewWriter(opts.Destination, opts.Key, opts.Metadata)
+ if err != nil {
+ err = ErrStateFile{err}
+ } else {
+ // Save the kernel.
+ err = k.SaveTo(wc)
+
+ // ENOSPC is a state file error. This error can only come from
+ // writing the state file, and not from fs.FileOperations.Fsync
+ // because we wrap those in kernel.TaskSet.flushWritesToFiles.
+ if err == syserror.ENOSPC {
+ err = ErrStateFile{err}
+ }
+
+ if closeErr := wc.Close(); err == nil && closeErr != nil {
+ err = ErrStateFile{closeErr}
+ }
+ }
+ opts.Callback(err)
+ return err
+}
+
+// LoadOpts contains load-related options.
+type LoadOpts struct {
+ // Destination is the load source.
+ Source io.Reader
+
+ // Key is used for state integrity check.
+ Key []byte
+}
+
+// Load loads the given kernel, setting the provided platform and stack.
+func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error {
+ // Open the file.
+ r, m, err := statefile.NewReader(opts.Source, opts.Key)
+ if err != nil {
+ return ErrStateFile{err}
+ }
+
+ previousMetadata = m
+
+ // Restore the Kernel object graph.
+ return k.LoadFrom(r, n)
+}
diff --git a/pkg/sentry/state/state_metadata.go b/pkg/sentry/state/state_metadata.go
new file mode 100644
index 000000000..b8e128c40
--- /dev/null
+++ b/pkg/sentry/state/state_metadata.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+// The save metadata keys for timestamp.
+const (
+ cpuUsage = "cpu_usage"
+ metadataTimestamp = "timestamp"
+)
+
+func addSaveMetadata(m map[string]string) {
+ t, err := CPUTime()
+ if err != nil {
+ log.Warningf("Error getting cpu time: %v", err)
+ }
+ if previousMetadata != nil {
+ p, err := time.ParseDuration(previousMetadata[cpuUsage])
+ if err != nil {
+ log.Warningf("Error parsing previous runs' cpu time: %v", err)
+ }
+ t += p
+ }
+ m[cpuUsage] = t.String()
+
+ m[metadataTimestamp] = fmt.Sprintf("%v", time.Now())
+}
diff --git a/pkg/sentry/state/state_state_autogen.go b/pkg/sentry/state/state_state_autogen.go
new file mode 100755
index 000000000..6c0d9b7a7
--- /dev/null
+++ b/pkg/sentry/state/state_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package state
+
diff --git a/pkg/sentry/state/state_unsafe.go b/pkg/sentry/state/state_unsafe.go
new file mode 100644
index 000000000..7745b6ac6
--- /dev/null
+++ b/pkg/sentry/state/state_unsafe.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "fmt"
+ "syscall"
+ "time"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// CPUTime returns the CPU time usage by Sentry and app.
+func CPUTime() (time.Duration, error) {
+ var ts syscall.Timespec
+ _, _, errno := syscall.RawSyscall(syscall.SYS_CLOCK_GETTIME, uintptr(linux.CLOCK_PROCESS_CPUTIME_ID), uintptr(unsafe.Pointer(&ts)), 0)
+ if errno != 0 {
+ return 0, fmt.Errorf("failed calling clock_gettime(CLOCK_PROCESS_CPUTIME_ID): errno=%d", errno)
+ }
+ return time.Duration(ts.Nano()), nil
+}
diff --git a/pkg/sentry/strace/capability.go b/pkg/sentry/strace/capability.go
new file mode 100644
index 000000000..f85d6636e
--- /dev/null
+++ b/pkg/sentry/strace/capability.go
@@ -0,0 +1,176 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// CapabilityBitset is the set of capabilties in a bitset.
+var CapabilityBitset = abi.FlagSet{
+ {
+ Flag: 1 << uint32(linux.CAP_CHOWN),
+ Name: "CAP_CHOWN",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_DAC_OVERRIDE),
+ Name: "CAP_DAC_OVERRIDE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_DAC_READ_SEARCH),
+ Name: "CAP_DAC_READ_SEARCH",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_FOWNER),
+ Name: "CAP_FOWNER",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_FSETID),
+ Name: "CAP_FSETID",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_KILL),
+ Name: "CAP_KILL",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SETGID),
+ Name: "CAP_SETGID",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SETUID),
+ Name: "CAP_SETUID",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SETPCAP),
+ Name: "CAP_SETPCAP",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_LINUX_IMMUTABLE),
+ Name: "CAP_LINUX_IMMUTABLE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_NET_BIND_SERVICE),
+ Name: "CAP_NET_BIND_SERVICE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_NET_BROADCAST),
+ Name: "CAP_NET_BROADCAST",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_NET_ADMIN),
+ Name: "CAP_NET_ADMIN",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_NET_RAW),
+ Name: "CAP_NET_RAW",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_IPC_LOCK),
+ Name: "CAP_IPC_LOCK",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_IPC_OWNER),
+ Name: "CAP_IPC_OWNER",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_MODULE),
+ Name: "CAP_SYS_MODULE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_RAWIO),
+ Name: "CAP_SYS_RAWIO",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_CHROOT),
+ Name: "CAP_SYS_CHROOT",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_PTRACE),
+ Name: "CAP_SYS_PTRACE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_PACCT),
+ Name: "CAP_SYS_PACCT",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_ADMIN),
+ Name: "CAP_SYS_ADMIN",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_BOOT),
+ Name: "CAP_SYS_BOOT",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_NICE),
+ Name: "CAP_SYS_NICE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_RESOURCE),
+ Name: "CAP_SYS_RESOURCE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_TIME),
+ Name: "CAP_SYS_TIME",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYS_TTY_CONFIG),
+ Name: "CAP_SYS_TTY_CONFIG",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_MKNOD),
+ Name: "CAP_MKNOD",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_LEASE),
+ Name: "CAP_LEASE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_AUDIT_WRITE),
+ Name: "CAP_AUDIT_WRITE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_AUDIT_CONTROL),
+ Name: "CAP_AUDIT_CONTROL",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SETFCAP),
+ Name: "CAP_SETFCAP",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_MAC_OVERRIDE),
+ Name: "CAP_MAC_OVERRIDE",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_MAC_ADMIN),
+ Name: "CAP_MAC_ADMIN",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_SYSLOG),
+ Name: "CAP_SYSLOG",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_WAKE_ALARM),
+ Name: "CAP_WAKE_ALARM",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_BLOCK_SUSPEND),
+ Name: "CAP_BLOCK_SUSPEND",
+ },
+ {
+ Flag: 1 << uint32(linux.CAP_AUDIT_READ),
+ Name: "CAP_AUDIT_READ",
+ },
+}
diff --git a/pkg/sentry/strace/clone.go b/pkg/sentry/strace/clone.go
new file mode 100644
index 000000000..ff6a432c6
--- /dev/null
+++ b/pkg/sentry/strace/clone.go
@@ -0,0 +1,113 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+)
+
+// CloneFlagSet is the set of clone(2) flags.
+var CloneFlagSet = abi.FlagSet{
+ {
+ Flag: syscall.CLONE_VM,
+ Name: "CLONE_VM",
+ },
+ {
+ Flag: syscall.CLONE_FS,
+ Name: "CLONE_FS",
+ },
+ {
+ Flag: syscall.CLONE_FILES,
+ Name: "CLONE_FILES",
+ },
+ {
+ Flag: syscall.CLONE_SIGHAND,
+ Name: "CLONE_SIGHAND",
+ },
+ {
+ Flag: syscall.CLONE_PTRACE,
+ Name: "CLONE_PTRACE",
+ },
+ {
+ Flag: syscall.CLONE_VFORK,
+ Name: "CLONE_VFORK",
+ },
+ {
+ Flag: syscall.CLONE_PARENT,
+ Name: "CLONE_PARENT",
+ },
+ {
+ Flag: syscall.CLONE_THREAD,
+ Name: "CLONE_THREAD",
+ },
+ {
+ Flag: syscall.CLONE_NEWNS,
+ Name: "CLONE_NEWNS",
+ },
+ {
+ Flag: syscall.CLONE_SYSVSEM,
+ Name: "CLONE_SYSVSEM",
+ },
+ {
+ Flag: syscall.CLONE_SETTLS,
+ Name: "CLONE_SETTLS",
+ },
+ {
+ Flag: syscall.CLONE_PARENT_SETTID,
+ Name: "CLONE_PARENT_SETTID",
+ },
+ {
+ Flag: syscall.CLONE_CHILD_CLEARTID,
+ Name: "CLONE_CHILD_CLEARTID",
+ },
+ {
+ Flag: syscall.CLONE_DETACHED,
+ Name: "CLONE_DETACHED",
+ },
+ {
+ Flag: syscall.CLONE_UNTRACED,
+ Name: "CLONE_UNTRACED",
+ },
+ {
+ Flag: syscall.CLONE_CHILD_SETTID,
+ Name: "CLONE_CHILD_SETTID",
+ },
+ {
+ Flag: syscall.CLONE_NEWUTS,
+ Name: "CLONE_NEWUTS",
+ },
+ {
+ Flag: syscall.CLONE_NEWIPC,
+ Name: "CLONE_NEWIPC",
+ },
+ {
+ Flag: syscall.CLONE_NEWUSER,
+ Name: "CLONE_NEWUSER",
+ },
+ {
+ Flag: syscall.CLONE_NEWPID,
+ Name: "CLONE_NEWPID",
+ },
+ {
+ Flag: syscall.CLONE_NEWNET,
+ Name: "CLONE_NEWNET",
+ },
+ {
+ Flag: syscall.CLONE_IO,
+ Name: "CLONE_IO",
+ },
+}
diff --git a/pkg/sentry/strace/futex.go b/pkg/sentry/strace/futex.go
new file mode 100644
index 000000000..24301bda6
--- /dev/null
+++ b/pkg/sentry/strace/futex.go
@@ -0,0 +1,52 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// FutexCmd are the possible futex(2) commands.
+var FutexCmd = abi.ValueSet{
+ linux.FUTEX_WAIT: "FUTEX_WAIT",
+ linux.FUTEX_WAKE: "FUTEX_WAKE",
+ linux.FUTEX_FD: "FUTEX_FD",
+ linux.FUTEX_REQUEUE: "FUTEX_REQUEUE",
+ linux.FUTEX_CMP_REQUEUE: "FUTEX_CMP_REQUEUE",
+ linux.FUTEX_WAKE_OP: "FUTEX_WAKE_OP",
+ linux.FUTEX_LOCK_PI: "FUTEX_LOCK_PI",
+ linux.FUTEX_UNLOCK_PI: "FUTEX_UNLOCK_PI",
+ linux.FUTEX_TRYLOCK_PI: "FUTEX_TRYLOCK_PI",
+ linux.FUTEX_WAIT_BITSET: "FUTEX_WAIT_BITSET",
+ linux.FUTEX_WAKE_BITSET: "FUTEX_WAKE_BITSET",
+ linux.FUTEX_WAIT_REQUEUE_PI: "FUTEX_WAIT_REQUEUE_PI",
+ linux.FUTEX_CMP_REQUEUE_PI: "FUTEX_CMP_REQUEUE_PI",
+}
+
+func futex(op uint64) string {
+ cmd := op &^ (linux.FUTEX_PRIVATE_FLAG | linux.FUTEX_CLOCK_REALTIME)
+ clockRealtime := (op & linux.FUTEX_CLOCK_REALTIME) == linux.FUTEX_CLOCK_REALTIME
+ private := (op & linux.FUTEX_PRIVATE_FLAG) == linux.FUTEX_PRIVATE_FLAG
+
+ s := FutexCmd.Parse(cmd)
+ if clockRealtime {
+ s += "|FUTEX_CLOCK_REALTIME"
+ }
+ if private {
+ s += "|FUTEX_PRIVATE_FLAG"
+ }
+ return s
+}
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go
new file mode 100644
index 000000000..3650fd6e1
--- /dev/null
+++ b/pkg/sentry/strace/linux64.go
@@ -0,0 +1,338 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+// linuxAMD64 provides a mapping of the Linux amd64 syscalls and their argument
+// types for display / formatting.
+var linuxAMD64 = SyscallMap{
+ 0: makeSyscallInfo("read", FD, ReadBuffer, Hex),
+ 1: makeSyscallInfo("write", FD, WriteBuffer, Hex),
+ 2: makeSyscallInfo("open", Path, OpenFlags, Mode),
+ 3: makeSyscallInfo("close", FD),
+ 4: makeSyscallInfo("stat", Path, Stat),
+ 5: makeSyscallInfo("fstat", FD, Stat),
+ 6: makeSyscallInfo("lstat", Path, Stat),
+ 7: makeSyscallInfo("poll", PollFDs, Hex, Hex),
+ 8: makeSyscallInfo("lseek", Hex, Hex, Hex),
+ 9: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 10: makeSyscallInfo("mprotect", Hex, Hex, Hex),
+ 11: makeSyscallInfo("munmap", Hex, Hex),
+ 12: makeSyscallInfo("brk", Hex),
+ 13: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction),
+ 14: makeSyscallInfo("rt_sigprocmask", SignalMaskAction, SigSet, PostSigSet, Hex),
+ 15: makeSyscallInfo("rt_sigreturn"),
+ 16: makeSyscallInfo("ioctl", FD, Hex, Hex),
+ 17: makeSyscallInfo("pread64", FD, ReadBuffer, Hex, Hex),
+ 18: makeSyscallInfo("pwrite64", FD, WriteBuffer, Hex, Hex),
+ 19: makeSyscallInfo("readv", FD, ReadIOVec, Hex),
+ 20: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
+ 21: makeSyscallInfo("access", Path, Oct),
+ 22: makeSyscallInfo("pipe", PipeFDs),
+ 23: makeSyscallInfo("select", Hex, Hex, Hex, Hex, Timeval),
+ 24: makeSyscallInfo("sched_yield"),
+ 25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
+ 26: makeSyscallInfo("msync", Hex, Hex, Hex),
+ 27: makeSyscallInfo("mincore", Hex, Hex, Hex),
+ 28: makeSyscallInfo("madvise", Hex, Hex, Hex),
+ 29: makeSyscallInfo("shmget", Hex, Hex, Hex),
+ 30: makeSyscallInfo("shmat", Hex, Hex, Hex),
+ 31: makeSyscallInfo("shmctl", Hex, Hex, Hex),
+ 32: makeSyscallInfo("dup", FD),
+ 33: makeSyscallInfo("dup2", FD, FD),
+ 34: makeSyscallInfo("pause"),
+ 35: makeSyscallInfo("nanosleep", Timespec, PostTimespec),
+ 36: makeSyscallInfo("getitimer", ItimerType, PostItimerVal),
+ 37: makeSyscallInfo("alarm", Hex),
+ 38: makeSyscallInfo("setitimer", ItimerType, ItimerVal, PostItimerVal),
+ 39: makeSyscallInfo("getpid"),
+ 40: makeSyscallInfo("sendfile", FD, FD, Hex, Hex),
+ 41: makeSyscallInfo("socket", SockFamily, SockType, SockProtocol),
+ 42: makeSyscallInfo("connect", FD, SockAddr, Hex),
+ 43: makeSyscallInfo("accept", FD, PostSockAddr, SockLen),
+ 44: makeSyscallInfo("sendto", FD, Hex, Hex, Hex, SockAddr, Hex),
+ 45: makeSyscallInfo("recvfrom", FD, Hex, Hex, Hex, PostSockAddr, SockLen),
+ 46: makeSyscallInfo("sendmsg", FD, SendMsgHdr, Hex),
+ 47: makeSyscallInfo("recvmsg", FD, RecvMsgHdr, Hex),
+ 48: makeSyscallInfo("shutdown", FD, Hex),
+ 49: makeSyscallInfo("bind", FD, SockAddr, Hex),
+ 50: makeSyscallInfo("listen", FD, Hex),
+ 51: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
+ 52: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
+ 53: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
+ 54: makeSyscallInfo("setsockopt", FD, Hex, Hex, Hex, Hex),
+ 55: makeSyscallInfo("getsockopt", FD, Hex, Hex, Hex, Hex),
+ 56: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
+ 57: makeSyscallInfo("fork"),
+ 58: makeSyscallInfo("vfork"),
+ 59: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
+ 60: makeSyscallInfo("exit", Hex),
+ 61: makeSyscallInfo("wait4", Hex, Hex, Hex, Rusage),
+ 62: makeSyscallInfo("kill", Hex, Signal),
+ 63: makeSyscallInfo("uname", Uname),
+ 64: makeSyscallInfo("semget", Hex, Hex, Hex),
+ 65: makeSyscallInfo("semop", Hex, Hex, Hex),
+ 66: makeSyscallInfo("semctl", Hex, Hex, Hex, Hex),
+ 67: makeSyscallInfo("shmdt", Hex),
+ 68: makeSyscallInfo("msgget", Hex, Hex),
+ 69: makeSyscallInfo("msgsnd", Hex, Hex, Hex, Hex),
+ 70: makeSyscallInfo("msgrcv", Hex, Hex, Hex, Hex, Hex),
+ 71: makeSyscallInfo("msgctl", Hex, Hex, Hex),
+ 72: makeSyscallInfo("fcntl", FD, Hex, Hex),
+ 73: makeSyscallInfo("flock", FD, Hex),
+ 74: makeSyscallInfo("fsync", FD),
+ 75: makeSyscallInfo("fdatasync", FD),
+ 76: makeSyscallInfo("truncate", Path, Hex),
+ 77: makeSyscallInfo("ftruncate", FD, Hex),
+ 78: makeSyscallInfo("getdents", FD, Hex, Hex),
+ 79: makeSyscallInfo("getcwd", PostPath, Hex),
+ 80: makeSyscallInfo("chdir", Path),
+ 81: makeSyscallInfo("fchdir", FD),
+ 82: makeSyscallInfo("rename", Path, Path),
+ 83: makeSyscallInfo("mkdir", Path, Oct),
+ 84: makeSyscallInfo("rmdir", Path),
+ 85: makeSyscallInfo("creat", Path, Oct),
+ 86: makeSyscallInfo("link", Path, Path),
+ 87: makeSyscallInfo("unlink", Path),
+ 88: makeSyscallInfo("symlink", Path, Path),
+ 89: makeSyscallInfo("readlink", Path, ReadBuffer, Hex),
+ 90: makeSyscallInfo("chmod", Path, Mode),
+ 91: makeSyscallInfo("fchmod", FD, Mode),
+ 92: makeSyscallInfo("chown", Path, Hex, Hex),
+ 93: makeSyscallInfo("fchown", FD, Hex, Hex),
+ 94: makeSyscallInfo("lchown", Path, Hex, Hex),
+ 95: makeSyscallInfo("umask", Hex),
+ 96: makeSyscallInfo("gettimeofday", Timeval, Hex),
+ 97: makeSyscallInfo("getrlimit", Hex, Hex),
+ 98: makeSyscallInfo("getrusage", Hex, Rusage),
+ 99: makeSyscallInfo("sysinfo", Hex),
+ 100: makeSyscallInfo("times", Hex),
+ 101: makeSyscallInfo("ptrace", PtraceRequest, Hex, Hex, Hex),
+ 102: makeSyscallInfo("getuid"),
+ 103: makeSyscallInfo("syslog", Hex, Hex, Hex),
+ 104: makeSyscallInfo("getgid"),
+ 105: makeSyscallInfo("setuid", Hex),
+ 106: makeSyscallInfo("setgid", Hex),
+ 107: makeSyscallInfo("geteuid"),
+ 108: makeSyscallInfo("getegid"),
+ 109: makeSyscallInfo("setpgid", Hex, Hex),
+ 110: makeSyscallInfo("getppid"),
+ 111: makeSyscallInfo("getpgrp"),
+ 112: makeSyscallInfo("setsid"),
+ 113: makeSyscallInfo("setreuid", Hex, Hex),
+ 114: makeSyscallInfo("setregid", Hex, Hex),
+ 115: makeSyscallInfo("getgroups", Hex, Hex),
+ 116: makeSyscallInfo("setgroups", Hex, Hex),
+ 117: makeSyscallInfo("setresuid", Hex, Hex, Hex),
+ 118: makeSyscallInfo("getresuid", Hex, Hex, Hex),
+ 119: makeSyscallInfo("setresgid", Hex, Hex, Hex),
+ 120: makeSyscallInfo("getresgid", Hex, Hex, Hex),
+ 121: makeSyscallInfo("getpgid", Hex),
+ 122: makeSyscallInfo("setfsuid", Hex),
+ 123: makeSyscallInfo("setfsgid", Hex),
+ 124: makeSyscallInfo("getsid", Hex),
+ 125: makeSyscallInfo("capget", CapHeader, PostCapData),
+ 126: makeSyscallInfo("capset", CapHeader, CapData),
+ 127: makeSyscallInfo("rt_sigpending", Hex),
+ 128: makeSyscallInfo("rt_sigtimedwait", SigSet, Hex, Timespec, Hex),
+ 129: makeSyscallInfo("rt_sigqueueinfo", Hex, Signal, Hex),
+ 130: makeSyscallInfo("rt_sigsuspend", Hex),
+ 131: makeSyscallInfo("sigaltstack", Hex, Hex),
+ 132: makeSyscallInfo("utime", Path, Utimbuf),
+ 133: makeSyscallInfo("mknod", Path, Mode, Hex),
+ 134: makeSyscallInfo("uselib", Hex),
+ 135: makeSyscallInfo("personality", Hex),
+ 136: makeSyscallInfo("ustat", Hex, Hex),
+ 137: makeSyscallInfo("statfs", Path, Hex),
+ 138: makeSyscallInfo("fstatfs", FD, Hex),
+ 139: makeSyscallInfo("sysfs", Hex, Hex, Hex),
+ 140: makeSyscallInfo("getpriority", Hex, Hex),
+ 141: makeSyscallInfo("setpriority", Hex, Hex, Hex),
+ 142: makeSyscallInfo("sched_setparam", Hex, Hex),
+ 143: makeSyscallInfo("sched_getparam", Hex, Hex),
+ 144: makeSyscallInfo("sched_setscheduler", Hex, Hex, Hex),
+ 145: makeSyscallInfo("sched_getscheduler", Hex),
+ 146: makeSyscallInfo("sched_get_priority_max", Hex),
+ 147: makeSyscallInfo("sched_get_priority_min", Hex),
+ 148: makeSyscallInfo("sched_rr_get_interval", Hex, Hex),
+ 149: makeSyscallInfo("mlock", Hex, Hex),
+ 150: makeSyscallInfo("munlock", Hex, Hex),
+ 151: makeSyscallInfo("mlockall", Hex),
+ 152: makeSyscallInfo("munlockall"),
+ 153: makeSyscallInfo("vhangup"),
+ 154: makeSyscallInfo("modify_ldt", Hex, Hex, Hex),
+ 155: makeSyscallInfo("pivot_root", Path, Path),
+ 156: makeSyscallInfo("_sysctl", Hex),
+ 157: makeSyscallInfo("prctl", Hex, Hex, Hex, Hex, Hex),
+ 158: makeSyscallInfo("arch_prctl", Hex, Hex),
+ 159: makeSyscallInfo("adjtimex", Hex),
+ 160: makeSyscallInfo("setrlimit", Hex, Hex),
+ 161: makeSyscallInfo("chroot", Path),
+ 162: makeSyscallInfo("sync"),
+ 163: makeSyscallInfo("acct", Hex),
+ 164: makeSyscallInfo("settimeofday", Timeval, Hex),
+ 165: makeSyscallInfo("mount", Path, Path, Path, Hex, Path),
+ 166: makeSyscallInfo("umount2", Path, Hex),
+ 167: makeSyscallInfo("swapon", Hex, Hex),
+ 168: makeSyscallInfo("swapoff", Hex),
+ 169: makeSyscallInfo("reboot", Hex, Hex, Hex, Hex),
+ 170: makeSyscallInfo("sethostname", Hex, Hex),
+ 171: makeSyscallInfo("setdomainname", Hex, Hex),
+ 172: makeSyscallInfo("iopl", Hex),
+ 173: makeSyscallInfo("ioperm", Hex, Hex, Hex),
+ 174: makeSyscallInfo("create_module", Path, Hex),
+ 175: makeSyscallInfo("init_module", Hex, Hex, Hex),
+ 176: makeSyscallInfo("delete_module", Hex, Hex),
+ 177: makeSyscallInfo("get_kernel_syms", Hex),
+ // 178: query_module (only present in Linux < 2.6)
+ 179: makeSyscallInfo("quotactl", Hex, Hex, Hex, Hex),
+ 180: makeSyscallInfo("nfsservctl", Hex, Hex, Hex),
+ // 181: getpmsg (not implemented in the Linux kernel)
+ // 182: putpmsg (not implemented in the Linux kernel)
+ // 183: afs_syscall (not implemented in the Linux kernel)
+ // 184: tuxcall (not implemented in the Linux kernel)
+ // 185: security (not implemented in the Linux kernel)
+ 186: makeSyscallInfo("gettid"),
+ 187: makeSyscallInfo("readahead", Hex, Hex, Hex),
+ 188: makeSyscallInfo("setxattr", Path, Path, Hex, Hex, Hex),
+ 189: makeSyscallInfo("lsetxattr", Path, Path, Hex, Hex, Hex),
+ 190: makeSyscallInfo("fsetxattr", FD, Path, Hex, Hex, Hex),
+ 191: makeSyscallInfo("getxattr", Path, Path, Hex, Hex),
+ 192: makeSyscallInfo("lgetxattr", Path, Path, Hex, Hex),
+ 193: makeSyscallInfo("fgetxattr", FD, Path, Hex, Hex),
+ 194: makeSyscallInfo("listxattr", Path, Path, Hex),
+ 195: makeSyscallInfo("llistxattr", Path, Path, Hex),
+ 196: makeSyscallInfo("flistxattr", FD, Path, Hex),
+ 197: makeSyscallInfo("removexattr", Path, Path),
+ 198: makeSyscallInfo("lremovexattr", Path, Path),
+ 199: makeSyscallInfo("fremovexattr", FD, Path),
+ 200: makeSyscallInfo("tkill", Hex, Signal),
+ 201: makeSyscallInfo("time", Hex),
+ 202: makeSyscallInfo("futex", Hex, FutexOp, Hex, Timespec, Hex, Hex),
+ 203: makeSyscallInfo("sched_setaffinity", Hex, Hex, Hex),
+ 204: makeSyscallInfo("sched_getaffinity", Hex, Hex, Hex),
+ 205: makeSyscallInfo("set_thread_area", Hex),
+ 206: makeSyscallInfo("io_setup", Hex, Hex),
+ 207: makeSyscallInfo("io_destroy", Hex),
+ 208: makeSyscallInfo("io_getevents", Hex, Hex, Hex, Hex, Timespec),
+ 209: makeSyscallInfo("io_submit", Hex, Hex, Hex),
+ 210: makeSyscallInfo("io_cancel", Hex, Hex, Hex),
+ 211: makeSyscallInfo("get_thread_area", Hex),
+ 212: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
+ 213: makeSyscallInfo("epoll_create", Hex),
+ // 214: epoll_ctl_old (not implemented in the Linux kernel)
+ // 215: epoll_wait_old (not implemented in the Linux kernel)
+ 216: makeSyscallInfo("remap_file_pages", Hex, Hex, Hex, Hex, Hex),
+ 217: makeSyscallInfo("getdents64", FD, Hex, Hex),
+ 218: makeSyscallInfo("set_tid_address", Hex),
+ 219: makeSyscallInfo("restart_syscall"),
+ 220: makeSyscallInfo("semtimedop", Hex, Hex, Hex, Hex),
+ 221: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
+ 222: makeSyscallInfo("timer_create", Hex, Hex, Hex),
+ 223: makeSyscallInfo("timer_settime", Hex, Hex, ItimerSpec, PostItimerSpec),
+ 224: makeSyscallInfo("timer_gettime", Hex, PostItimerSpec),
+ 225: makeSyscallInfo("timer_getoverrun", Hex),
+ 226: makeSyscallInfo("timer_delete", Hex),
+ 227: makeSyscallInfo("clock_settime", Hex, Timespec),
+ 228: makeSyscallInfo("clock_gettime", Hex, PostTimespec),
+ 229: makeSyscallInfo("clock_getres", Hex, PostTimespec),
+ 230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
+ 231: makeSyscallInfo("exit_group", Hex),
+ 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex),
+ 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 234: makeSyscallInfo("tgkill", Hex, Hex, Signal),
+ 235: makeSyscallInfo("utimes", Path, Timeval),
+ // 236: vserver (not implemented in the Linux kernel)
+ 237: makeSyscallInfo("mbind", Hex, Hex, Hex, Hex, Hex, Hex),
+ 238: makeSyscallInfo("set_mempolicy", Hex, Hex, Hex),
+ 239: makeSyscallInfo("get_mempolicy", Hex, Hex, Hex, Hex, Hex),
+ 240: makeSyscallInfo("mq_open", Hex, Hex, Hex, Hex),
+ 241: makeSyscallInfo("mq_unlink", Hex),
+ 242: makeSyscallInfo("mq_timedsend", Hex, Hex, Hex, Hex, Hex),
+ 243: makeSyscallInfo("mq_timedreceive", Hex, Hex, Hex, Hex, Hex),
+ 244: makeSyscallInfo("mq_notify", Hex, Hex),
+ 245: makeSyscallInfo("mq_getsetattr", Hex, Hex, Hex),
+ 246: makeSyscallInfo("kexec_load", Hex, Hex, Hex, Hex),
+ 247: makeSyscallInfo("waitid", Hex, Hex, Hex, Hex, Rusage),
+ 248: makeSyscallInfo("add_key", Hex, Hex, Hex, Hex, Hex),
+ 249: makeSyscallInfo("request_key", Hex, Hex, Hex, Hex),
+ 250: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
+ 251: makeSyscallInfo("ioprio_set", Hex, Hex, Hex),
+ 252: makeSyscallInfo("ioprio_get", Hex, Hex),
+ 253: makeSyscallInfo("inotify_init"),
+ 254: makeSyscallInfo("inotify_add_watch", Hex, Path, Hex),
+ 255: makeSyscallInfo("inotify_rm_watch", Hex, Hex),
+ 256: makeSyscallInfo("migrate_pages", Hex, Hex, Hex, Hex),
+ 257: makeSyscallInfo("openat", FD, Path, OpenFlags, Mode),
+ 258: makeSyscallInfo("mkdirat", FD, Path, Hex),
+ 259: makeSyscallInfo("mknodat", FD, Path, Mode, Hex),
+ 260: makeSyscallInfo("fchownat", FD, Path, Hex, Hex, Hex),
+ 261: makeSyscallInfo("futimesat", FD, Path, Hex),
+ 262: makeSyscallInfo("newfstatat", FD, Path, Stat, Hex),
+ 263: makeSyscallInfo("unlinkat", FD, Path, Hex),
+ 264: makeSyscallInfo("renameat", FD, Path, Hex, Path),
+ 265: makeSyscallInfo("linkat", FD, Path, Hex, Path, Hex),
+ 266: makeSyscallInfo("symlinkat", Path, Hex, Path),
+ 267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
+ 268: makeSyscallInfo("fchmodat", FD, Path, Mode),
+ 269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
+ 270: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
+ 272: makeSyscallInfo("unshare", CloneFlags),
+ 273: makeSyscallInfo("set_robust_list", Hex, Hex),
+ 274: makeSyscallInfo("get_robust_list", Hex, Hex, Hex),
+ 275: makeSyscallInfo("splice", FD, Hex, FD, Hex, Hex, Hex),
+ 276: makeSyscallInfo("tee", FD, FD, Hex, Hex),
+ 277: makeSyscallInfo("sync_file_range", FD, Hex, Hex, Hex),
+ 278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
+ 279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
+ 280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
+ 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 282: makeSyscallInfo("signalfd", Hex, Hex, Hex),
+ 283: makeSyscallInfo("timerfd_create", Hex, Hex),
+ 284: makeSyscallInfo("eventfd", Hex),
+ 285: makeSyscallInfo("fallocate", FD, Hex, Hex, Hex),
+ 286: makeSyscallInfo("timerfd_settime", FD, Hex, ItimerSpec, PostItimerSpec),
+ 287: makeSyscallInfo("timerfd_gettime", FD, PostItimerSpec),
+ 288: makeSyscallInfo("accept4", FD, PostSockAddr, SockLen, SockFlags),
+ 289: makeSyscallInfo("signalfd4", Hex, Hex, Hex, Hex),
+ 290: makeSyscallInfo("eventfd2", Hex, Hex),
+ 291: makeSyscallInfo("epoll_create1", Hex),
+ 292: makeSyscallInfo("dup3", FD, FD, Hex),
+ 293: makeSyscallInfo("pipe2", PipeFDs, Hex),
+ 294: makeSyscallInfo("inotify_init1", Hex),
+ 295: makeSyscallInfo("preadv", FD, ReadIOVec, Hex, Hex),
+ 296: makeSyscallInfo("pwritev", FD, WriteIOVec, Hex, Hex),
+ 297: makeSyscallInfo("rt_tgsigqueueinfo", Hex, Hex, Signal, Hex),
+ 298: makeSyscallInfo("perf_event_open", Hex, Hex, Hex, Hex, Hex),
+ 299: makeSyscallInfo("recvmmsg", FD, Hex, Hex, Hex, Hex),
+ 300: makeSyscallInfo("fanotify_init", Hex, Hex),
+ 301: makeSyscallInfo("fanotify_mark", Hex, Hex, Hex, Hex, Hex),
+ 302: makeSyscallInfo("prlimit64", Hex, Hex, Hex, Hex),
+ 303: makeSyscallInfo("name_to_handle_at", FD, Hex, Hex, Hex, Hex),
+ 304: makeSyscallInfo("open_by_handle_at", FD, Hex, Hex),
+ 305: makeSyscallInfo("clock_adjtime", Hex, Hex),
+ 306: makeSyscallInfo("syncfs", FD),
+ 307: makeSyscallInfo("sendmmsg", FD, Hex, Hex, Hex),
+ 308: makeSyscallInfo("setns", FD, Hex),
+ 309: makeSyscallInfo("getcpu", Hex, Hex, Hex),
+ 310: makeSyscallInfo("process_vm_readv", Hex, ReadIOVec, Hex, IOVec, Hex, Hex),
+ 311: makeSyscallInfo("process_vm_writev", Hex, IOVec, Hex, WriteIOVec, Hex, Hex),
+ 312: makeSyscallInfo("kcmp", Hex, Hex, Hex, Hex, Hex),
+ 313: makeSyscallInfo("finit_module", Hex, Hex, Hex),
+ 314: makeSyscallInfo("sched_setattr", Hex, Hex, Hex),
+ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
+ 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
+ 317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+}
diff --git a/pkg/sentry/strace/open.go b/pkg/sentry/strace/open.go
new file mode 100644
index 000000000..140727b02
--- /dev/null
+++ b/pkg/sentry/strace/open.go
@@ -0,0 +1,96 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+)
+
+// OpenMode represents the mode to open(2) a file.
+var OpenMode = abi.ValueSet{
+ syscall.O_RDWR: "O_RDWR",
+ syscall.O_WRONLY: "O_WRONLY",
+ syscall.O_RDONLY: "O_RDONLY",
+}
+
+// OpenFlagSet is the set of open(2) flags.
+var OpenFlagSet = abi.FlagSet{
+ {
+ Flag: syscall.O_APPEND,
+ Name: "O_APPEND",
+ },
+ {
+ Flag: syscall.O_ASYNC,
+ Name: "O_ASYNC",
+ },
+ {
+ Flag: syscall.O_CLOEXEC,
+ Name: "O_CLOEXEC",
+ },
+ {
+ Flag: syscall.O_CREAT,
+ Name: "O_CREAT",
+ },
+ {
+ Flag: syscall.O_DIRECT,
+ Name: "O_DIRECT",
+ },
+ {
+ Flag: syscall.O_DIRECTORY,
+ Name: "O_DIRECTORY",
+ },
+ {
+ Flag: syscall.O_EXCL,
+ Name: "O_EXCL",
+ },
+ {
+ Flag: syscall.O_NOATIME,
+ Name: "O_NOATIME",
+ },
+ {
+ Flag: syscall.O_NOCTTY,
+ Name: "O_NOCTTY",
+ },
+ {
+ Flag: syscall.O_NOFOLLOW,
+ Name: "O_NOFOLLOW",
+ },
+ {
+ Flag: syscall.O_NONBLOCK,
+ Name: "O_NONBLOCK",
+ },
+ {
+ Flag: 0x200000, // O_PATH
+ Name: "O_PATH",
+ },
+ {
+ Flag: syscall.O_SYNC,
+ Name: "O_SYNC",
+ },
+ {
+ Flag: syscall.O_TRUNC,
+ Name: "O_TRUNC",
+ },
+}
+
+func open(val uint64) string {
+ s := OpenMode.Parse(val & syscall.O_ACCMODE)
+ if flags := OpenFlagSet.Parse(val &^ syscall.O_ACCMODE); flags != "" {
+ s += "|" + flags
+ }
+ return s
+}
diff --git a/pkg/sentry/strace/poll.go b/pkg/sentry/strace/poll.go
new file mode 100644
index 000000000..15605187d
--- /dev/null
+++ b/pkg/sentry/strace/poll.go
@@ -0,0 +1,72 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ slinux "gvisor.googlesource.com/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// PollEventSet is the set of poll(2) event flags.
+var PollEventSet = abi.FlagSet{
+ {Flag: linux.POLLIN, Name: "POLLIN"},
+ {Flag: linux.POLLPRI, Name: "POLLPRI"},
+ {Flag: linux.POLLOUT, Name: "POLLOUT"},
+ {Flag: linux.POLLERR, Name: "POLLERR"},
+ {Flag: linux.POLLHUP, Name: "POLLHUP"},
+ {Flag: linux.POLLNVAL, Name: "POLLNVAL"},
+ {Flag: linux.POLLRDNORM, Name: "POLLRDNORM"},
+ {Flag: linux.POLLRDBAND, Name: "POLLRDBAND"},
+ {Flag: linux.POLLWRNORM, Name: "POLLWRNORM"},
+ {Flag: linux.POLLWRBAND, Name: "POLLWRBAND"},
+ {Flag: linux.POLLMSG, Name: "POLLMSG"},
+ {Flag: linux.POLLREMOVE, Name: "POLLREMOVE"},
+ {Flag: linux.POLLRDHUP, Name: "POLLRDHUP"},
+ {Flag: linux.POLLFREE, Name: "POLLFREE"},
+ {Flag: linux.POLL_BUSY_LOOP, Name: "POLL_BUSY_LOOP"},
+}
+
+func pollFD(t *kernel.Task, pfd *linux.PollFD, post bool) string {
+ revents := "..."
+ if post {
+ revents = PollEventSet.Parse(uint64(pfd.REvents))
+ }
+ return fmt.Sprintf("{FD: %s, Events: %s, REvents: %s}", fd(t, kdefs.FD(pfd.FD)), PollEventSet.Parse(uint64(pfd.Events)), revents)
+}
+
+func pollFDs(t *kernel.Task, addr usermem.Addr, nfds uint, post bool) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ pfds, err := slinux.CopyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding pollfds: %s)", addr, err)
+ }
+
+ s := make([]string, 0, len(pfds))
+ for i := range pfds {
+ s = append(s, pollFD(t, &pfds[i], post))
+ }
+
+ return fmt.Sprintf("%#x [%s]", addr, strings.Join(s, ", "))
+}
diff --git a/pkg/sentry/strace/ptrace.go b/pkg/sentry/strace/ptrace.go
new file mode 100644
index 000000000..485aacb8a
--- /dev/null
+++ b/pkg/sentry/strace/ptrace.go
@@ -0,0 +1,62 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// PtraceRequestSet are the possible ptrace(2) requests.
+var PtraceRequestSet = abi.ValueSet{
+ linux.PTRACE_TRACEME: "PTRACE_TRACEME",
+ linux.PTRACE_PEEKTEXT: "PTRACE_PEEKTEXT",
+ linux.PTRACE_PEEKDATA: "PTRACE_PEEKDATA",
+ linux.PTRACE_PEEKUSR: "PTRACE_PEEKUSR",
+ linux.PTRACE_POKETEXT: "PTRACE_POKETEXT",
+ linux.PTRACE_POKEDATA: "PTRACE_POKEDATA",
+ linux.PTRACE_POKEUSR: "PTRACE_POKEUSR",
+ linux.PTRACE_CONT: "PTRACE_CONT",
+ linux.PTRACE_KILL: "PTRACE_KILL",
+ linux.PTRACE_SINGLESTEP: "PTRACE_SINGLESTEP",
+ linux.PTRACE_ATTACH: "PTRACE_ATTACH",
+ linux.PTRACE_DETACH: "PTRACE_DETACH",
+ linux.PTRACE_SYSCALL: "PTRACE_SYSCALL",
+ linux.PTRACE_SETOPTIONS: "PTRACE_SETOPTIONS",
+ linux.PTRACE_GETEVENTMSG: "PTRACE_GETEVENTMSG",
+ linux.PTRACE_GETSIGINFO: "PTRACE_GETSIGINFO",
+ linux.PTRACE_SETSIGINFO: "PTRACE_SETSIGINFO",
+ linux.PTRACE_GETREGSET: "PTRACE_GETREGSET",
+ linux.PTRACE_SETREGSET: "PTRACE_SETREGSET",
+ linux.PTRACE_SEIZE: "PTRACE_SEIZE",
+ linux.PTRACE_INTERRUPT: "PTRACE_INTERRUPT",
+ linux.PTRACE_LISTEN: "PTRACE_LISTEN",
+ linux.PTRACE_PEEKSIGINFO: "PTRACE_PEEKSIGINFO",
+ linux.PTRACE_GETSIGMASK: "PTRACE_GETSIGMASK",
+ linux.PTRACE_SETSIGMASK: "PTRACE_SETSIGMASK",
+ linux.PTRACE_GETREGS: "PTRACE_GETREGS",
+ linux.PTRACE_SETREGS: "PTRACE_SETREGS",
+ linux.PTRACE_GETFPREGS: "PTRACE_GETFPREGS",
+ linux.PTRACE_SETFPREGS: "PTRACE_SETFPREGS",
+ linux.PTRACE_GETFPXREGS: "PTRACE_GETFPXREGS",
+ linux.PTRACE_SETFPXREGS: "PTRACE_SETFPXREGS",
+ linux.PTRACE_OLDSETOPTIONS: "PTRACE_OLDSETOPTIONS",
+ linux.PTRACE_GET_THREAD_AREA: "PTRACE_GET_THREAD_AREA",
+ linux.PTRACE_SET_THREAD_AREA: "PTRACE_SET_THREAD_AREA",
+ linux.PTRACE_ARCH_PRCTL: "PTRACE_ARCH_PRCTL",
+ linux.PTRACE_SYSEMU: "PTRACE_SYSEMU",
+ linux.PTRACE_SYSEMU_SINGLESTEP: "PTRACE_SYSEMU_SINGLESTEP",
+ linux.PTRACE_SINGLEBLOCK: "PTRACE_SINGLEBLOCK",
+}
diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go
new file mode 100644
index 000000000..f82460e1c
--- /dev/null
+++ b/pkg/sentry/strace/signal.go
@@ -0,0 +1,148 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// signalNames contains the names of all named signals.
+var signalNames = abi.ValueSet{
+ uint64(linux.SIGABRT): "SIGABRT",
+ uint64(linux.SIGALRM): "SIGALRM",
+ uint64(linux.SIGBUS): "SIGBUS",
+ uint64(linux.SIGCHLD): "SIGCHLD",
+ uint64(linux.SIGCONT): "SIGCONT",
+ uint64(linux.SIGFPE): "SIGFPE",
+ uint64(linux.SIGHUP): "SIGHUP",
+ uint64(linux.SIGILL): "SIGILL",
+ uint64(linux.SIGINT): "SIGINT",
+ uint64(linux.SIGIO): "SIGIO",
+ uint64(linux.SIGKILL): "SIGKILL",
+ uint64(linux.SIGPIPE): "SIGPIPE",
+ uint64(linux.SIGPROF): "SIGPROF",
+ uint64(linux.SIGPWR): "SIGPWR",
+ uint64(linux.SIGQUIT): "SIGQUIT",
+ uint64(linux.SIGSEGV): "SIGSEGV",
+ uint64(linux.SIGSTKFLT): "SIGSTKFLT",
+ uint64(linux.SIGSTOP): "SIGSTOP",
+ uint64(linux.SIGSYS): "SIGSYS",
+ uint64(linux.SIGTERM): "SIGTERM",
+ uint64(linux.SIGTRAP): "SIGTRAP",
+ uint64(linux.SIGTSTP): "SIGTSTP",
+ uint64(linux.SIGTTIN): "SIGTTIN",
+ uint64(linux.SIGTTOU): "SIGTTOU",
+ uint64(linux.SIGURG): "SIGURG",
+ uint64(linux.SIGUSR1): "SIGUSR1",
+ uint64(linux.SIGUSR2): "SIGUSR2",
+ uint64(linux.SIGVTALRM): "SIGVTALRM",
+ uint64(linux.SIGWINCH): "SIGWINCH",
+ uint64(linux.SIGXCPU): "SIGXCPU",
+ uint64(linux.SIGXFSZ): "SIGXFSZ",
+}
+
+var signalMaskActions = abi.ValueSet{
+ linux.SIG_BLOCK: "SIG_BLOCK",
+ linux.SIG_UNBLOCK: "SIG_UNBLOCK",
+ linux.SIG_SETMASK: "SIG_SETMASK",
+}
+
+var sigActionFlags = abi.FlagSet{
+ {
+ Flag: linux.SA_NOCLDSTOP,
+ Name: "SA_NOCLDSTOP",
+ },
+ {
+ Flag: linux.SA_NOCLDWAIT,
+ Name: "SA_NOCLDWAIT",
+ },
+ {
+ Flag: linux.SA_SIGINFO,
+ Name: "SA_SIGINFO",
+ },
+ {
+ Flag: linux.SA_RESTORER,
+ Name: "SA_RESTORER",
+ },
+ {
+ Flag: linux.SA_ONSTACK,
+ Name: "SA_ONSTACK",
+ },
+ {
+ Flag: linux.SA_RESTART,
+ Name: "SA_RESTART",
+ },
+ {
+ Flag: linux.SA_NODEFER,
+ Name: "SA_NODEFER",
+ },
+ {
+ Flag: linux.SA_RESETHAND,
+ Name: "SA_RESETHAND",
+ },
+}
+
+func sigSet(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var b [linux.SignalSetSize]byte
+ if _, err := t.CopyInBytes(addr, b[:]); err != nil {
+ return fmt.Sprintf("%#x (error copying sigset: %v)", addr, err)
+ }
+
+ set := linux.SignalSet(usermem.ByteOrder.Uint64(b[:]))
+
+ return fmt.Sprintf("%#x %s", addr, formatSigSet(set))
+}
+
+func formatSigSet(set linux.SignalSet) string {
+ var signals []string
+ linux.ForEachSignal(set, func(sig linux.Signal) {
+ signals = append(signals, signalNames.ParseDecimal(uint64(sig)))
+ })
+
+ return fmt.Sprintf("[%v]", strings.Join(signals, " "))
+}
+
+func sigAction(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ sa, err := t.CopyInSignalAct(addr)
+ if err != nil {
+ return fmt.Sprintf("%#x (error copying sigaction: %v)", addr, err)
+ }
+
+ var handler string
+ switch sa.Handler {
+ case linux.SIG_IGN:
+ handler = "SIG_IGN"
+ case linux.SIG_DFL:
+ handler = "SIG_DFL"
+ default:
+ handler = fmt.Sprintf("%#x", sa.Handler)
+ }
+
+ return fmt.Sprintf("%#x {Handler: %s, Flags: %s, Restorer: %#x, Mask: %s}", addr, handler, sigActionFlags.Parse(sa.Flags), sa.Restorer, formatSigSet(sa.Mask))
+}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
new file mode 100644
index 000000000..dbe53b9a2
--- /dev/null
+++ b/pkg/sentry/strace/socket.go
@@ -0,0 +1,412 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/epsocket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink"
+ slinux "gvisor.googlesource.com/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// SocketFamily are the possible socket(2) families.
+var SocketFamily = abi.ValueSet{
+ linux.AF_UNSPEC: "AF_UNSPEC",
+ linux.AF_UNIX: "AF_UNIX",
+ linux.AF_INET: "AF_INET",
+ linux.AF_AX25: "AF_AX25",
+ linux.AF_IPX: "AF_IPX",
+ linux.AF_APPLETALK: "AF_APPLETALK",
+ linux.AF_NETROM: "AF_NETROM",
+ linux.AF_BRIDGE: "AF_BRIDGE",
+ linux.AF_ATMPVC: "AF_ATMPVC",
+ linux.AF_X25: "AF_X25",
+ linux.AF_INET6: "AF_INET6",
+ linux.AF_ROSE: "AF_ROSE",
+ linux.AF_DECnet: "AF_DECnet",
+ linux.AF_NETBEUI: "AF_NETBEUI",
+ linux.AF_SECURITY: "AF_SECURITY",
+ linux.AF_KEY: "AF_KEY",
+ linux.AF_NETLINK: "AF_NETLINK",
+ linux.AF_PACKET: "AF_PACKET",
+ linux.AF_ASH: "AF_ASH",
+ linux.AF_ECONET: "AF_ECONET",
+ linux.AF_ATMSVC: "AF_ATMSVC",
+ linux.AF_RDS: "AF_RDS",
+ linux.AF_SNA: "AF_SNA",
+ linux.AF_IRDA: "AF_IRDA",
+ linux.AF_PPPOX: "AF_PPPOX",
+ linux.AF_WANPIPE: "AF_WANPIPE",
+ linux.AF_LLC: "AF_LLC",
+ linux.AF_IB: "AF_IB",
+ linux.AF_MPLS: "AF_MPLS",
+ linux.AF_CAN: "AF_CAN",
+ linux.AF_TIPC: "AF_TIPC",
+ linux.AF_BLUETOOTH: "AF_BLUETOOTH",
+ linux.AF_IUCV: "AF_IUCV",
+ linux.AF_RXRPC: "AF_RXRPC",
+ linux.AF_ISDN: "AF_ISDN",
+ linux.AF_PHONET: "AF_PHONET",
+ linux.AF_IEEE802154: "AF_IEEE802154",
+ linux.AF_CAIF: "AF_CAIF",
+ linux.AF_ALG: "AF_ALG",
+ linux.AF_NFC: "AF_NFC",
+ linux.AF_VSOCK: "AF_VSOCK",
+}
+
+// SocketType are the possible socket(2) types.
+var SocketType = abi.ValueSet{
+ linux.SOCK_STREAM: "SOCK_STREAM",
+ linux.SOCK_DGRAM: "SOCK_DGRAM",
+ linux.SOCK_RAW: "SOCK_RAW",
+ linux.SOCK_RDM: "SOCK_RDM",
+ linux.SOCK_SEQPACKET: "SOCK_SEQPACKET",
+ linux.SOCK_DCCP: "SOCK_DCCP",
+ linux.SOCK_PACKET: "SOCK_PACKET",
+}
+
+// SocketFlagSet are the possible socket(2) flags.
+var SocketFlagSet = abi.FlagSet{
+ {
+ Flag: linux.SOCK_CLOEXEC,
+ Name: "SOCK_CLOEXEC",
+ },
+ {
+ Flag: linux.SOCK_NONBLOCK,
+ Name: "SOCK_NONBLOCK",
+ },
+}
+
+// ipProtocol are the possible socket(2) types for INET and INET6 sockets.
+var ipProtocol = abi.ValueSet{
+ linux.IPPROTO_IP: "IPPROTO_IP",
+ linux.IPPROTO_ICMP: "IPPROTO_ICMP",
+ linux.IPPROTO_IGMP: "IPPROTO_IGMP",
+ linux.IPPROTO_IPIP: "IPPROTO_IPIP",
+ linux.IPPROTO_TCP: "IPPROTO_TCP",
+ linux.IPPROTO_EGP: "IPPROTO_EGP",
+ linux.IPPROTO_PUP: "IPPROTO_PUP",
+ linux.IPPROTO_UDP: "IPPROTO_UDP",
+ linux.IPPROTO_IDP: "IPPROTO_IDP",
+ linux.IPPROTO_TP: "IPPROTO_TP",
+ linux.IPPROTO_DCCP: "IPPROTO_DCCP",
+ linux.IPPROTO_IPV6: "IPPROTO_IPV6",
+ linux.IPPROTO_RSVP: "IPPROTO_RSVP",
+ linux.IPPROTO_GRE: "IPPROTO_GRE",
+ linux.IPPROTO_ESP: "IPPROTO_ESP",
+ linux.IPPROTO_AH: "IPPROTO_AH",
+ linux.IPPROTO_MTP: "IPPROTO_MTP",
+ linux.IPPROTO_BEETPH: "IPPROTO_BEETPH",
+ linux.IPPROTO_ENCAP: "IPPROTO_ENCAP",
+ linux.IPPROTO_PIM: "IPPROTO_PIM",
+ linux.IPPROTO_COMP: "IPPROTO_COMP",
+ linux.IPPROTO_SCTP: "IPPROTO_SCTP",
+ linux.IPPROTO_UDPLITE: "IPPROTO_UDPLITE",
+ linux.IPPROTO_MPLS: "IPPROTO_MPLS",
+ linux.IPPROTO_RAW: "IPPROTO_RAW",
+}
+
+// SocketProtocol are the possible socket(2) protocols for each protocol family.
+var SocketProtocol = map[int32]abi.ValueSet{
+ linux.AF_INET: ipProtocol,
+ linux.AF_INET6: ipProtocol,
+ linux.AF_NETLINK: {
+ linux.NETLINK_ROUTE: "NETLINK_ROUTE",
+ linux.NETLINK_UNUSED: "NETLINK_UNUSED",
+ linux.NETLINK_USERSOCK: "NETLINK_USERSOCK",
+ linux.NETLINK_FIREWALL: "NETLINK_FIREWALL",
+ linux.NETLINK_SOCK_DIAG: "NETLINK_SOCK_DIAG",
+ linux.NETLINK_NFLOG: "NETLINK_NFLOG",
+ linux.NETLINK_XFRM: "NETLINK_XFRM",
+ linux.NETLINK_SELINUX: "NETLINK_SELINUX",
+ linux.NETLINK_ISCSI: "NETLINK_ISCSI",
+ linux.NETLINK_AUDIT: "NETLINK_AUDIT",
+ linux.NETLINK_FIB_LOOKUP: "NETLINK_FIB_LOOKUP",
+ linux.NETLINK_CONNECTOR: "NETLINK_CONNECTOR",
+ linux.NETLINK_NETFILTER: "NETLINK_NETFILTER",
+ linux.NETLINK_IP6_FW: "NETLINK_IP6_FW",
+ linux.NETLINK_DNRTMSG: "NETLINK_DNRTMSG",
+ linux.NETLINK_KOBJECT_UEVENT: "NETLINK_KOBJECT_UEVENT",
+ linux.NETLINK_GENERIC: "NETLINK_GENERIC",
+ linux.NETLINK_SCSITRANSPORT: "NETLINK_SCSITRANSPORT",
+ linux.NETLINK_ECRYPTFS: "NETLINK_ECRYPTFS",
+ linux.NETLINK_RDMA: "NETLINK_RDMA",
+ linux.NETLINK_CRYPTO: "NETLINK_CRYPTO",
+ },
+}
+
+var controlMessageType = map[int32]string{
+ linux.SCM_RIGHTS: "SCM_RIGHTS",
+ linux.SCM_CREDENTIALS: "SCM_CREDENTIALS",
+ linux.SO_TIMESTAMP: "SO_TIMESTAMP",
+}
+
+func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) string {
+ if length > maxBytes {
+ return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length)
+ }
+
+ buf := make([]byte, length)
+ if _, err := t.CopyIn(addr, &buf); err != nil {
+ return fmt.Sprintf("%#x (error decoding control: %v)", addr, err)
+ }
+
+ var strs []string
+
+ for i := 0; i < len(buf); {
+ if i+linux.SizeOfControlMessageHeader > len(buf) {
+ strs = append(strs, "{invalid control message (too short)}")
+ break
+ }
+
+ var h linux.ControlMessageHeader
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
+
+ var skipData bool
+ level := "SOL_SOCKET"
+ if h.Level != linux.SOL_SOCKET {
+ skipData = true
+ level = fmt.Sprint(h.Level)
+ }
+
+ typ, ok := controlMessageType[h.Type]
+ if !ok {
+ skipData = true
+ typ = fmt.Sprint(h.Type)
+ }
+
+ if h.Length > uint64(len(buf)-i) {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content extends beyond buffer}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
+
+ i += linux.SizeOfControlMessageHeader
+ width := t.Arch().Width()
+ length := int(h.Length) - linux.SizeOfControlMessageHeader
+
+ if skipData {
+ strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
+ i += control.AlignUp(length, width)
+ continue
+ }
+
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := control.AlignDown(length, linux.SizeOfControlMessageRight)
+
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+ fds := make(linux.ControlMessageRights, numRights)
+ binary.Unmarshal(buf[i:i+rightsSize], usermem.ByteOrder, &fds)
+
+ rights := make([]string, 0, len(fds))
+ for _, fd := range fds {
+ rights = append(rights, fmt.Sprint(fd))
+ }
+
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content: %s}",
+ level,
+ typ,
+ h.Length,
+ strings.Join(rights, ","),
+ ))
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
+
+ var creds linux.ControlMessageCredentials
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
+ level,
+ typ,
+ h.Length,
+ creds.PID,
+ creds.UID,
+ creds.GID,
+ ))
+
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
+
+ var tv linux.Timeval
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &tv)
+
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
+ level,
+ typ,
+ h.Length,
+ tv.Sec,
+ tv.Usec,
+ ))
+
+ default:
+ panic("unreachable")
+ }
+ i += control.AlignUp(length, width)
+ }
+
+ return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
+}
+
+func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string {
+ var msg slinux.MessageHeader64
+ if err := slinux.CopyInMessageHeader64(t, addr, &msg); err != nil {
+ return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err)
+ }
+ s := fmt.Sprintf(
+ "%#x {name=%#x, namelen=%d, iovecs=%s",
+ addr,
+ msg.Name,
+ msg.NameLen,
+ iovecs(t, usermem.Addr(msg.Iov), int(msg.IovLen), printContent, maxBytes),
+ )
+ if printContent {
+ s = fmt.Sprintf("%s, control={%s}", s, cmsghdr(t, usermem.Addr(msg.Control), msg.ControlLen, maxBytes))
+ } else {
+ s = fmt.Sprintf("%s, control=%#x, control_len=%d", s, msg.Control, msg.ControlLen)
+ }
+ return fmt.Sprintf("%s, flags=%d}", s, msg.Flags)
+}
+
+func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ b, err := slinux.CaptureAddress(t, addr, length)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading address: %v}", addr, err)
+ }
+
+ // Extract address family.
+ if len(b) < 2 {
+ return fmt.Sprintf("%#x {address too short: %d bytes}", addr, len(b))
+ }
+ family := usermem.ByteOrder.Uint16(b)
+
+ familyStr := SocketFamily.Parse(uint64(family))
+
+ switch family {
+ case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
+ fa, err := epsocket.GetAddress(int(family), b)
+ if err != nil {
+ return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
+ }
+
+ if family == linux.AF_UNIX {
+ return fmt.Sprintf("%#x {Family: %s, Addr: %q}", addr, familyStr, string(fa.Addr))
+ }
+
+ return fmt.Sprintf("%#x {Family: %s, Addr: %v, Port: %d}", addr, familyStr, fa.Addr, fa.Port)
+ case linux.AF_NETLINK:
+ sa, err := netlink.ExtractSockAddr(b)
+ if err != nil {
+ return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
+ }
+ return fmt.Sprintf("%#x {Family: %s, PortID: %d, Groups: %d}", addr, familyStr, sa.PortID, sa.Groups)
+ default:
+ return fmt.Sprintf("%#x {Family: %s, family addr format unknown}", addr, familyStr)
+ }
+}
+
+func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ if lengthPtr == 0 {
+ return fmt.Sprintf("%#x {length null}", addr)
+ }
+
+ l, err := copySockLen(t, lengthPtr)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading length: %v}", addr, err)
+ }
+
+ return sockAddr(t, addr, l)
+}
+
+func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) {
+ // socklen_t is 32-bits.
+ var l uint32
+ _, err := t.CopyIn(addr, &l)
+ return l, err
+}
+
+func sockLenPointer(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+ l, err := copySockLen(t, addr)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading length: %v}", addr, err)
+ }
+ return fmt.Sprintf("%#x {length=%v}", addr, l)
+}
+
+func sockType(stype int32) string {
+ s := SocketType.Parse(uint64(stype & linux.SOCK_TYPE_MASK))
+ if flags := SocketFlagSet.Parse(uint64(stype &^ linux.SOCK_TYPE_MASK)); flags != "" {
+ s += "|" + flags
+ }
+ return s
+}
+
+func sockProtocol(family, protocol int32) string {
+ protocols, ok := SocketProtocol[family]
+ if !ok {
+ return fmt.Sprintf("%#x", protocol)
+ }
+ return protocols.Parse(uint64(protocol))
+}
+
+func sockFlags(flags int32) string {
+ if flags == 0 {
+ return "0"
+ }
+ return SocketFlagSet.Parse(uint64(flags))
+}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
new file mode 100644
index 000000000..f4c1be4ce
--- /dev/null
+++ b/pkg/sentry/strace/strace.go
@@ -0,0 +1,820 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package strace implements the logic to print out the input and the return value
+// of each traced syscall.
+package strace
+
+import (
+ "encoding/binary"
+ "fmt"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/bits"
+ "gvisor.googlesource.com/gvisor/pkg/eventchannel"
+ "gvisor.googlesource.com/gvisor/pkg/seccomp"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/strace/strace_go_proto"
+ slinux "gvisor.googlesource.com/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// DefaultLogMaximumSize is the default LogMaximumSize.
+const DefaultLogMaximumSize = 1024
+
+// LogMaximumSize determines the maximum display size for data blobs (read,
+// write, etc.).
+var LogMaximumSize uint = DefaultLogMaximumSize
+
+// EventMaximumSize determines the maximum size for data blobs (read, write,
+// etc.) sent over the event channel. Default is 0 because most clients cannot
+// do anything useful with binary text dump of byte array arguments.
+var EventMaximumSize uint
+
+// ItimerTypes are the possible itimer types.
+var ItimerTypes = abi.ValueSet{
+ linux.ITIMER_REAL: "ITIMER_REAL",
+ linux.ITIMER_VIRTUAL: "ITIMER_VIRTUAL",
+ linux.ITIMER_PROF: "ITIMER_PROF",
+}
+
+func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, maxBytes uint64) string {
+ if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
+ return fmt.Sprintf("%#x (error decoding iovecs: invalid iovcnt)", addr)
+ }
+ ars, err := t.CopyInIovecs(addr, iovcnt)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding iovecs: %v)", addr, err)
+ }
+
+ var totalBytes uint64
+ var truncated bool
+ iovs := make([]string, iovcnt)
+ for i := 0; !ars.IsEmpty(); i, ars = i+1, ars.Tail() {
+ ar := ars.Head()
+ if ar.Length() == 0 || !printContent {
+ iovs[i] = fmt.Sprintf("{base=%#x, len=%d}", ar.Start, ar.Length())
+ continue
+ }
+
+ size := uint64(ar.Length())
+ if truncated || totalBytes+size > maxBytes {
+ truncated = true
+ size = maxBytes - totalBytes
+ } else {
+ totalBytes += uint64(ar.Length())
+ }
+
+ b := make([]byte, size)
+ amt, err := t.CopyIn(ar.Start, b)
+ if err != nil {
+ iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q..., error decoding string: %v}", ar.Start, ar.Length(), b[:amt], err)
+ continue
+ }
+
+ dot := ""
+ if truncated {
+ // Indicate truncation.
+ dot = "..."
+ }
+ iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q%s}", ar.Start, ar.Length(), b[:amt], dot)
+ }
+
+ return fmt.Sprintf("%#x %s", addr, strings.Join(iovs, ", "))
+}
+
+func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) string {
+ origSize := size
+ if size > maximumBlobSize {
+ size = maximumBlobSize
+ }
+ if size == 0 {
+ return ""
+ }
+
+ b := make([]byte, size)
+ amt, err := t.CopyIn(addr, b)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding string: %s)", addr, err)
+ }
+
+ dot := ""
+ if uint(amt) < origSize {
+ // ... if we truncated the dump.
+ dot = "..."
+ }
+
+ return fmt.Sprintf("%#x %q%s", addr, b[:amt], dot)
+}
+
+func path(t *kernel.Task, addr usermem.Addr) string {
+ path, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding path: %s)", addr, err)
+ }
+ return fmt.Sprintf("%#x %s", addr, path)
+}
+
+func fd(t *kernel.Task, fd kdefs.FD) string {
+ root := t.FSContext().RootDirectory()
+ if root != nil {
+ defer root.DecRef()
+ }
+
+ if fd == linux.AT_FDCWD {
+ wd := t.FSContext().WorkingDirectory()
+ var name string
+ if wd != nil {
+ defer wd.DecRef()
+ name, _ = wd.FullName(root)
+ } else {
+ name = "(unknown cwd)"
+ }
+ return fmt.Sprintf("AT_FDCWD %s", name)
+ }
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ // Cast FD to uint64 to avoid printing negative hex.
+ return fmt.Sprintf("%#x (bad FD)", uint64(fd))
+ }
+ defer file.DecRef()
+
+ name, _ := file.Dirent.FullName(root)
+ return fmt.Sprintf("%#x %s", fd, name)
+}
+
+func fdpair(t *kernel.Task, addr usermem.Addr) string {
+ var fds [2]int32
+ _, err := t.CopyIn(addr, &fds)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding fds: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x [%d %d]", addr, fds[0], fds[1])
+}
+
+func uname(t *kernel.Task, addr usermem.Addr) string {
+ var u linux.UtsName
+ if _, err := t.CopyIn(addr, &u); err != nil {
+ return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x %s", addr, u)
+}
+
+func utimensTimespec(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var tim linux.Timespec
+ if _, err := t.CopyIn(addr, &tim); err != nil {
+ return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err)
+ }
+
+ var ns string
+ switch tim.Nsec {
+ case linux.UTIME_NOW:
+ ns = "UTIME_NOW"
+ case linux.UTIME_OMIT:
+ ns = "UTIME_OMIT"
+ default:
+ ns = fmt.Sprintf("%v", tim.Nsec)
+ }
+ return fmt.Sprintf("%#x {sec=%v nsec=%s}", addr, tim.Sec, ns)
+}
+
+func timespec(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var tim linux.Timespec
+ if _, err := t.CopyIn(addr, &tim); err != nil {
+ return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err)
+ }
+ return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec)
+}
+
+func timeval(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var tim linux.Timeval
+ if _, err := t.CopyIn(addr, &tim); err != nil {
+ return fmt.Sprintf("%#x (error decoding timeval: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x {sec=%v usec=%v}", addr, tim.Sec, tim.Usec)
+}
+
+func utimbuf(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var utim syscall.Utimbuf
+ if _, err := t.CopyIn(addr, &utim); err != nil {
+ return fmt.Sprintf("%#x (error decoding utimbuf: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x {actime=%v, modtime=%v}", addr, utim.Actime, utim.Modtime)
+}
+
+func stat(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var stat linux.Stat
+ if _, err := t.CopyIn(addr, &stat); err != nil {
+ return fmt.Sprintf("%#x (error decoding stat: %s)", addr, err)
+ }
+ return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec))
+}
+
+func itimerval(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ interval := timeval(t, addr)
+ value := timeval(t, addr+usermem.Addr(binary.Size(linux.Timeval{})))
+ return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
+}
+
+func itimerspec(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ interval := timespec(t, addr)
+ value := timespec(t, addr+usermem.Addr(binary.Size(linux.Timespec{})))
+ return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
+}
+
+func stringVector(t *kernel.Task, addr usermem.Addr) string {
+ vec, err := t.CopyInVector(addr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return fmt.Sprintf("%#x {error copying vector: %v}", addr, err)
+ }
+ s := fmt.Sprintf("%#x [", addr)
+ for i, v := range vec {
+ if i != 0 {
+ s += ", "
+ }
+ s += fmt.Sprintf("%q", v)
+ }
+ s += "]"
+ return s
+}
+
+func rusage(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var ru linux.Rusage
+ if _, err := t.CopyIn(addr, &ru); err != nil {
+ return fmt.Sprintf("%#x (error decoding rusage: %s)", addr, err)
+ }
+ return fmt.Sprintf("%#x %+v", addr, ru)
+}
+
+func capHeader(t *kernel.Task, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ var hdr linux.CapUserHeader
+ if _, err := t.CopyIn(addr, &hdr); err != nil {
+ return fmt.Sprintf("%#x (error decoding header: %s)", addr, err)
+ }
+
+ var version string
+ switch hdr.Version {
+ case linux.LINUX_CAPABILITY_VERSION_1:
+ version = "1"
+ case linux.LINUX_CAPABILITY_VERSION_2:
+ version = "2"
+ case linux.LINUX_CAPABILITY_VERSION_3:
+ version = "3"
+ default:
+ version = strconv.FormatUint(uint64(hdr.Version), 16)
+ }
+
+ return fmt.Sprintf("%#x {Version: %s, Pid: %d}", addr, version, hdr.Pid)
+}
+
+func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
+ if dataAddr == 0 {
+ return "null"
+ }
+
+ var hdr linux.CapUserHeader
+ if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ return fmt.Sprintf("%#x (error decoding header: %v)", dataAddr, err)
+ }
+
+ var p, i, e uint64
+
+ switch hdr.Version {
+ case linux.LINUX_CAPABILITY_VERSION_1:
+ var data linux.CapUserData
+ if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err)
+ }
+ p = uint64(data.Permitted)
+ i = uint64(data.Inheritable)
+ e = uint64(data.Effective)
+ case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
+ var data [2]linux.CapUserData
+ if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err)
+ }
+ p = uint64(data[0].Permitted) | (uint64(data[1].Permitted) << 32)
+ i = uint64(data[0].Inheritable) | (uint64(data[1].Inheritable) << 32)
+ e = uint64(data[0].Effective) | (uint64(data[1].Effective) << 32)
+ default:
+ return fmt.Sprintf("%#x (unknown version %d)", dataAddr, hdr.Version)
+ }
+
+ return fmt.Sprintf("%#x {Permitted: %s, Inheritable: %s, Effective: %s}", dataAddr, CapabilityBitset.Parse(p), CapabilityBitset.Parse(i), CapabilityBitset.Parse(e))
+}
+
+// pre fills in the pre-execution arguments for a system call. If an argument
+// cannot be interpreted before the system call is executed, then a hex value
+// will be used. Note that a full output slice will always be provided, that is
+// len(return) == len(args).
+func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlobSize uint) []string {
+ var output []string
+
+ for arg := range args {
+ if arg >= len(i.format) {
+ break
+ }
+ switch i.format[arg] {
+ case FD:
+ output = append(output, fd(t, kdefs.FD(args[arg].Int())))
+ case WriteBuffer:
+ output = append(output, dump(t, args[arg].Pointer(), args[arg+1].SizeT(), maximumBlobSize))
+ case WriteIOVec:
+ output = append(output, iovecs(t, args[arg].Pointer(), int(args[arg+1].Int()), true /* content */, uint64(maximumBlobSize)))
+ case IOVec:
+ output = append(output, iovecs(t, args[arg].Pointer(), int(args[arg+1].Int()), false /* content */, uint64(maximumBlobSize)))
+ case SendMsgHdr:
+ output = append(output, msghdr(t, args[arg].Pointer(), true /* content */, uint64(maximumBlobSize)))
+ case RecvMsgHdr:
+ output = append(output, msghdr(t, args[arg].Pointer(), false /* content */, uint64(maximumBlobSize)))
+ case Path:
+ output = append(output, path(t, args[arg].Pointer()))
+ case ExecveStringVector:
+ output = append(output, stringVector(t, args[arg].Pointer()))
+ case SockAddr:
+ output = append(output, sockAddr(t, args[arg].Pointer(), uint32(args[arg+1].Uint64())))
+ case SockLen:
+ output = append(output, sockLenPointer(t, args[arg].Pointer()))
+ case SockFamily:
+ output = append(output, SocketFamily.Parse(uint64(args[arg].Int())))
+ case SockType:
+ output = append(output, sockType(args[arg].Int()))
+ case SockProtocol:
+ output = append(output, sockProtocol(args[arg-2].Int(), args[arg].Int()))
+ case SockFlags:
+ output = append(output, sockFlags(args[arg].Int()))
+ case Timespec:
+ output = append(output, timespec(t, args[arg].Pointer()))
+ case UTimeTimespec:
+ output = append(output, utimensTimespec(t, args[arg].Pointer()))
+ case ItimerVal:
+ output = append(output, itimerval(t, args[arg].Pointer()))
+ case ItimerSpec:
+ output = append(output, itimerspec(t, args[arg].Pointer()))
+ case Timeval:
+ output = append(output, timeval(t, args[arg].Pointer()))
+ case Utimbuf:
+ output = append(output, utimbuf(t, args[arg].Pointer()))
+ case CloneFlags:
+ output = append(output, CloneFlagSet.Parse(uint64(args[arg].Uint())))
+ case OpenFlags:
+ output = append(output, open(uint64(args[arg].Uint())))
+ case Mode:
+ output = append(output, linux.FileMode(args[arg].ModeT()).String())
+ case FutexOp:
+ output = append(output, futex(uint64(args[arg].Uint())))
+ case PtraceRequest:
+ output = append(output, PtraceRequestSet.Parse(args[arg].Uint64()))
+ case ItimerType:
+ output = append(output, ItimerTypes.Parse(uint64(args[arg].Int())))
+ case Signal:
+ output = append(output, signalNames.ParseDecimal(args[arg].Uint64()))
+ case SignalMaskAction:
+ output = append(output, signalMaskActions.Parse(uint64(args[arg].Int())))
+ case SigSet:
+ output = append(output, sigSet(t, args[arg].Pointer()))
+ case SigAction:
+ output = append(output, sigAction(t, args[arg].Pointer()))
+ case CapHeader:
+ output = append(output, capHeader(t, args[arg].Pointer()))
+ case CapData:
+ output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
+ case PollFDs:
+ output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case Oct:
+ output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8))
+ case Hex:
+ fallthrough
+ default:
+ output = append(output, "0x"+strconv.FormatUint(args[arg].Uint64(), 16))
+ }
+ }
+
+ return output
+}
+
+// post fills in the post-execution arguments for a system call. This modifies
+// the given output slice in place with arguments that may only be interpreted
+// after the system call has been executed.
+func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uintptr, output []string, maximumBlobSize uint) {
+ for arg := range output {
+ if arg >= len(i.format) {
+ break
+ }
+ switch i.format[arg] {
+ case ReadBuffer:
+ output[arg] = dump(t, args[arg].Pointer(), uint(rval), maximumBlobSize)
+ case ReadIOVec:
+ printLength := uint64(rval)
+ if printLength > uint64(maximumBlobSize) {
+ printLength = uint64(maximumBlobSize)
+ }
+ output[arg] = iovecs(t, args[arg].Pointer(), int(args[arg+1].Int()), true /* content */, printLength)
+ case WriteIOVec, IOVec, WriteBuffer:
+ // We already have a big blast from write.
+ output[arg] = "..."
+ case SendMsgHdr:
+ output[arg] = msghdr(t, args[arg].Pointer(), false /* content */, uint64(maximumBlobSize))
+ case RecvMsgHdr:
+ output[arg] = msghdr(t, args[arg].Pointer(), true /* content */, uint64(maximumBlobSize))
+ case PostPath:
+ output[arg] = path(t, args[arg].Pointer())
+ case PipeFDs:
+ output[arg] = fdpair(t, args[arg].Pointer())
+ case Uname:
+ output[arg] = uname(t, args[arg].Pointer())
+ case Stat:
+ output[arg] = stat(t, args[arg].Pointer())
+ case PostSockAddr:
+ output[arg] = postSockAddr(t, args[arg].Pointer(), args[arg+1].Pointer())
+ case SockLen:
+ output[arg] = sockLenPointer(t, args[arg].Pointer())
+ case PostTimespec:
+ output[arg] = timespec(t, args[arg].Pointer())
+ case PostItimerVal:
+ output[arg] = itimerval(t, args[arg].Pointer())
+ case PostItimerSpec:
+ output[arg] = itimerspec(t, args[arg].Pointer())
+ case Timeval:
+ output[arg] = timeval(t, args[arg].Pointer())
+ case Rusage:
+ output[arg] = rusage(t, args[arg].Pointer())
+ case PostSigSet:
+ output[arg] = sigSet(t, args[arg].Pointer())
+ case PostSigAction:
+ output[arg] = sigAction(t, args[arg].Pointer())
+ case PostCapData:
+ output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer())
+ case PollFDs:
+ output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true)
+ }
+ }
+}
+
+// printEntry prints the given system call entry.
+func (i *SyscallInfo) printEnter(t *kernel.Task, args arch.SyscallArguments) []string {
+ output := i.pre(t, args, LogMaximumSize)
+
+ switch len(output) {
+ case 0:
+ t.Infof("%s E %s()", t.Name(), i.name)
+ case 1:
+ t.Infof("%s E %s(%s)", t.Name(), i.name,
+ output[0])
+ case 2:
+ t.Infof("%s E %s(%s, %s)", t.Name(), i.name,
+ output[0], output[1])
+ case 3:
+ t.Infof("%s E %s(%s, %s, %s)", t.Name(), i.name,
+ output[0], output[1], output[2])
+ case 4:
+ t.Infof("%s E %s(%s, %s, %s, %s)", t.Name(), i.name,
+ output[0], output[1], output[2], output[3])
+ case 5:
+ t.Infof("%s E %s(%s, %s, %s, %s, %s)", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], output[4])
+ case 6:
+ t.Infof("%s E %s(%s, %s, %s, %s, %s, %s)", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], output[4], output[5])
+ }
+
+ return output
+}
+
+// printExit prints the given system call exit.
+func (i *SyscallInfo) printExit(t *kernel.Task, elapsed time.Duration, output []string, args arch.SyscallArguments, retval uintptr, err error, errno int) {
+ var rval string
+ if err == nil {
+ // Fill in the output after successful execution.
+ i.post(t, args, retval, output, LogMaximumSize)
+ rval = fmt.Sprintf("%#x (%v)", retval, elapsed)
+ } else {
+ rval = fmt.Sprintf("%#x errno=%d (%s) (%v)", retval, errno, err, elapsed)
+ }
+
+ switch len(output) {
+ case 0:
+ t.Infof("%s X %s() = %s", t.Name(), i.name,
+ rval)
+ case 1:
+ t.Infof("%s X %s(%s) = %s", t.Name(), i.name,
+ output[0], rval)
+ case 2:
+ t.Infof("%s X %s(%s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], rval)
+ case 3:
+ t.Infof("%s X %s(%s, %s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], output[2], rval)
+ case 4:
+ t.Infof("%s X %s(%s, %s, %s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], rval)
+ case 5:
+ t.Infof("%s X %s(%s, %s, %s, %s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], output[4], rval)
+ case 6:
+ t.Infof("%s X %s(%s, %s, %s, %s, %s, %s) = %s", t.Name(), i.name,
+ output[0], output[1], output[2], output[3], output[4], output[5], rval)
+ }
+}
+
+// sendEnter sends the syscall enter to event log.
+func (i *SyscallInfo) sendEnter(t *kernel.Task, args arch.SyscallArguments) []string {
+ output := i.pre(t, args, EventMaximumSize)
+
+ event := pb.Strace{
+ Process: t.Name(),
+ Function: i.name,
+ Info: &pb.Strace_Enter{
+ Enter: &pb.StraceEnter{},
+ },
+ }
+ for _, arg := range output {
+ event.Args = append(event.Args, arg)
+ }
+ eventchannel.Emit(&event)
+
+ return output
+}
+
+// sendExit sends the syscall exit to event log.
+func (i *SyscallInfo) sendExit(t *kernel.Task, elapsed time.Duration, output []string, args arch.SyscallArguments, rval uintptr, err error, errno int) {
+ if err == nil {
+ // Fill in the output after successful execution.
+ i.post(t, args, rval, output, EventMaximumSize)
+ }
+
+ exit := &pb.StraceExit{
+ Return: fmt.Sprintf("%#x", rval),
+ ElapsedNs: elapsed.Nanoseconds(),
+ }
+ if err != nil {
+ exit.Error = err.Error()
+ exit.ErrNo = int64(errno)
+ }
+ event := pb.Strace{
+ Process: t.Name(),
+ Function: i.name,
+ Info: &pb.Strace_Exit{Exit: exit},
+ }
+ for _, arg := range output {
+ event.Args = append(event.Args, arg)
+ }
+ eventchannel.Emit(&event)
+}
+
+type syscallContext struct {
+ info SyscallInfo
+ args arch.SyscallArguments
+ start time.Time
+ logOutput []string
+ eventOutput []string
+ flags uint32
+}
+
+// SyscallEnter implements kernel.Stracer.SyscallEnter. It logs the syscall
+// entry trace.
+func (s SyscallMap) SyscallEnter(t *kernel.Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{} {
+ info, ok := s[sysno]
+ if !ok {
+ info = SyscallInfo{
+ name: fmt.Sprintf("sys_%d", sysno),
+ format: defaultFormat,
+ }
+ }
+
+ var output, eventOutput []string
+ if bits.IsOn32(flags, kernel.StraceEnableLog) {
+ output = info.printEnter(t, args)
+ }
+ if bits.IsOn32(flags, kernel.StraceEnableEvent) {
+ eventOutput = info.sendEnter(t, args)
+ }
+
+ return &syscallContext{
+ info: info,
+ args: args,
+ start: time.Now(),
+ logOutput: output,
+ eventOutput: eventOutput,
+ flags: flags,
+ }
+}
+
+// SyscallExit implements kernel.Stracer.SyscallExit. It logs the syscall
+// exit trace.
+func (s SyscallMap) SyscallExit(context interface{}, t *kernel.Task, sysno, rval uintptr, err error) {
+ errno := t.ExtractErrno(err, int(sysno))
+ c := context.(*syscallContext)
+
+ elapsed := time.Since(c.start)
+ if bits.IsOn32(c.flags, kernel.StraceEnableLog) {
+ c.info.printExit(t, elapsed, c.logOutput, c.args, rval, err, errno)
+ }
+ if bits.IsOn32(c.flags, kernel.StraceEnableEvent) {
+ c.info.sendExit(t, elapsed, c.eventOutput, c.args, rval, err, errno)
+ }
+}
+
+// ConvertToSysnoMap converts the names to a map keyed on the syscall number
+// and value set to true.
+//
+// The map is in a convenient format to pass to SyscallFlagsTable.Enable().
+func (s SyscallMap) ConvertToSysnoMap(syscalls []string) (map[uintptr]bool, error) {
+ if syscalls == nil {
+ // Sentinel: no list.
+ return nil, nil
+ }
+
+ l := make(map[uintptr]bool)
+ for _, sc := range syscalls {
+ // Try to match this system call.
+ sysno, ok := s.ConvertToSysno(sc)
+ if !ok {
+ return nil, fmt.Errorf("syscall %q not found", sc)
+ }
+ l[sysno] = true
+ }
+
+ // Success.
+ return l, nil
+}
+
+// ConvertToSysno converts the name to system call number. Returns false
+// if syscall with same name is not found.
+func (s SyscallMap) ConvertToSysno(syscall string) (uintptr, bool) {
+ for sysno, info := range s {
+ if info.name != "" && info.name == syscall {
+ return sysno, true
+ }
+ }
+ return 0, false
+}
+
+// Name returns the syscall name.
+func (s SyscallMap) Name(sysno uintptr) string {
+ if info, ok := s[sysno]; ok {
+ return info.name
+ }
+ return fmt.Sprintf("sys_%d", sysno)
+}
+
+// Initialize prepares all syscall tables for use by this package.
+//
+// N.B. This is not in an init function because we can't be sure all syscall
+// tables are registered with the kernel when init runs.
+//
+// TODO(gvisor.dev/issue/155): remove kernel package dependencies from this
+// package and have the kernel package self-initialize all syscall tables.
+func Initialize() {
+ for _, table := range kernel.SyscallTables() {
+ // Is this known?
+ sys, ok := Lookup(table.OS, table.Arch)
+ if !ok {
+ continue
+ }
+
+ table.Stracer = sys
+ }
+}
+
+// SinkType defines where to send straces to.
+type SinkType uint32
+
+const (
+ // SinkTypeLog sends straces to text log
+ SinkTypeLog SinkType = 1 << iota
+
+ // SinkTypeEvent sends strace to event log
+ SinkTypeEvent
+)
+
+func convertToSyscallFlag(sinks SinkType) uint32 {
+ ret := uint32(0)
+ if bits.IsOn32(uint32(sinks), uint32(SinkTypeLog)) {
+ ret |= kernel.StraceEnableLog
+ }
+ if bits.IsOn32(uint32(sinks), uint32(SinkTypeEvent)) {
+ ret |= kernel.StraceEnableEvent
+ }
+ return ret
+}
+
+// Enable enables the syscalls in whitelist in all syscall tables.
+//
+// Preconditions: Initialize has been called.
+func Enable(whitelist []string, sinks SinkType) error {
+ flags := convertToSyscallFlag(sinks)
+ for _, table := range kernel.SyscallTables() {
+ // Is this known?
+ sys, ok := Lookup(table.OS, table.Arch)
+ if !ok {
+ continue
+ }
+
+ // Convert to a set of system calls numbers.
+ wl, err := sys.ConvertToSysnoMap(whitelist)
+ if err != nil {
+ return err
+ }
+
+ table.FeatureEnable.Enable(flags, wl, true)
+ }
+
+ // Done.
+ return nil
+}
+
+// Disable will disable Strace for all system calls and missing syscalls.
+//
+// Preconditions: Initialize has been called.
+func Disable(sinks SinkType) {
+ flags := convertToSyscallFlag(sinks)
+ for _, table := range kernel.SyscallTables() {
+ // Strace will be disabled for all syscalls including missing.
+ table.FeatureEnable.Enable(flags, nil, false)
+ }
+}
+
+// EnableAll enables all syscalls in all syscall tables.
+//
+// Preconditions: Initialize has been called.
+func EnableAll(sinks SinkType) {
+ flags := convertToSyscallFlag(sinks)
+ for _, table := range kernel.SyscallTables() {
+ // Is this known?
+ if _, ok := Lookup(table.OS, table.Arch); !ok {
+ continue
+ }
+
+ table.FeatureEnable.EnableAll(flags)
+ }
+}
+
+func init() {
+ t, ok := Lookup(abi.Host, arch.Host)
+ if ok {
+ // Provide the native table as the lookup for seccomp
+ // debugging. This is best-effort. This is provided this way to
+ // avoid dependencies from seccomp to this package.
+ seccomp.SyscallName = t.Name
+ }
+}
diff --git a/pkg/sentry/strace/strace_go_proto/strace.pb.go b/pkg/sentry/strace/strace_go_proto/strace.pb.go
new file mode 100755
index 000000000..ef45661bc
--- /dev/null
+++ b/pkg/sentry/strace/strace_go_proto/strace.pb.go
@@ -0,0 +1,247 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: pkg/sentry/strace/strace.proto
+
+package gvisor
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type Strace struct {
+ Process string `protobuf:"bytes,1,opt,name=process,proto3" json:"process,omitempty"`
+ Function string `protobuf:"bytes,2,opt,name=function,proto3" json:"function,omitempty"`
+ Args []string `protobuf:"bytes,3,rep,name=args,proto3" json:"args,omitempty"`
+ // Types that are valid to be assigned to Info:
+ // *Strace_Enter
+ // *Strace_Exit
+ Info isStrace_Info `protobuf_oneof:"info"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Strace) Reset() { *m = Strace{} }
+func (m *Strace) String() string { return proto.CompactTextString(m) }
+func (*Strace) ProtoMessage() {}
+func (*Strace) Descriptor() ([]byte, []int) {
+ return fileDescriptor_50c4b43677c82b5f, []int{0}
+}
+
+func (m *Strace) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Strace.Unmarshal(m, b)
+}
+func (m *Strace) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Strace.Marshal(b, m, deterministic)
+}
+func (m *Strace) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Strace.Merge(m, src)
+}
+func (m *Strace) XXX_Size() int {
+ return xxx_messageInfo_Strace.Size(m)
+}
+func (m *Strace) XXX_DiscardUnknown() {
+ xxx_messageInfo_Strace.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Strace proto.InternalMessageInfo
+
+func (m *Strace) GetProcess() string {
+ if m != nil {
+ return m.Process
+ }
+ return ""
+}
+
+func (m *Strace) GetFunction() string {
+ if m != nil {
+ return m.Function
+ }
+ return ""
+}
+
+func (m *Strace) GetArgs() []string {
+ if m != nil {
+ return m.Args
+ }
+ return nil
+}
+
+type isStrace_Info interface {
+ isStrace_Info()
+}
+
+type Strace_Enter struct {
+ Enter *StraceEnter `protobuf:"bytes,4,opt,name=enter,proto3,oneof"`
+}
+
+type Strace_Exit struct {
+ Exit *StraceExit `protobuf:"bytes,5,opt,name=exit,proto3,oneof"`
+}
+
+func (*Strace_Enter) isStrace_Info() {}
+
+func (*Strace_Exit) isStrace_Info() {}
+
+func (m *Strace) GetInfo() isStrace_Info {
+ if m != nil {
+ return m.Info
+ }
+ return nil
+}
+
+func (m *Strace) GetEnter() *StraceEnter {
+ if x, ok := m.GetInfo().(*Strace_Enter); ok {
+ return x.Enter
+ }
+ return nil
+}
+
+func (m *Strace) GetExit() *StraceExit {
+ if x, ok := m.GetInfo().(*Strace_Exit); ok {
+ return x.Exit
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*Strace) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*Strace_Enter)(nil),
+ (*Strace_Exit)(nil),
+ }
+}
+
+type StraceEnter struct {
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *StraceEnter) Reset() { *m = StraceEnter{} }
+func (m *StraceEnter) String() string { return proto.CompactTextString(m) }
+func (*StraceEnter) ProtoMessage() {}
+func (*StraceEnter) Descriptor() ([]byte, []int) {
+ return fileDescriptor_50c4b43677c82b5f, []int{1}
+}
+
+func (m *StraceEnter) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_StraceEnter.Unmarshal(m, b)
+}
+func (m *StraceEnter) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_StraceEnter.Marshal(b, m, deterministic)
+}
+func (m *StraceEnter) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_StraceEnter.Merge(m, src)
+}
+func (m *StraceEnter) XXX_Size() int {
+ return xxx_messageInfo_StraceEnter.Size(m)
+}
+func (m *StraceEnter) XXX_DiscardUnknown() {
+ xxx_messageInfo_StraceEnter.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_StraceEnter proto.InternalMessageInfo
+
+type StraceExit struct {
+ Return string `protobuf:"bytes,1,opt,name=return,proto3" json:"return,omitempty"`
+ Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"`
+ ErrNo int64 `protobuf:"varint,3,opt,name=err_no,json=errNo,proto3" json:"err_no,omitempty"`
+ ElapsedNs int64 `protobuf:"varint,4,opt,name=elapsed_ns,json=elapsedNs,proto3" json:"elapsed_ns,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *StraceExit) Reset() { *m = StraceExit{} }
+func (m *StraceExit) String() string { return proto.CompactTextString(m) }
+func (*StraceExit) ProtoMessage() {}
+func (*StraceExit) Descriptor() ([]byte, []int) {
+ return fileDescriptor_50c4b43677c82b5f, []int{2}
+}
+
+func (m *StraceExit) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_StraceExit.Unmarshal(m, b)
+}
+func (m *StraceExit) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_StraceExit.Marshal(b, m, deterministic)
+}
+func (m *StraceExit) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_StraceExit.Merge(m, src)
+}
+func (m *StraceExit) XXX_Size() int {
+ return xxx_messageInfo_StraceExit.Size(m)
+}
+func (m *StraceExit) XXX_DiscardUnknown() {
+ xxx_messageInfo_StraceExit.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_StraceExit proto.InternalMessageInfo
+
+func (m *StraceExit) GetReturn() string {
+ if m != nil {
+ return m.Return
+ }
+ return ""
+}
+
+func (m *StraceExit) GetError() string {
+ if m != nil {
+ return m.Error
+ }
+ return ""
+}
+
+func (m *StraceExit) GetErrNo() int64 {
+ if m != nil {
+ return m.ErrNo
+ }
+ return 0
+}
+
+func (m *StraceExit) GetElapsedNs() int64 {
+ if m != nil {
+ return m.ElapsedNs
+ }
+ return 0
+}
+
+func init() {
+ proto.RegisterType((*Strace)(nil), "gvisor.Strace")
+ proto.RegisterType((*StraceEnter)(nil), "gvisor.StraceEnter")
+ proto.RegisterType((*StraceExit)(nil), "gvisor.StraceExit")
+}
+
+func init() { proto.RegisterFile("pkg/sentry/strace/strace.proto", fileDescriptor_50c4b43677c82b5f) }
+
+var fileDescriptor_50c4b43677c82b5f = []byte{
+ // 255 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x90, 0xdd, 0x4a, 0xf4, 0x30,
+ 0x10, 0x86, 0xb7, 0x5f, 0xdb, 0x7c, 0x76, 0x16, 0x4f, 0xc6, 0x1f, 0x82, 0xa0, 0x94, 0x1e, 0x05,
+ 0x84, 0x2e, 0xe8, 0x1d, 0x08, 0xc2, 0x1e, 0xed, 0x41, 0xbc, 0x80, 0xa5, 0xd6, 0xd9, 0x12, 0x94,
+ 0x24, 0x4c, 0xb2, 0xb2, 0x5e, 0x96, 0x77, 0x28, 0xa6, 0xf1, 0x07, 0x8f, 0x92, 0x67, 0xde, 0x87,
+ 0x0c, 0x6f, 0xe0, 0xca, 0x3f, 0x4f, 0xab, 0x40, 0x36, 0xf2, 0xdb, 0x2a, 0x44, 0x1e, 0x46, 0xca,
+ 0x47, 0xef, 0xd9, 0x45, 0x87, 0x62, 0x7a, 0x35, 0xc1, 0x71, 0xf7, 0x5e, 0x80, 0x78, 0x48, 0x01,
+ 0x4a, 0xf8, 0xef, 0xd9, 0x8d, 0x14, 0x82, 0x2c, 0xda, 0x42, 0x35, 0xfa, 0x0b, 0xf1, 0x02, 0x8e,
+ 0x76, 0x7b, 0x3b, 0x46, 0xe3, 0xac, 0xfc, 0x97, 0xa2, 0x6f, 0x46, 0x84, 0x6a, 0xe0, 0x29, 0xc8,
+ 0xb2, 0x2d, 0x55, 0xa3, 0xd3, 0x1d, 0xaf, 0xa1, 0x26, 0x1b, 0x89, 0x65, 0xd5, 0x16, 0x6a, 0x79,
+ 0x73, 0xd2, 0xcf, 0xcb, 0xfa, 0x79, 0xd1, 0xfd, 0x67, 0xb4, 0x5e, 0xe8, 0xd9, 0x41, 0x05, 0x15,
+ 0x1d, 0x4c, 0x94, 0x75, 0x72, 0xf1, 0x8f, 0x7b, 0x30, 0x71, 0xbd, 0xd0, 0xc9, 0xb8, 0x13, 0x50,
+ 0x19, 0xbb, 0x73, 0xdd, 0x31, 0x2c, 0x7f, 0xbd, 0xd4, 0x79, 0x80, 0x1f, 0x19, 0xcf, 0x41, 0x30,
+ 0xc5, 0x3d, 0xdb, 0x5c, 0x22, 0x13, 0x9e, 0x42, 0x4d, 0xcc, 0x8e, 0x73, 0x81, 0x19, 0xf0, 0x0c,
+ 0x04, 0x31, 0x6f, 0xad, 0x93, 0x65, 0x5b, 0xa8, 0x32, 0x8d, 0x37, 0x0e, 0x2f, 0x01, 0xe8, 0x65,
+ 0xf0, 0x81, 0x9e, 0xb6, 0x36, 0xa4, 0x16, 0xa5, 0x6e, 0xf2, 0x64, 0x13, 0x1e, 0x45, 0xfa, 0xc3,
+ 0xdb, 0x8f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x42, 0x9a, 0xbc, 0x81, 0x65, 0x01, 0x00, 0x00,
+}
diff --git a/pkg/sentry/strace/strace_state_autogen.go b/pkg/sentry/strace/strace_state_autogen.go
new file mode 100755
index 000000000..9dc697ed6
--- /dev/null
+++ b/pkg/sentry/strace/strace_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package strace
+
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
new file mode 100644
index 000000000..eae2d6c12
--- /dev/null
+++ b/pkg/sentry/strace/syscalls.go
@@ -0,0 +1,267 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+)
+
+// FormatSpecifier values describe how an individual syscall argument should be
+// formatted.
+type FormatSpecifier int
+
+// Valid FormatSpecifiers.
+//
+// Unless otherwise specified, values are formatted before syscall execution
+// and not updated after syscall execution (the same value is output).
+const (
+ // Hex is just a hexadecimal number.
+ Hex FormatSpecifier = iota
+
+ // Oct is just an octal number.
+ Oct
+
+ // FD is a file descriptor.
+ FD
+
+ // ReadBuffer is a buffer for a read-style call. The syscall return
+ // value is used for the length.
+ //
+ // Formatted after syscall execution.
+ ReadBuffer
+
+ // WriteBuffer is a buffer for a write-style call. The following arg is
+ // used for the length.
+ //
+ // Contents omitted after syscall execution.
+ WriteBuffer
+
+ // ReadIOVec is a pointer to a struct iovec for a writev-style call.
+ // The following arg is used for the length. The return value is used
+ // for the total length.
+ //
+ // Complete contents only formatted after syscall execution.
+ ReadIOVec
+
+ // WriteIOVec is a pointer to a struct iovec for a writev-style call.
+ // The following arg is used for the length.
+ //
+ // Complete contents only formatted before syscall execution, omitted
+ // after.
+ WriteIOVec
+
+ // IOVec is a generic pointer to a struct iovec. Contents are not dumped.
+ IOVec
+
+ // SendMsgHdr is a pointer to a struct msghdr for a sendmsg-style call.
+ // Contents formatted only before syscall execution, omitted after.
+ SendMsgHdr
+
+ // RecvMsgHdr is a pointer to a struct msghdr for a recvmsg-style call.
+ // Contents formatted only after syscall execution.
+ RecvMsgHdr
+
+ // Path is a pointer to a char* path.
+ Path
+
+ // PostPath is a pointer to a char* path, formatted after syscall
+ // execution.
+ PostPath
+
+ // ExecveStringVector is a NULL-terminated array of strings. Enforces
+ // the maximum execve array length.
+ ExecveStringVector
+
+ // PipeFDs is an array of two FDs, formatted after syscall execution.
+ PipeFDs
+
+ // Uname is a pointer to a struct uname, formatted after syscall execution.
+ Uname
+
+ // Stat is a pointer to a struct stat, formatted after syscall execution.
+ Stat
+
+ // SockAddr is a pointer to a struct sockaddr. The following arg is
+ // used for length.
+ SockAddr
+
+ // PostSockAddr is a pointer to a struct sockaddr, formatted after
+ // syscall execution. The following arg is a pointer to the socklen_t
+ // length.
+ PostSockAddr
+
+ // SockLen is a pointer to a socklen_t, formatted before and after
+ // syscall execution.
+ SockLen
+
+ // SockFamily is a socket protocol family value.
+ SockFamily
+
+ // SockType is a socket type and flags value.
+ SockType
+
+ // SockProtocol is a socket protocol value. Argument n-2 is the socket
+ // protocol family.
+ SockProtocol
+
+ // SockFlags are socket flags.
+ SockFlags
+
+ // Timespec is a pointer to a struct timespec.
+ Timespec
+
+ // PostTimespec is a pointer to a struct timespec, formatted after
+ // syscall execution.
+ PostTimespec
+
+ // UTimeTimespec is a pointer to a struct timespec. Formatting includes
+ // UTIME_NOW and UTIME_OMIT.
+ UTimeTimespec
+
+ // ItimerVal is a pointer to a struct itimerval.
+ ItimerVal
+
+ // PostItimerVal is a pointer to a struct itimerval, formatted after
+ // syscall execution.
+ PostItimerVal
+
+ // ItimerSpec is a pointer to a struct itimerspec.
+ ItimerSpec
+
+ // PostItimerSpec is a pointer to a struct itimerspec, formatted after
+ // syscall execution.
+ PostItimerSpec
+
+ // Timeval is a pointer to a struct timeval, formatted before and after
+ // syscall execution.
+ Timeval
+
+ // Utimbuf is a pointer to a struct utimbuf.
+ Utimbuf
+
+ // Rusage is a struct rusage, formatted after syscall execution.
+ Rusage
+
+ // CloneFlags are clone(2) flags.
+ CloneFlags
+
+ // OpenFlags are open(2) flags.
+ OpenFlags
+
+ // Mode is a mode_t.
+ Mode
+
+ // FutexOp is the futex(2) operation.
+ FutexOp
+
+ // PtraceRequest is the ptrace(2) request.
+ PtraceRequest
+
+ // ItimerType is an itimer type (ITIMER_REAL, etc).
+ ItimerType
+
+ // Signal is a signal number.
+ Signal
+
+ // SignalMaskAction is a signal mask action passed to rt_sigprocmask(2).
+ SignalMaskAction
+
+ // SigSet is a signal set.
+ SigSet
+
+ // PostSigSet is a signal set, formatted after syscall execution.
+ PostSigSet
+
+ // SigAction is a struct sigaction.
+ SigAction
+
+ // PostSigAction is a struct sigaction, formatted after syscall execution.
+ PostSigAction
+
+ // CapHeader is a cap_user_header_t.
+ CapHeader
+
+ // CapData is the data argument to capget(2)/capset(2). The previous
+ // argument must be CapHeader.
+ CapData
+
+ // PostCapData is the data argument to capget(2)/capset(2), formatted
+ // after syscall execution. The previous argument must be CapHeader.
+ PostCapData
+
+ // PollFDs is an array of struct pollfd. The number of entries in the
+ // array is in the next argument.
+ PollFDs
+)
+
+// defaultFormat is the syscall argument format to use if the actual format is
+// not known. It formats all six arguments as hex.
+var defaultFormat = []FormatSpecifier{Hex, Hex, Hex, Hex, Hex, Hex}
+
+// SyscallInfo captures the name and printing format of a syscall.
+type SyscallInfo struct {
+ // name is the name of the syscall.
+ name string
+
+ // format contains the format specifiers for each argument.
+ //
+ // Syscall calls can have up to six arguments. Arguments without a
+ // corresponding entry in format will not be printed.
+ format []FormatSpecifier
+}
+
+// makeSyscallInfo returns a SyscallInfo for a syscall.
+func makeSyscallInfo(name string, f ...FormatSpecifier) SyscallInfo {
+ return SyscallInfo{name: name, format: f}
+}
+
+// SyscallMap maps syscalls into names and printing formats.
+type SyscallMap map[uintptr]SyscallInfo
+
+var _ kernel.Stracer = (SyscallMap)(nil)
+
+// syscallTable contains the syscalls for a specific OS/Arch.
+type syscallTable struct {
+ // os is the operating system this table targets.
+ os abi.OS
+
+ // arch is the architecture this table targets.
+ arch arch.Arch
+
+ // syscalls contains the syscall mappings.
+ syscalls SyscallMap
+}
+
+// syscallTables contains all syscall tables.
+var syscallTables = []syscallTable{
+ {
+ os: abi.Linux,
+ arch: arch.AMD64,
+ syscalls: linuxAMD64,
+ },
+}
+
+// Lookup returns the SyscallMap for the OS/Arch combination. The returned map
+// must not be changed.
+func Lookup(os abi.OS, a arch.Arch) (SyscallMap, bool) {
+ for _, s := range syscallTables {
+ if s.os == os && s.arch == a {
+ return s.syscalls, true
+ }
+ }
+ return nil, false
+}
diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go
new file mode 100644
index 000000000..ec1eab331
--- /dev/null
+++ b/pkg/sentry/syscalls/epoll.go
@@ -0,0 +1,174 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syscalls
+
+import (
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/epoll"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// CreateEpoll implements the epoll_create(2) linux syscall.
+func CreateEpoll(t *kernel.Task, closeOnExec bool) (kdefs.FD, error) {
+ file := epoll.NewEventPoll(t)
+ defer file.DecRef()
+
+ flags := kernel.FDFlags{
+ CloseOnExec: closeOnExec,
+ }
+ fd, err := t.FDMap().NewFDFrom(0, file, flags, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, err
+ }
+
+ return fd, nil
+}
+
+// AddEpoll implements the epoll_ctl(2) linux syscall when op is EPOLL_CTL_ADD.
+func AddEpoll(t *kernel.Task, epfd kdefs.FD, fd kdefs.FD, flags epoll.EntryFlags, mask waiter.EventMask, userData [2]int32) error {
+ // Get epoll from the file descriptor.
+ epollfile := t.FDMap().GetFile(epfd)
+ if epollfile == nil {
+ return syscall.EBADF
+ }
+ defer epollfile.DecRef()
+
+ // Get the target file id.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the epollPoll operations.
+ e, ok := epollfile.FileOperations.(*epoll.EventPoll)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ // Try to add the entry.
+ return e.AddEntry(epoll.FileIdentifier{file, fd}, flags, mask, userData)
+}
+
+// UpdateEpoll implements the epoll_ctl(2) linux syscall when op is EPOLL_CTL_MOD.
+func UpdateEpoll(t *kernel.Task, epfd kdefs.FD, fd kdefs.FD, flags epoll.EntryFlags, mask waiter.EventMask, userData [2]int32) error {
+ // Get epoll from the file descriptor.
+ epollfile := t.FDMap().GetFile(epfd)
+ if epollfile == nil {
+ return syscall.EBADF
+ }
+ defer epollfile.DecRef()
+
+ // Get the target file id.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the epollPoll operations.
+ e, ok := epollfile.FileOperations.(*epoll.EventPoll)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ // Try to update the entry.
+ return e.UpdateEntry(epoll.FileIdentifier{file, fd}, flags, mask, userData)
+}
+
+// RemoveEpoll implements the epoll_ctl(2) linux syscall when op is EPOLL_CTL_DEL.
+func RemoveEpoll(t *kernel.Task, epfd kdefs.FD, fd kdefs.FD) error {
+ // Get epoll from the file descriptor.
+ epollfile := t.FDMap().GetFile(epfd)
+ if epollfile == nil {
+ return syscall.EBADF
+ }
+ defer epollfile.DecRef()
+
+ // Get the target file id.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the epollPoll operations.
+ e, ok := epollfile.FileOperations.(*epoll.EventPoll)
+ if !ok {
+ return syscall.EBADF
+ }
+
+ // Try to remove the entry.
+ return e.RemoveEntry(epoll.FileIdentifier{file, fd})
+}
+
+// WaitEpoll implements the epoll_wait(2) linux syscall.
+func WaitEpoll(t *kernel.Task, fd kdefs.FD, max int, timeout int) ([]epoll.Event, error) {
+ // Get epoll from the file descriptor.
+ epollfile := t.FDMap().GetFile(fd)
+ if epollfile == nil {
+ return nil, syscall.EBADF
+ }
+ defer epollfile.DecRef()
+
+ // Extract the epollPoll operations.
+ e, ok := epollfile.FileOperations.(*epoll.EventPoll)
+ if !ok {
+ return nil, syscall.EBADF
+ }
+
+ // Try to read events and return right away if we got them or if the
+ // caller requested a non-blocking "wait".
+ r := e.ReadEvents(max)
+ if len(r) != 0 || timeout == 0 {
+ return r, nil
+ }
+
+ // We'll have to wait. Set up the timer if a timeout was specified and
+ // and register with the epoll object for readability events.
+ var haveDeadline bool
+ var deadline ktime.Time
+ if timeout > 0 {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+
+ w, ch := waiter.NewChannelEntry(nil)
+ e.EventRegister(&w, waiter.EventIn)
+ defer e.EventUnregister(&w)
+
+ // Try to read the events again until we succeed, timeout or get
+ // interrupted.
+ for {
+ r = e.ReadEvents(max)
+ if len(r) != 0 {
+ return r, nil
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syscall.ETIMEDOUT {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
new file mode 100644
index 000000000..1ba3695fb
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -0,0 +1,114 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "io"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+var (
+ partialResultMetric = metric.MustCreateNewUint64Metric("/syscalls/partial_result", true /* sync */, "Whether or not a partial result has occurred for this sandbox.")
+ partialResultOnce sync.Once
+)
+
+// handleIOError handles special error cases for partial results. For some
+// errors, we may consume the error and return only the partial read/write.
+//
+// op and f are used only for panics.
+func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op string, f *fs.File) error {
+ switch err {
+ case nil:
+ // Typical successful syscall.
+ return nil
+ case io.EOF:
+ // EOF is always consumed. If this is a partial read/write
+ // (result != 0), the application will see that, otherwise
+ // they will see 0.
+ return nil
+ case syserror.ErrExceedsFileSizeLimit:
+ // Ignore partialResult because this error only applies to
+ // normal files, and for those files we cannot accumulate
+ // write results.
+ //
+ // Do not consume the error and return it as EFBIG.
+ // Simultaneously send a SIGXFSZ per setrlimit(2).
+ t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
+ return syscall.EFBIG
+ case syserror.ErrInterrupted:
+ // The syscall was interrupted. Return nil if it completed
+ // partially, otherwise return the error code that the syscall
+ // needs (to indicate to the kernel what it should do).
+ if partialResult {
+ return nil
+ }
+ return intr
+ }
+
+ if !partialResult {
+ // Typical syscall error.
+ return err
+ }
+
+ switch err {
+ case syserror.EINTR:
+ // Syscall interrupted, but completed a partial
+ // read/write. Like ErrWouldBlock, since we have a
+ // partial read/write, we consume the error and return
+ // the partial result.
+ return nil
+ case syserror.EFAULT:
+ // EFAULT is only shown the user if nothing was
+ // read/written. If we read something (this case), they see
+ // a partial read/write. They will then presumably try again
+ // with an incremented buffer, which will EFAULT with
+ // result == 0.
+ return nil
+ case syserror.EPIPE:
+ // Writes to a pipe or socket will return EPIPE if the other
+ // side is gone. The partial write is returned. EPIPE will be
+ // returned on the next call.
+ //
+ // TODO(gvisor.dev/issue/161): In some cases SIGPIPE should
+ // also be sent to the application.
+ return nil
+ case syserror.ErrWouldBlock:
+ // Syscall would block, but completed a partial read/write.
+ // This case should only be returned by IssueIO for nonblocking
+ // files. Since we have a partial read/write, we consume
+ // ErrWouldBlock, returning the partial result.
+ return nil
+ }
+
+ switch err.(type) {
+ case kernel.SyscallRestartErrno:
+ // Identical to the EINTR case.
+ return nil
+ }
+
+ // An unknown error is encountered with a partial read/write.
+ name, _ := f.Dirent.FullName(nil /* ignore chroot */)
+ log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, err, err, op, name, f.FileOperations)
+ partialResultOnce.Do(partialResultMetric.Increment)
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/flags.go b/pkg/sentry/syscalls/linux/flags.go
new file mode 100644
index 000000000..d83e12971
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/flags.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+)
+
+// flagsToPermissions returns a Permissions object from Linux flags.
+// This includes truncate permission if O_TRUNC is set in the mask.
+func flagsToPermissions(mask uint) (p fs.PermMask) {
+ if mask&linux.O_TRUNC != 0 {
+ p.Write = true
+ }
+ switch mask & linux.O_ACCMODE {
+ case linux.O_WRONLY:
+ p.Write = true
+ case linux.O_RDWR:
+ p.Write = true
+ p.Read = true
+ case linux.O_RDONLY:
+ p.Read = true
+ }
+ return
+}
+
+// linuxToFlags converts Linux file flags to a FileFlags object.
+func linuxToFlags(mask uint) fs.FileFlags {
+ return fs.FileFlags{
+ Direct: mask&linux.O_DIRECT != 0,
+ Sync: mask&linux.O_SYNC != 0,
+ NonBlocking: mask&linux.O_NONBLOCK != 0,
+ Read: (mask & linux.O_ACCMODE) != linux.O_WRONLY,
+ Write: (mask & linux.O_ACCMODE) != linux.O_RDONLY,
+ Append: mask&linux.O_APPEND != 0,
+ Directory: mask&linux.O_DIRECTORY != 0,
+ Async: mask&linux.O_ASYNC != 0,
+ LargeFile: mask&linux.O_LARGEFILE != 0,
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
new file mode 100644
index 000000000..3e4d312af
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -0,0 +1,487 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package linux provides syscall tables for amd64 Linux.
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/syscalls"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// AUDIT_ARCH_X86_64 identifies the Linux syscall API on AMD64, and is taken
+// from <linux/audit.h>.
+const _AUDIT_ARCH_X86_64 = 0xc000003e
+
+// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall
+// numbers from Linux 4.4. The entries commented out are those syscalls we
+// don't currently support.
+//
+// Syscall support is documented as annotations in Go comments of the form:
+// @Syscall(<name>, <key:value>, ...)
+//
+// Supported args and values are:
+//
+// - arg: A syscall option. This entry only applies to the syscall when given
+// this option.
+// - support: Indicates support level
+// - UNIMPLEMENTED: Unimplemented (default, implies returns:ENOSYS)
+// - PARTIAL: Partial support. Details should be provided in note.
+// - FULL: Full support
+// - returns: Indicates a known return value. Values are syscall errors. This
+// is treated as a string so you can use something like
+// "returns:EPERM or ENOSYS".
+// - issue: A Github issue number.
+// - note: A note
+//
+// Example:
+// // @Syscall(mmap, arg:MAP_PRIVATE, support:FULL, note:Private memory fully supported)
+// // @Syscall(mmap, arg:MAP_SHARED, issue:123, note:Shared memory not supported)
+// // @Syscall(setxattr, returns:ENOTSUP, note:Requires file system support)
+//
+// Annotations should be placed as close to their implementation as possible
+// (preferrably as part of a supporting function's Godoc) and should be
+// updated as syscall support changes. Unimplemented syscalls are documented
+// here due to their lack of a supporting function or method.
+var AMD64 = &kernel.SyscallTable{
+ OS: abi.Linux,
+ Arch: arch.AMD64,
+ Version: kernel.Version{
+ // Version 4.4 is chosen as a stable, longterm version of Linux, which
+ // guides the interface provided by this syscall table. The build
+ // version is that for a clean build with default kernel config, at 5
+ // minutes after v4.4 was tagged.
+ Sysname: "Linux",
+ Release: "4.4",
+ Version: "#1 SMP Sun Jan 10 15:06:54 PST 2016",
+ },
+ AuditNumber: _AUDIT_ARCH_X86_64,
+ Table: map[uintptr]kernel.SyscallFn{
+ 0: Read,
+ 1: Write,
+ 2: Open,
+ 3: Close,
+ 4: Stat,
+ 5: Fstat,
+ 6: Lstat,
+ 7: Poll,
+ 8: Lseek,
+ 9: Mmap,
+ 10: Mprotect,
+ 11: Munmap,
+ 12: Brk,
+ 13: RtSigaction,
+ 14: RtSigprocmask,
+ 15: RtSigreturn,
+ 16: Ioctl,
+ 17: Pread64,
+ 18: Pwrite64,
+ 19: Readv,
+ 20: Writev,
+ 21: Access,
+ 22: Pipe,
+ 23: Select,
+ 24: SchedYield,
+ 25: Mremap,
+ 26: Msync,
+ 27: Mincore,
+ 28: Madvise,
+ 29: Shmget,
+ 30: Shmat,
+ 31: Shmctl,
+ 32: Dup,
+ 33: Dup2,
+ 34: Pause,
+ 35: Nanosleep,
+ 36: Getitimer,
+ 37: Alarm,
+ 38: Setitimer,
+ 39: Getpid,
+ 40: Sendfile,
+ 41: Socket,
+ 42: Connect,
+ 43: Accept,
+ 44: SendTo,
+ 45: RecvFrom,
+ 46: SendMsg,
+ 47: RecvMsg,
+ 48: Shutdown,
+ 49: Bind,
+ 50: Listen,
+ 51: GetSockName,
+ 52: GetPeerName,
+ 53: SocketPair,
+ 54: SetSockOpt,
+ 55: GetSockOpt,
+ 56: Clone,
+ 57: Fork,
+ 58: Vfork,
+ 59: Execve,
+ 60: Exit,
+ 61: Wait4,
+ 62: Kill,
+ 63: Uname,
+ 64: Semget,
+ 65: Semop,
+ 66: Semctl,
+ 67: Shmdt,
+ // 68: @Syscall(Msgget), TODO(b/29354921)
+ // 69: @Syscall(Msgsnd), TODO(b/29354921)
+ // 70: @Syscall(Msgrcv), TODO(b/29354921)
+ // 71: @Syscall(Msgctl), TODO(b/29354921)
+ 72: Fcntl,
+ 73: Flock,
+ 74: Fsync,
+ 75: Fdatasync,
+ 76: Truncate,
+ 77: Ftruncate,
+ 78: Getdents,
+ 79: Getcwd,
+ 80: Chdir,
+ 81: Fchdir,
+ 82: Rename,
+ 83: Mkdir,
+ 84: Rmdir,
+ 85: Creat,
+ 86: Link,
+ 87: Unlink,
+ 88: Symlink,
+ 89: Readlink,
+ 90: Chmod,
+ 91: Fchmod,
+ 92: Chown,
+ 93: Fchown,
+ 94: Lchown,
+ 95: Umask,
+ 96: Gettimeofday,
+ 97: Getrlimit,
+ 98: Getrusage,
+ 99: Sysinfo,
+ 100: Times,
+ 101: Ptrace,
+ 102: Getuid,
+ 103: Syslog,
+ 104: Getgid,
+ 105: Setuid,
+ 106: Setgid,
+ 107: Geteuid,
+ 108: Getegid,
+ 109: Setpgid,
+ 110: Getppid,
+ 111: Getpgrp,
+ 112: Setsid,
+ 113: Setreuid,
+ 114: Setregid,
+ 115: Getgroups,
+ 116: Setgroups,
+ 117: Setresuid,
+ 118: Getresuid,
+ 119: Setresgid,
+ 120: Getresgid,
+ 121: Getpgid,
+ // 122: @Syscall(Setfsuid), TODO(b/112851702)
+ // 123: @Syscall(Setfsgid), TODO(b/112851702)
+ 124: Getsid,
+ 125: Capget,
+ 126: Capset,
+ 127: RtSigpending,
+ 128: RtSigtimedwait,
+ 129: RtSigqueueinfo,
+ 130: RtSigsuspend,
+ 131: Sigaltstack,
+ 132: Utime,
+ 133: Mknod,
+ // @Syscall(Uselib, note:Obsolete)
+ 134: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(SetPersonality, returns:EINVAL, note:Unable to change personality)
+ 135: syscalls.ErrorWithEvent(syscall.EINVAL),
+ // @Syscall(Ustat, note:Needs filesystem support)
+ 136: syscalls.ErrorWithEvent(syscall.ENOSYS),
+ 137: Statfs,
+ 138: Fstatfs,
+ // 139: @Syscall(Sysfs), TODO(gvisor.dev/issue/165)
+ 140: Getpriority,
+ 141: Setpriority,
+ // @Syscall(SchedSetparam, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise)
+ 142: syscalls.CapError(linux.CAP_SYS_NICE), // requires cap_sys_nice
+ 143: SchedGetparam,
+ 144: SchedSetscheduler,
+ 145: SchedGetscheduler,
+ 146: SchedGetPriorityMax,
+ 147: SchedGetPriorityMin,
+ // @Syscall(SchedRrGetInterval, returns:EPERM)
+ 148: syscalls.ErrorWithEvent(syscall.EPERM),
+ 149: Mlock,
+ 150: Munlock,
+ 151: Mlockall,
+ 152: Munlockall,
+ // @Syscall(Vhangup, returns:EPERM)
+ 153: syscalls.CapError(linux.CAP_SYS_TTY_CONFIG),
+ // @Syscall(ModifyLdt, returns:EPERM)
+ 154: syscalls.Error(syscall.EPERM),
+ // @Syscall(PivotRoot, returns:EPERM)
+ 155: syscalls.Error(syscall.EPERM),
+ // @Syscall(Sysctl, returns:EPERM)
+ 156: syscalls.Error(syscall.EPERM), // syscall is "worthless"
+ 157: Prctl,
+ 158: ArchPrctl,
+ // @Syscall(Adjtimex, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_time; ENOSYS otherwise)
+ 159: syscalls.CapError(linux.CAP_SYS_TIME), // requires cap_sys_time
+ 160: Setrlimit,
+ 161: Chroot,
+ 162: Sync,
+ // @Syscall(Acct, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_pacct; ENOSYS otherwise)
+ 163: syscalls.CapError(linux.CAP_SYS_PACCT), // requires cap_sys_pacct
+ // @Syscall(Settimeofday, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_time; ENOSYS otherwise)
+ 164: syscalls.CapError(linux.CAP_SYS_TIME), // requires cap_sys_time
+ 165: Mount,
+ 166: Umount2,
+ // @Syscall(Swapon, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
+ 167: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin
+ // @Syscall(Swapoff, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
+ 168: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin
+ // @Syscall(Reboot, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_boot; ENOSYS otherwise)
+ 169: syscalls.CapError(linux.CAP_SYS_BOOT), // requires cap_sys_boot
+ 170: Sethostname,
+ 171: Setdomainname,
+ // @Syscall(Iopl, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_rawio; ENOSYS otherwise)
+ 172: syscalls.CapError(linux.CAP_SYS_RAWIO), // requires cap_sys_rawio
+ // @Syscall(Ioperm, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_rawio; ENOSYS otherwise)
+ 173: syscalls.CapError(linux.CAP_SYS_RAWIO), // requires cap_sys_rawio
+ // @Syscall(CreateModule, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
+ 174: syscalls.CapError(linux.CAP_SYS_MODULE), // CreateModule, requires cap_sys_module
+ // @Syscall(InitModule, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
+ 175: syscalls.CapError(linux.CAP_SYS_MODULE), // requires cap_sys_module
+ // @Syscall(DeleteModule, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
+ 176: syscalls.CapError(linux.CAP_SYS_MODULE), // requires cap_sys_module
+ // @Syscall(GetKernelSyms, note:Not supported in > 2.6)
+ 177: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(QueryModule, note:Not supported in > 2.6)
+ 178: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(Quotactl, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
+ 179: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin (most operations)
+ // @Syscall(Nfsservctl, note:Does not exist > 3.1)
+ 180: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(Getpmsg, note:Not implemented in Linux)
+ 181: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(Putpmsg, note:Not implemented in Linux)
+ 182: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(AfsSyscall, note:Not implemented in Linux)
+ 183: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(Tuxcall, note:Not implemented in Linux)
+ 184: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(Security, note:Not implemented in Linux)
+ 185: syscalls.Error(syscall.ENOSYS),
+ 186: Gettid,
+ 187: nil, // @Syscall(Readahead), TODO(b/29351341)
+ // @Syscall(Setxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 188: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Lsetxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 189: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Fsetxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 190: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Getxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 191: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Lgetxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 192: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Fgetxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 193: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Listxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 194: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Llistxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 195: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Flistxattr, returns:ENOTSUP, note:Requires filesystem support)
+ 196: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Removexattr, returns:ENOTSUP, note:Requires filesystem support)
+ 197: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Lremovexattr, returns:ENOTSUP, note:Requires filesystem support)
+ 198: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ // @Syscall(Fremovexattr, returns:ENOTSUP, note:Requires filesystem support)
+ 199: syscalls.ErrorWithEvent(syscall.ENOTSUP),
+ 200: Tkill,
+ 201: Time,
+ 202: Futex,
+ 203: SchedSetaffinity,
+ 204: SchedGetaffinity,
+ // @Syscall(SetThreadArea, note:Expected to return ENOSYS on 64-bit)
+ 205: syscalls.Error(syscall.ENOSYS),
+ 206: IoSetup,
+ 207: IoDestroy,
+ 208: IoGetevents,
+ 209: IoSubmit,
+ 210: IoCancel,
+ // @Syscall(GetThreadArea, note:Expected to return ENOSYS on 64-bit)
+ 211: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(LookupDcookie, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
+ 212: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin
+ 213: EpollCreate,
+ // @Syscall(EpollCtlOld, note:Deprecated)
+ 214: syscalls.ErrorWithEvent(syscall.ENOSYS), // deprecated (afaik, unused)
+ // @Syscall(EpollWaitOld, note:Deprecated)
+ 215: syscalls.ErrorWithEvent(syscall.ENOSYS), // deprecated (afaik, unused)
+ // @Syscall(RemapFilePages, note:Deprecated)
+ 216: syscalls.ErrorWithEvent(syscall.ENOSYS), // deprecated since 3.16
+ 217: Getdents64,
+ 218: SetTidAddress,
+ 219: RestartSyscall,
+ // 220: @Syscall(Semtimedop), TODO(b/29354920)
+ 221: Fadvise64,
+ 222: TimerCreate,
+ 223: TimerSettime,
+ 224: TimerGettime,
+ 225: TimerGetoverrun,
+ 226: TimerDelete,
+ 227: ClockSettime,
+ 228: ClockGettime,
+ 229: ClockGetres,
+ 230: ClockNanosleep,
+ 231: ExitGroup,
+ 232: EpollWait,
+ 233: EpollCtl,
+ 234: Tgkill,
+ 235: Utimes,
+ // @Syscall(Vserver, note:Not implemented by Linux)
+ 236: syscalls.Error(syscall.ENOSYS), // Vserver, not implemented by Linux
+ // @Syscall(Mbind, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise), TODO(b/117792295)
+ 237: syscalls.CapError(linux.CAP_SYS_NICE), // may require cap_sys_nice
+ 238: SetMempolicy,
+ 239: GetMempolicy,
+ // 240: @Syscall(MqOpen), TODO(b/29354921)
+ // 241: @Syscall(MqUnlink), TODO(b/29354921)
+ // 242: @Syscall(MqTimedsend), TODO(b/29354921)
+ // 243: @Syscall(MqTimedreceive), TODO(b/29354921)
+ // 244: @Syscall(MqNotify), TODO(b/29354921)
+ // 245: @Syscall(MqGetsetattr), TODO(b/29354921)
+ 246: syscalls.CapError(linux.CAP_SYS_BOOT), // kexec_load, requires cap_sys_boot
+ 247: Waitid,
+ // @Syscall(AddKey, returns:EACCES, note:Not available to user)
+ 248: syscalls.Error(syscall.EACCES),
+ // @Syscall(RequestKey, returns:EACCES, note:Not available to user)
+ 249: syscalls.Error(syscall.EACCES),
+ // @Syscall(Keyctl, returns:EACCES, note:Not available to user)
+ 250: syscalls.Error(syscall.EACCES),
+ // @Syscall(IoprioSet, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
+ 251: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_nice or cap_sys_admin (depending)
+ // @Syscall(IoprioGet, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
+ 252: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_nice or cap_sys_admin (depending)
+ 253: InotifyInit,
+ 254: InotifyAddWatch,
+ 255: InotifyRmWatch,
+ // @Syscall(MigratePages, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise)
+ 256: syscalls.CapError(linux.CAP_SYS_NICE),
+ 257: Openat,
+ 258: Mkdirat,
+ 259: Mknodat,
+ 260: Fchownat,
+ 261: Futimesat,
+ 262: Fstatat,
+ 263: Unlinkat,
+ 264: Renameat,
+ 265: Linkat,
+ 266: Symlinkat,
+ 267: Readlinkat,
+ 268: Fchmodat,
+ 269: Faccessat,
+ 270: Pselect,
+ 271: Ppoll,
+ 272: Unshare,
+ // @Syscall(SetRobustList, note:Obsolete)
+ 273: syscalls.Error(syscall.ENOSYS),
+ // @Syscall(GetRobustList, note:Obsolete)
+ 274: syscalls.Error(syscall.ENOSYS),
+ 275: Splice,
+ // 276: @Syscall(Tee), TODO(b/29354098)
+ 277: SyncFileRange,
+ // 278: @Syscall(Vmsplice), TODO(b/29354098)
+ // @Syscall(MovePages, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise)
+ 279: syscalls.CapError(linux.CAP_SYS_NICE), // requires cap_sys_nice (mostly)
+ 280: Utimensat,
+ 281: EpollPwait,
+ // 282: @Syscall(Signalfd), TODO(b/19846426)
+ 283: TimerfdCreate,
+ 284: Eventfd,
+ 285: Fallocate,
+ 286: TimerfdSettime,
+ 287: TimerfdGettime,
+ 288: Accept4,
+ // 289: @Syscall(Signalfd4), TODO(b/19846426)
+ 290: Eventfd2,
+ 291: EpollCreate1,
+ 292: Dup3,
+ 293: Pipe2,
+ 294: InotifyInit1,
+ 295: Preadv,
+ 296: Pwritev,
+ 297: RtTgsigqueueinfo,
+ // @Syscall(PerfEventOpen, returns:ENODEV, note:No support for perf counters)
+ 298: syscalls.ErrorWithEvent(syscall.ENODEV),
+ 299: RecvMMsg,
+ // @Syscall(FanotifyInit, note:Needs CONFIG_FANOTIFY)
+ 300: syscalls.ErrorWithEvent(syscall.ENOSYS),
+ // @Syscall(FanotifyMark, note:Needs CONFIG_FANOTIFY)
+ 301: syscalls.ErrorWithEvent(syscall.ENOSYS),
+ 302: Prlimit64,
+ // @Syscall(NameToHandleAt, returns:EOPNOTSUPP, note:Needs filesystem support)
+ 303: syscalls.ErrorWithEvent(syscall.EOPNOTSUPP),
+ // @Syscall(OpenByHandleAt, returns:EOPNOTSUPP, note:Needs filesystem support)
+ 304: syscalls.ErrorWithEvent(syscall.EOPNOTSUPP),
+ // @Syscall(ClockAdjtime, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
+ 305: syscalls.CapError(linux.CAP_SYS_TIME), // requires cap_sys_time
+ 306: Syncfs,
+ 307: SendMMsg,
+ // 308: @Syscall(Setns), TODO(b/29354995)
+ 309: Getcpu,
+ // 310: @Syscall(ProcessVmReadv), TODO(gvisor.dev/issue/158) may require cap_sys_ptrace
+ // 311: @Syscall(ProcessVmWritev), TODO(gvisor.dev/issue/158) may require cap_sys_ptrace
+ // @Syscall(Kcmp, returns:EPERM or ENOSYS, note:Requires cap_sys_ptrace)
+ 312: syscalls.CapError(linux.CAP_SYS_PTRACE),
+ // @Syscall(FinitModule, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
+ 313: syscalls.CapError(linux.CAP_SYS_MODULE),
+ // 314: @Syscall(SchedSetattr), TODO(b/118902272), we have no scheduler
+ // 315: @Syscall(SchedGetattr), TODO(b/118902272), we have no scheduler
+ // 316: @Syscall(Renameat2), TODO(b/118902772)
+ 317: Seccomp,
+ 318: GetRandom,
+ 319: MemfdCreate,
+ // @Syscall(KexecFileLoad, EPERM or ENOSYS, note:Infeasible to support. Returns EPERM if the process does not have cap_sys_boot; ENOSYS otherwise)
+ 320: syscalls.CapError(linux.CAP_SYS_BOOT),
+ // @Syscall(Bpf, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_boot; ENOSYS otherwise)
+ 321: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin for all commands
+ // 322: @Syscall(Execveat), TODO(b/118901836)
+ // 323: @Syscall(Userfaultfd), TODO(b/118906345)
+ // 324: @Syscall(Membarrier), TODO(b/118904897)
+ 325: Mlock2,
+ // Syscalls after 325 are "backports" from versions of Linux after 4.4.
+ // 326: @Syscall(CopyFileRange),
+ 327: Preadv2,
+ 328: Pwritev2,
+ },
+
+ Emulate: map[usermem.Addr]uintptr{
+ 0xffffffffff600000: 96, // vsyscall gettimeofday(2)
+ 0xffffffffff600400: 201, // vsyscall time(2)
+ 0xffffffffff600800: 309, // vsyscall getcpu(2)
+ },
+ Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, syserror.ENOSYS
+ },
+}
diff --git a/pkg/sentry/syscalls/linux/linux_state_autogen.go b/pkg/sentry/syscalls/linux/linux_state_autogen.go
new file mode 100755
index 000000000..0a747952b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/linux_state_autogen.go
@@ -0,0 +1,80 @@
+// automatically generated by stateify.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *ioEvent) beforeSave() {}
+func (x *ioEvent) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Data", &x.Data)
+ m.Save("Obj", &x.Obj)
+ m.Save("Result", &x.Result)
+ m.Save("Result2", &x.Result2)
+}
+
+func (x *ioEvent) afterLoad() {}
+func (x *ioEvent) load(m state.Map) {
+ m.Load("Data", &x.Data)
+ m.Load("Obj", &x.Obj)
+ m.Load("Result", &x.Result)
+ m.Load("Result2", &x.Result2)
+}
+
+func (x *futexWaitRestartBlock) beforeSave() {}
+func (x *futexWaitRestartBlock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("duration", &x.duration)
+ m.Save("addr", &x.addr)
+ m.Save("private", &x.private)
+ m.Save("val", &x.val)
+ m.Save("mask", &x.mask)
+}
+
+func (x *futexWaitRestartBlock) afterLoad() {}
+func (x *futexWaitRestartBlock) load(m state.Map) {
+ m.Load("duration", &x.duration)
+ m.Load("addr", &x.addr)
+ m.Load("private", &x.private)
+ m.Load("val", &x.val)
+ m.Load("mask", &x.mask)
+}
+
+func (x *pollRestartBlock) beforeSave() {}
+func (x *pollRestartBlock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("pfdAddr", &x.pfdAddr)
+ m.Save("nfds", &x.nfds)
+ m.Save("timeout", &x.timeout)
+}
+
+func (x *pollRestartBlock) afterLoad() {}
+func (x *pollRestartBlock) load(m state.Map) {
+ m.Load("pfdAddr", &x.pfdAddr)
+ m.Load("nfds", &x.nfds)
+ m.Load("timeout", &x.timeout)
+}
+
+func (x *clockNanosleepRestartBlock) beforeSave() {}
+func (x *clockNanosleepRestartBlock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("c", &x.c)
+ m.Save("duration", &x.duration)
+ m.Save("rem", &x.rem)
+}
+
+func (x *clockNanosleepRestartBlock) afterLoad() {}
+func (x *clockNanosleepRestartBlock) load(m state.Map) {
+ m.Load("c", &x.c)
+ m.Load("duration", &x.duration)
+ m.Load("rem", &x.rem)
+}
+
+func init() {
+ state.Register("linux.ioEvent", (*ioEvent)(nil), state.Fns{Save: (*ioEvent).save, Load: (*ioEvent).load})
+ state.Register("linux.futexWaitRestartBlock", (*futexWaitRestartBlock)(nil), state.Fns{Save: (*futexWaitRestartBlock).save, Load: (*futexWaitRestartBlock).load})
+ state.Register("linux.pollRestartBlock", (*pollRestartBlock)(nil), state.Fns{Save: (*pollRestartBlock).save, Load: (*pollRestartBlock).load})
+ state.Register("linux.clockNanosleepRestartBlock", (*clockNanosleepRestartBlock)(nil), state.Fns{Save: (*clockNanosleepRestartBlock).save, Load: (*clockNanosleepRestartBlock).load})
+}
diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go
new file mode 100644
index 000000000..5438b664b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sigset.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// copyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and
+// STOP are clear.
+func copyInSigSet(t *kernel.Task, sigSetAddr usermem.Addr, size uint) (linux.SignalSet, error) {
+ if size != linux.SignalSetSize {
+ return 0, syscall.EINVAL
+ }
+ b := t.CopyScratchBuffer(8)
+ if _, err := t.CopyInBytes(sigSetAddr, b); err != nil {
+ return 0, err
+ }
+ mask := usermem.ByteOrder.Uint64(b[:])
+ return linux.SignalSet(mask) &^ kernel.UnblockableSignals, nil
+}
+
+// copyOutSigSet copies out a sigset_t.
+func copyOutSigSet(t *kernel.Task, sigSetAddr usermem.Addr, mask linux.SignalSet) error {
+ b := t.CopyScratchBuffer(8)
+ usermem.ByteOrder.PutUint64(b, uint64(mask))
+ _, err := t.CopyOutBytes(sigSetAddr, b)
+ return err
+}
+
+// copyInSigSetWithSize copies in a structure as below
+//
+// struct {
+// sigset_t* sigset_addr;
+// size_t sizeof_sigset;
+// };
+//
+// and returns sigset_addr and size.
+func copyInSigSetWithSize(t *kernel.Task, addr usermem.Addr) (usermem.Addr, uint, error) {
+ switch t.Arch().Width() {
+ case 8:
+ in := t.CopyScratchBuffer(16)
+ if _, err := t.CopyInBytes(addr, in); err != nil {
+ return 0, 0, err
+ }
+ maskAddr := usermem.Addr(usermem.ByteOrder.Uint64(in[0:]))
+ maskSize := uint(usermem.ByteOrder.Uint64(in[8:]))
+ return maskAddr, maskSize, nil
+ default:
+ return 0, 0, syserror.ENOSYS
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
new file mode 100644
index 000000000..1b27b2415
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -0,0 +1,416 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/eventfd"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// I/O commands.
+const (
+ _IOCB_CMD_PREAD = 0
+ _IOCB_CMD_PWRITE = 1
+ _IOCB_CMD_FSYNC = 2
+ _IOCB_CMD_FDSYNC = 3
+ _IOCB_CMD_NOOP = 6
+ _IOCB_CMD_PREADV = 7
+ _IOCB_CMD_PWRITEV = 8
+)
+
+// I/O flags.
+const (
+ _IOCB_FLAG_RESFD = 1
+)
+
+// ioCallback describes an I/O request.
+//
+// The priority field is currently ignored in the implementation below. Also
+// note that the IOCB_FLAG_RESFD feature is not supported.
+type ioCallback struct {
+ Data uint64
+ Key uint32
+ Reserved1 uint32
+
+ OpCode uint16
+ ReqPrio int16
+ FD uint32
+
+ Buf uint64
+ Bytes uint64
+ Offset int64
+
+ Reserved2 uint64
+ Flags uint32
+
+ // eventfd to signal if IOCB_FLAG_RESFD is set in flags.
+ ResFD uint32
+}
+
+// ioEvent describes an I/O result.
+//
+// +stateify savable
+type ioEvent struct {
+ Data uint64
+ Obj uint64
+ Result int64
+ Result2 int64
+}
+
+// ioEventSize is the size of an ioEvent encoded.
+var ioEventSize = binary.Size(ioEvent{})
+
+// IoSetup implements linux syscall io_setup(2).
+func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nrEvents := args[0].Int()
+ idAddr := args[1].Pointer()
+
+ // Linux uses the native long as the aio ID.
+ //
+ // The context pointer _must_ be zero initially.
+ var idIn uint64
+ if _, err := t.CopyIn(idAddr, &idIn); err != nil {
+ return 0, nil, err
+ }
+ if idIn != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ id, err := t.MemoryManager().NewAIOContext(t, uint32(nrEvents))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Copy out the new ID.
+ if _, err := t.CopyOut(idAddr, &id); err != nil {
+ t.MemoryManager().DestroyAIOContext(t, id)
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// IoDestroy implements linux syscall io_destroy(2).
+func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+
+ // Destroy the given context.
+ if !t.MemoryManager().DestroyAIOContext(t, id) {
+ // Does not exist.
+ return 0, nil, syserror.EINVAL
+ }
+ // FIXME(fvoznika): Linux blocks until all AIO to the destroyed context is
+ // done.
+ return 0, nil, nil
+}
+
+// IoGetevents implements linux syscall io_getevents(2).
+func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+ minEvents := args[1].Int()
+ events := args[2].Int()
+ eventsAddr := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+
+ // Sanity check arguments.
+ if minEvents < 0 || minEvents > events {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ctx, ok := t.MemoryManager().LookupAIOContext(t, id)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Setup the timeout.
+ var haveDeadline bool
+ var deadline ktime.Time
+ if timespecAddr != 0 {
+ d, err := copyTimespecIn(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if !d.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ deadline = t.Kernel().MonotonicClock().Now().Add(d.ToDuration())
+ haveDeadline = true
+ }
+
+ // Loop over all requests.
+ for count := int32(0); count < events; count++ {
+ // Get a request, per semantics.
+ var v interface{}
+ if count >= minEvents {
+ var ok bool
+ v, ok = ctx.PopRequest()
+ if !ok {
+ return uintptr(count), nil, nil
+ }
+ } else {
+ var err error
+ v, err = waitForRequest(ctx, t, haveDeadline, deadline)
+ if err != nil {
+ if count > 0 || err == syserror.ETIMEDOUT {
+ return uintptr(count), nil, nil
+ }
+ return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+ }
+
+ ev := v.(*ioEvent)
+
+ // Copy out the result.
+ if _, err := t.CopyOut(eventsAddr, ev); err != nil {
+ if count > 0 {
+ return uintptr(count), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Keep rolling.
+ eventsAddr += usermem.Addr(ioEventSize)
+ }
+
+ // Everything finished.
+ return uintptr(events), nil, nil
+}
+
+func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadline ktime.Time) (interface{}, error) {
+ for {
+ if v, ok := ctx.PopRequest(); ok {
+ // Request was readly available. Just return it.
+ return v, nil
+ }
+
+ // Need to wait for request completion.
+ done, active := ctx.WaitChannel()
+ if !active {
+ // Context has been destroyed.
+ return nil, syserror.EINVAL
+ }
+ if err := t.BlockWithDeadline(done, haveDeadline, deadline); err != nil {
+ return nil, err
+ }
+ }
+}
+
+// memoryFor returns appropriate memory for the given callback.
+func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
+ bytes := int(cb.Bytes)
+ if bytes < 0 {
+ // Linux also requires that this field fit in ssize_t.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+
+ // Since this I/O will be asynchronous with respect to t's task goroutine,
+ // we have no guarantee that t's AddressSpace will be active during the
+ // I/O.
+ switch cb.OpCode {
+ case _IOCB_CMD_PREAD, _IOCB_CMD_PWRITE:
+ return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case _IOCB_CMD_PREADV, _IOCB_CMD_PWRITEV:
+ return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case _IOCB_CMD_FSYNC, _IOCB_CMD_FDSYNC, _IOCB_CMD_NOOP:
+ return usermem.IOSequence{}, nil
+
+ default:
+ // Not a supported command.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+}
+
+func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioCallback, ioseq usermem.IOSequence, ctx *mm.AIOContext, eventFile *fs.File) {
+ ev := &ioEvent{
+ Data: cb.Data,
+ Obj: uint64(cbAddr),
+ }
+
+ // Construct a context.Context that will not be interrupted if t is
+ // interrupted.
+ c := t.AsyncContext()
+
+ var err error
+ switch cb.OpCode {
+ case _IOCB_CMD_PREAD, _IOCB_CMD_PREADV:
+ ev.Result, err = file.Preadv(c, ioseq, cb.Offset)
+ case _IOCB_CMD_PWRITE, _IOCB_CMD_PWRITEV:
+ ev.Result, err = file.Pwritev(c, ioseq, cb.Offset)
+ case _IOCB_CMD_FSYNC:
+ err = file.Fsync(c, 0, fs.FileMaxOffset, fs.SyncAll)
+ case _IOCB_CMD_FDSYNC:
+ err = file.Fsync(c, 0, fs.FileMaxOffset, fs.SyncData)
+ }
+
+ // Update the result.
+ if err != nil {
+ err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file)
+ ev.Result = -int64(t.ExtractErrno(err, 0))
+ }
+
+ file.DecRef()
+
+ // Queue the result for delivery.
+ ctx.FinishRequest(ev)
+
+ // Notify the event file if one was specified. This needs to happen
+ // *after* queueing the result to avoid racing with the thread we may
+ // wake up.
+ if eventFile != nil {
+ eventFile.FileOperations.(*eventfd.EventOperations).Signal(1)
+ eventFile.DecRef()
+ }
+}
+
+// submitCallback processes a single callback.
+func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Addr) error {
+ file := t.FDMap().GetFile(kdefs.FD(cb.FD))
+ if file == nil {
+ // File not found.
+ return syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Was there an eventFD? Extract it.
+ var eventFile *fs.File
+ if cb.Flags&_IOCB_FLAG_RESFD != 0 {
+ eventFile = t.FDMap().GetFile(kdefs.FD(cb.ResFD))
+ if eventFile == nil {
+ // Bad FD.
+ return syserror.EBADF
+ }
+ defer eventFile.DecRef()
+
+ // Check that it is an eventfd.
+ if _, ok := eventFile.FileOperations.(*eventfd.EventOperations); !ok {
+ // Not an event FD.
+ return syserror.EINVAL
+ }
+ }
+
+ ioseq, err := memoryFor(t, cb)
+ if err != nil {
+ return err
+ }
+
+ // Check offset for reads/writes.
+ switch cb.OpCode {
+ case _IOCB_CMD_PREAD, _IOCB_CMD_PREADV, _IOCB_CMD_PWRITE, _IOCB_CMD_PWRITEV:
+ if cb.Offset < 0 {
+ return syserror.EINVAL
+ }
+ }
+
+ // Prepare the request.
+ ctx, ok := t.MemoryManager().LookupAIOContext(t, id)
+ if !ok {
+ return syserror.EINVAL
+ }
+ if ready := ctx.Prepare(); !ready {
+ // Context is busy.
+ return syserror.EAGAIN
+ }
+
+ if eventFile != nil {
+ // The request is set. Make sure there's a ref on the file.
+ //
+ // This is necessary when the callback executes on completion,
+ // which is also what will release this reference.
+ eventFile.IncRef()
+ }
+
+ // Perform the request asynchronously.
+ file.IncRef()
+ fs.Async(func() { performCallback(t, file, cbAddr, cb, ioseq, ctx, eventFile) })
+
+ // All set.
+ return nil
+}
+
+// IoSubmit implements linux syscall io_submit(2).
+func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+ nrEvents := args[1].Int()
+ addr := args[2].Pointer()
+
+ if nrEvents < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ for i := int32(0); i < nrEvents; i++ {
+ // Copy in the address.
+ cbAddrNative := t.Arch().Native(0)
+ if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Copy in this callback.
+ var cb ioCallback
+ cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
+ if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+
+ if i > 0 {
+ // Some have been successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Process this callback.
+ if err := submitCallback(t, id, &cb, cbAddr); err != nil {
+ if i > 0 {
+ // Partial success.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Advance to the next one.
+ addr += usermem.Addr(t.Arch().Width())
+ }
+
+ return uintptr(nrEvents), nil, nil
+}
+
+// IoCancel implements linux syscall io_cancel(2).
+//
+// It is not presently supported (ENOSYS indicates no support on this
+// architecture).
+func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.ENOSYS
+}
diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go
new file mode 100644
index 000000000..622cb8d0d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_capability.go
@@ -0,0 +1,149 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+func lookupCaps(t *kernel.Task, tid kernel.ThreadID) (permitted, inheritable, effective auth.CapabilitySet, err error) {
+ if tid < 0 {
+ err = syserror.EINVAL
+ return
+ }
+ if tid > 0 {
+ t = t.PIDNamespace().TaskWithID(tid)
+ }
+ if t == nil {
+ err = syserror.ESRCH
+ return
+ }
+ creds := t.Credentials()
+ permitted, inheritable, effective = creds.PermittedCaps, creds.InheritableCaps, creds.EffectiveCaps
+ return
+}
+
+// Capget implements Linux syscall capget.
+func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ hdrAddr := args[0].Pointer()
+ dataAddr := args[1].Pointer()
+
+ var hdr linux.CapUserHeader
+ if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ return 0, nil, err
+ }
+ // hdr.Pid doesn't need to be valid if this capget() is a "version probe"
+ // (hdr.Version is unrecognized and dataAddr is null), so we can't do the
+ // lookup yet.
+ switch hdr.Version {
+ case linux.LINUX_CAPABILITY_VERSION_1:
+ if dataAddr == 0 {
+ return 0, nil, nil
+ }
+ p, i, e, err := lookupCaps(t, kernel.ThreadID(hdr.Pid))
+ if err != nil {
+ return 0, nil, err
+ }
+ data := linux.CapUserData{
+ Effective: uint32(e),
+ Permitted: uint32(p),
+ Inheritable: uint32(i),
+ }
+ _, err = t.CopyOut(dataAddr, &data)
+ return 0, nil, err
+
+ case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
+ if dataAddr == 0 {
+ return 0, nil, nil
+ }
+ p, i, e, err := lookupCaps(t, kernel.ThreadID(hdr.Pid))
+ if err != nil {
+ return 0, nil, err
+ }
+ data := [2]linux.CapUserData{
+ {
+ Effective: uint32(e),
+ Permitted: uint32(p),
+ Inheritable: uint32(i),
+ },
+ {
+ Effective: uint32(e >> 32),
+ Permitted: uint32(p >> 32),
+ Inheritable: uint32(i >> 32),
+ },
+ }
+ _, err = t.CopyOut(dataAddr, &data)
+ return 0, nil, err
+
+ default:
+ hdr.Version = linux.HighestCapabilityVersion
+ if _, err := t.CopyOut(hdrAddr, &hdr); err != nil {
+ return 0, nil, err
+ }
+ if dataAddr != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ return 0, nil, nil
+ }
+}
+
+// Capset implements Linux syscall capset.
+func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ hdrAddr := args[0].Pointer()
+ dataAddr := args[1].Pointer()
+
+ var hdr linux.CapUserHeader
+ if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ return 0, nil, err
+ }
+ switch hdr.Version {
+ case linux.LINUX_CAPABILITY_VERSION_1:
+ if tid := kernel.ThreadID(hdr.Pid); tid != 0 && tid != t.ThreadID() {
+ return 0, nil, syserror.EPERM
+ }
+ var data linux.CapUserData
+ if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ return 0, nil, err
+ }
+ p := auth.CapabilitySet(data.Permitted) & auth.AllCapabilities
+ i := auth.CapabilitySet(data.Inheritable) & auth.AllCapabilities
+ e := auth.CapabilitySet(data.Effective) & auth.AllCapabilities
+ return 0, nil, t.SetCapabilitySets(p, i, e)
+
+ case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
+ if tid := kernel.ThreadID(hdr.Pid); tid != 0 && tid != t.ThreadID() {
+ return 0, nil, syserror.EPERM
+ }
+ var data [2]linux.CapUserData
+ if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ return 0, nil, err
+ }
+ p := (auth.CapabilitySet(data[0].Permitted) | (auth.CapabilitySet(data[1].Permitted) << 32)) & auth.AllCapabilities
+ i := (auth.CapabilitySet(data[0].Inheritable) | (auth.CapabilitySet(data[1].Inheritable) << 32)) & auth.AllCapabilities
+ e := (auth.CapabilitySet(data[0].Effective) | (auth.CapabilitySet(data[1].Effective) << 32)) & auth.AllCapabilities
+ return 0, nil, t.SetCapabilitySets(p, i, e)
+
+ default:
+ hdr.Version = linux.HighestCapabilityVersion
+ if _, err := t.CopyOut(hdrAddr, &hdr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
new file mode 100644
index 000000000..1467feb4e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/epoll"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/syscalls"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// EpollCreate1 implements the epoll_create1(2) linux syscall.
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags & ^syscall.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ closeOnExec := flags&syscall.EPOLL_CLOEXEC != 0
+ fd, err := syscalls.CreateEpoll(t, closeOnExec)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements the epoll_create(2) linux syscall.
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ fd, err := syscalls.CreateEpoll(t, false)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements the epoll_ctl(2) linux syscall.
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := kdefs.FD(args[0].Int())
+ op := args[1].Int()
+ fd := kdefs.FD(args[2].Int())
+ eventAddr := args[3].Pointer()
+
+ // Capture the event state if needed.
+ flags := epoll.EntryFlags(0)
+ mask := waiter.EventMask(0)
+ var data [2]int32
+ if op != syscall.EPOLL_CTL_DEL {
+ var e syscall.EpollEvent
+ if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ return 0, nil, err
+ }
+
+ if e.Events&syscall.EPOLLONESHOT != 0 {
+ flags |= epoll.OneShot
+ }
+
+ // syscall.EPOLLET is incorrectly generated as a negative number
+ // in Go, see https://github.com/golang/go/issues/5328 for
+ // details.
+ if e.Events&-syscall.EPOLLET != 0 {
+ flags |= epoll.EdgeTriggered
+ }
+
+ mask = waiter.EventMaskFromLinux(e.Events)
+ data[0] = e.Fd
+ data[1] = e.Pad
+ }
+
+ // Perform the requested operations.
+ switch op {
+ case syscall.EPOLL_CTL_ADD:
+ // See fs/eventpoll.c.
+ mask |= waiter.EventHUp | waiter.EventErr
+ return 0, nil, syscalls.AddEpoll(t, epfd, fd, flags, mask, data)
+ case syscall.EPOLL_CTL_DEL:
+ return 0, nil, syscalls.RemoveEpoll(t, epfd, fd)
+ case syscall.EPOLL_CTL_MOD:
+ // Same as EPOLL_CTL_ADD.
+ mask |= waiter.EventHUp | waiter.EventErr
+ return 0, nil, syscalls.UpdateEpoll(t, epfd, fd, flags, mask, data)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// copyOutEvents copies epoll events from the kernel to user memory.
+func copyOutEvents(t *kernel.Task, addr usermem.Addr, e []epoll.Event) error {
+ const itemLen = 12
+ if _, ok := addr.AddLength(uint64(len(e)) * itemLen); !ok {
+ return syserror.EFAULT
+ }
+
+ b := t.CopyScratchBuffer(itemLen)
+ for i := range e {
+ usermem.ByteOrder.PutUint32(b[0:], e[i].Events)
+ usermem.ByteOrder.PutUint32(b[4:], uint32(e[i].Data[0]))
+ usermem.ByteOrder.PutUint32(b[8:], uint32(e[i].Data[1]))
+ if _, err := t.CopyOutBytes(addr, b); err != nil {
+ return err
+ }
+ addr += itemLen
+ }
+
+ return nil
+}
+
+// EpollWait implements the epoll_wait(2) linux syscall.
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := kdefs.FD(args[0].Int())
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ r, err := syscalls.WaitEpoll(t, epfd, maxEvents, timeout)
+ if err != nil {
+ return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ if len(r) != 0 {
+ if err := copyOutEvents(t, eventsAddr, r); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return uintptr(len(r)), nil, nil
+}
+
+// EpollPwait implements the epoll_pwait(2) linux syscall.
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if maskAddr != 0 {
+ mask, err := copyInSigSet(t, maskAddr, maskSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ }
+
+ return EpollWait(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_eventfd.go b/pkg/sentry/syscalls/linux/sys_eventfd.go
new file mode 100644
index 000000000..ca4ead488
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_eventfd.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/eventfd"
+)
+
+const (
+ // EFD_SEMAPHORE is a flag used in syscall eventfd(2) and eventfd2(2). Please
+ // see its man page for more information.
+ EFD_SEMAPHORE = 1
+ EFD_NONBLOCK = 0x800
+ EFD_CLOEXEC = 0x80000
+)
+
+// Eventfd2 implements linux syscall eventfd2(2).
+func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ initVal := args[0].Int()
+ flags := uint(args[1].Uint())
+ allOps := uint(EFD_SEMAPHORE | EFD_NONBLOCK | EFD_CLOEXEC)
+
+ if flags & ^allOps != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ event := eventfd.New(t, uint64(initVal), flags&EFD_SEMAPHORE != 0)
+ event.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&EFD_NONBLOCK != 0,
+ })
+ defer event.DecRef()
+
+ fd, err := t.FDMap().NewFDFrom(0, event, kernel.FDFlags{
+ CloseOnExec: flags&EFD_CLOEXEC != 0,
+ },
+ t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// Eventfd implements linux syscall eventfd(2).
+func Eventfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[1].Value = 0
+ return Eventfd2(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
new file mode 100644
index 000000000..19f579930
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -0,0 +1,2088 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/lock"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/fasync"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/pipe"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// fileOpAt performs an operation on the second last component in the path.
+func fileOpAt(t *kernel.Task, dirFD kdefs.FD, path string, fn func(root *fs.Dirent, d *fs.Dirent, name string) error) error {
+ // Extract the last component.
+ dir, name := fs.SplitLast(path)
+ if dir == "/" {
+ // Common case: we are accessing a file in the root.
+ root := t.FSContext().RootDirectory()
+ err := fn(root, root, name)
+ root.DecRef()
+ return err
+ } else if dir == "." && dirFD == linux.AT_FDCWD {
+ // Common case: we are accessing a file relative to the current
+ // working directory; skip the look-up.
+ wd := t.FSContext().WorkingDirectory()
+ root := t.FSContext().RootDirectory()
+ err := fn(root, wd, name)
+ wd.DecRef()
+ root.DecRef()
+ return err
+ }
+
+ return fileOpOn(t, dirFD, dir, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ return fn(root, d, name)
+ })
+}
+
+// fileOpOn performs an operation on the last entry of the path.
+func fileOpOn(t *kernel.Task, dirFD kdefs.FD, path string, resolve bool, fn func(root *fs.Dirent, d *fs.Dirent) error) error {
+ var (
+ d *fs.Dirent // The file.
+ wd *fs.Dirent // The working directory (if required.)
+ rel *fs.Dirent // The relative directory for search (if required.)
+ f *fs.File // The file corresponding to dirFD (if required.)
+ err error
+ )
+
+ // Extract the working directory (maybe).
+ if len(path) > 0 && path[0] == '/' {
+ // Absolute path; rel can be nil.
+ } else if dirFD == linux.AT_FDCWD {
+ // Need to reference the working directory.
+ wd = t.FSContext().WorkingDirectory()
+ rel = wd
+ } else {
+ // Need to extract the given FD.
+ f = t.FDMap().GetFile(dirFD)
+ if f == nil {
+ return syserror.EBADF
+ }
+ rel = f.Dirent
+ if !fs.IsDir(rel.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+ }
+
+ // Grab the root (always required.)
+ root := t.FSContext().RootDirectory()
+
+ // Lookup the node.
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ if resolve {
+ d, err = t.MountNamespace().FindInode(t, root, rel, path, &remainingTraversals)
+ } else {
+ d, err = t.MountNamespace().FindLink(t, root, rel, path, &remainingTraversals)
+ }
+ root.DecRef()
+ if wd != nil {
+ wd.DecRef()
+ }
+ if f != nil {
+ f.DecRef()
+ }
+ if err != nil {
+ return err
+ }
+
+ err = fn(root, d)
+ d.DecRef()
+ return err
+}
+
+// copyInPath copies a path in.
+func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string, dirPath bool, err error) {
+ path, err = t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return "", false, err
+ }
+ if path == "" && !allowEmpty {
+ return "", false, syserror.ENOENT
+ }
+
+ // If the path ends with a /, then checks must be enforced in various
+ // ways in the different callers. We pass this back to the caller.
+ path, dirPath = fs.TrimTrailingSlashes(path)
+
+ return path, dirPath, nil
+}
+
+func openAt(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr, flags uint) (fd uintptr, err error) {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, err
+ }
+
+ resolve := flags&linux.O_NOFOLLOW == 0
+ err = fileOpOn(t, dirFD, path, resolve, func(root *fs.Dirent, d *fs.Dirent) error {
+ // First check a few things about the filesystem before trying to get the file
+ // reference.
+ //
+ // It's required that Check does not try to open files not that aren't backed by
+ // this dirent (e.g. pipes and sockets) because this would result in opening these
+ // files an extra time just to check permissions.
+ if err := d.Inode.CheckPermission(t, flagsToPermissions(flags)); err != nil {
+ return err
+ }
+
+ if fs.IsSymlink(d.Inode.StableAttr) && !resolve {
+ return syserror.ELOOP
+ }
+
+ fileFlags := linuxToFlags(flags)
+ // Linux always adds the O_LARGEFILE flag when running in 64-bit mode.
+ fileFlags.LargeFile = true
+ if fs.IsDir(d.Inode.StableAttr) {
+ // Don't allow directories to be opened writable.
+ if fileFlags.Write {
+ return syserror.EISDIR
+ }
+ } else {
+ // If O_DIRECTORY is set, but the file is not a directory, then fail.
+ if fileFlags.Directory {
+ return syserror.ENOTDIR
+ }
+ // If it's a directory, then make sure.
+ if dirPath {
+ return syserror.ENOTDIR
+ }
+ if flags&linux.O_TRUNC != 0 {
+ if err := d.Inode.Truncate(t, d, 0); err != nil {
+ return err
+ }
+ }
+ }
+
+ file, err := d.Inode.GetFile(t, d, fileFlags)
+ if err != nil {
+ return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+ }
+ defer file.DecRef()
+
+ // Success.
+ fdFlags := kernel.FDFlags{CloseOnExec: flags&linux.O_CLOEXEC != 0}
+ newFD, err := t.FDMap().NewFDFrom(0, file, fdFlags, t.ThreadGroup().Limits())
+ if err != nil {
+ return err
+ }
+
+ // Set return result in frame.
+ fd = uintptr(newFD)
+
+ // Generate notification for opened file.
+ d.InotifyEvent(linux.IN_OPEN, 0)
+
+ return nil
+ })
+ return fd, err // Use result in frame.
+}
+
+func mknodAt(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr, mode linux.FileMode) error {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ if dirPath {
+ return syserror.ENOENT
+ }
+
+ return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Do we have the appropriate permissions on the parent?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ // Attempt a creation.
+ perms := fs.FilePermsFromMode(mode &^ linux.FileMode(t.FSContext().Umask()))
+
+ switch mode.FileType() {
+ case 0:
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ fallthrough
+ case linux.ModeRegular:
+ // We are not going to return the file, so the actual
+ // flags used don't matter, but they cannot be empty or
+ // Create will complain.
+ flags := fs.FileFlags{Read: true, Write: true}
+ file, err := d.Create(t, root, name, flags, perms)
+ if err != nil {
+ return err
+ }
+ file.DecRef()
+ return nil
+
+ case linux.ModeNamedPipe:
+ return d.CreateFifo(t, root, name, perms)
+
+ case linux.ModeSocket:
+ // While it is possible create a unix domain socket file on linux
+ // using mknod(2), in practice this is pretty useless from an
+ // application. Linux internally uses mknod() to create the socket
+ // node during bind(2), but we implement bind(2) independently. If
+ // an application explicitly creates a socket node using mknod(),
+ // you can't seem to bind() or connect() to the resulting socket.
+ //
+ // Instead of emulating this seemingly useless behaviour, we'll
+ // indicate that the filesystem doesn't support the creation of
+ // sockets.
+ return syserror.EOPNOTSUPP
+
+ case linux.ModeCharacterDevice:
+ fallthrough
+ case linux.ModeBlockDevice:
+ // TODO(b/72101894): We don't support creating block or character
+ // devices at the moment.
+ //
+ // When we start supporting block and character devices, we'll
+ // need to check for CAP_MKNOD here.
+ return syserror.EPERM
+
+ default:
+ // "EINVAL - mode requested creation of something other than a
+ // regular file, device special file, FIFO or socket." - mknod(2)
+ return syserror.EINVAL
+ }
+ })
+}
+
+// Mknod implements the linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ path := args[0].Pointer()
+ mode := linux.FileMode(args[1].ModeT())
+ // We don't need this argument until we support creation of device nodes.
+ _ = args[2].Uint() // dev
+
+ return 0, nil, mknodAt(t, linux.AT_FDCWD, path, mode)
+}
+
+// Mknodat implements the linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ path := args[1].Pointer()
+ mode := linux.FileMode(args[2].ModeT())
+ // We don't need this argument until we support creation of device nodes.
+ _ = args[3].Uint() // dev
+
+ return 0, nil, mknodAt(t, dirFD, path, mode)
+}
+
+func createAt(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr, flags uint, mode linux.FileMode) (fd uintptr, err error) {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, err
+ }
+ if dirPath {
+ return 0, syserror.ENOENT
+ }
+
+ err = fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ fileFlags := linuxToFlags(flags)
+ // Linux always adds the O_LARGEFILE flag when running in 64-bit mode.
+ fileFlags.LargeFile = true
+
+ // Does this file exist already?
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ targetDirent, err := t.MountNamespace().FindInode(t, root, d, name, &remainingTraversals)
+ var newFile *fs.File
+ switch err {
+ case nil:
+ // The file existed.
+ defer targetDirent.DecRef()
+
+ // Check if we wanted to create.
+ if flags&linux.O_EXCL != 0 {
+ return syserror.EEXIST
+ }
+
+ // Like sys_open, check for a few things about the
+ // filesystem before trying to get a reference to the
+ // fs.File. The same constraints on Check apply.
+ if err := targetDirent.Inode.CheckPermission(t, flagsToPermissions(flags)); err != nil {
+ return err
+ }
+
+ // Should we truncate the file?
+ if flags&linux.O_TRUNC != 0 {
+ if err := targetDirent.Inode.Truncate(t, targetDirent, 0); err != nil {
+ return err
+ }
+ }
+
+ // Create a new fs.File.
+ newFile, err = targetDirent.Inode.GetFile(t, targetDirent, fileFlags)
+ if err != nil {
+ return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+ }
+ defer newFile.DecRef()
+ case syserror.ENOENT:
+ // File does not exist. Proceed with creation.
+
+ // Do we have write permissions on the parent?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ // Attempt a creation.
+ perms := fs.FilePermsFromMode(mode &^ linux.FileMode(t.FSContext().Umask()))
+ newFile, err = d.Create(t, root, name, fileFlags, perms)
+ if err != nil {
+ // No luck, bail.
+ return err
+ }
+ defer newFile.DecRef()
+ targetDirent = newFile.Dirent
+ default:
+ return err
+ }
+
+ // Success.
+ fdFlags := kernel.FDFlags{CloseOnExec: flags&linux.O_CLOEXEC != 0}
+ newFD, err := t.FDMap().NewFDFrom(0, newFile, fdFlags, t.ThreadGroup().Limits())
+ if err != nil {
+ return err
+ }
+
+ // Set result in frame.
+ fd = uintptr(newFD)
+
+ // Queue the open inotify event. The creation event is
+ // automatically queued when the dirent is targetDirent. The
+ // open events are implemented at the syscall layer so we need
+ // to manually queue one here.
+ targetDirent.InotifyEvent(linux.IN_OPEN, 0)
+
+ return nil
+ })
+ return fd, err // Use result in frame.
+}
+
+// Open implements linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := uint(args[1].Uint())
+ if flags&linux.O_CREAT != 0 {
+ mode := linux.FileMode(args[2].ModeT())
+ n, err := createAt(t, linux.AT_FDCWD, addr, flags, mode)
+ return n, nil, err
+ }
+ n, err := openAt(t, linux.AT_FDCWD, addr, flags)
+ return n, nil, err
+}
+
+// Openat implements linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ flags := uint(args[2].Uint())
+ if flags&linux.O_CREAT != 0 {
+ mode := linux.FileMode(args[3].ModeT())
+ n, err := createAt(t, dirFD, addr, flags, mode)
+ return n, nil, err
+ }
+ n, err := openAt(t, dirFD, addr, flags)
+ return n, nil, err
+}
+
+// Creat implements linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := linux.FileMode(args[1].ModeT())
+ n, err := createAt(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_TRUNC, mode)
+ return n, nil, err
+}
+
+// accessContext is a context that overrides the credentials used, but
+// otherwise carries the same values as the embedded context.
+//
+// accessContext should only be used for access(2).
+type accessContext struct {
+ context.Context
+ creds *auth.Credentials
+}
+
+// Value implements context.Context.
+func (ac accessContext) Value(key interface{}) interface{} {
+ switch key {
+ case auth.CtxCredentials:
+ return ac.creds
+ default:
+ return ac.Context.Value(key)
+ }
+}
+
+func accessAt(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr, resolve bool, mode uint) error {
+ const rOK = 4
+ const wOK = 2
+ const xOK = 1
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ // Sanity check the mode.
+ if mode&^(rOK|wOK|xOK) != 0 {
+ return syserror.EINVAL
+ }
+
+ return fileOpOn(t, dirFD, path, resolve, func(root *fs.Dirent, d *fs.Dirent) error {
+ // access(2) and faccessat(2) check permissions using real
+ // UID/GID, not effective UID/GID.
+ //
+ // "access() needs to use the real uid/gid, not the effective
+ // uid/gid. We do this by temporarily clearing all FS-related
+ // capabilities and switching the fsuid/fsgid around to the
+ // real ones." -fs/open.c:faccessat
+ creds := t.Credentials().Fork()
+ creds.EffectiveKUID = creds.RealKUID
+ creds.EffectiveKGID = creds.RealKGID
+ if creds.EffectiveKUID.In(creds.UserNamespace) == auth.RootUID {
+ creds.EffectiveCaps = creds.PermittedCaps
+ } else {
+ creds.EffectiveCaps = 0
+ }
+
+ ctx := &accessContext{
+ Context: t,
+ creds: creds,
+ }
+
+ return d.Inode.CheckPermission(ctx, fs.PermMask{
+ Read: mode&rOK != 0,
+ Write: mode&wOK != 0,
+ Execute: mode&xOK != 0,
+ })
+ })
+}
+
+// Access implements linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, true, mode)
+}
+
+// Faccessat implements linux syscall faccessat(2).
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ flags := args[3].Int()
+
+ return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
+}
+
+// Ioctl implements linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ request := int(args[1].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Shared flags between file and socket.
+ switch request {
+ case linux.FIONCLEX:
+ t.FDMap().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: false,
+ })
+ return 0, nil, nil
+ case linux.FIOCLEX:
+ t.FDMap().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: true,
+ })
+ return 0, nil, nil
+
+ case linux.FIONBIO:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.Flags()
+ if set != 0 {
+ flags.NonBlocking = true
+ } else {
+ flags.NonBlocking = false
+ }
+ file.SetFlags(flags.Settable())
+ return 0, nil, nil
+
+ case linux.FIOASYNC:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.Flags()
+ if set != 0 {
+ flags.Async = true
+ } else {
+ flags.Async = false
+ }
+ file.SetFlags(flags.Settable())
+ return 0, nil, nil
+
+ case linux.FIOSETOWN, linux.SIOCSPGRP:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ fSetOwn(t, file, set)
+ return 0, nil, nil
+
+ case linux.FIOGETOWN, linux.SIOCGPGRP:
+ who := fGetOwn(t, file)
+ _, err := t.CopyOut(args[2].Pointer(), &who)
+ return 0, nil, err
+
+ default:
+ ret, err := file.FileOperations.Ioctl(t, t.MemoryManager(), args)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return ret, nil, nil
+ }
+}
+
+// Getcwd implements the linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+ cwd := t.FSContext().WorkingDirectory()
+ defer cwd.DecRef()
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+
+ // Get our fullname from the root and preprend unreachable if the root was
+ // unreachable from our current dirent this is the same behavior as on linux.
+ s, reachable := cwd.FullName(root)
+ if !reachable {
+ s = "(unreachable)" + s
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Copy out the path name for the node.
+ bytes, err := t.CopyOutBytes(addr, []byte(s))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Top it off with a terminator.
+ _, err = t.CopyOut(addr+usermem.Addr(bytes), []byte("\x00"))
+ return uintptr(bytes + 1), nil, err
+}
+
+// Chroot implements the linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ // Is it a directory?
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Does it have execute permissions?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Execute: true}); err != nil {
+ return err
+ }
+
+ t.FSContext().SetRootDirectory(d)
+ return nil
+ })
+}
+
+// Chdir implements the linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ // Is it a directory?
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Does it have execute permissions?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Execute: true}); err != nil {
+ return err
+ }
+
+ t.FSContext().SetWorkingDirectory(d)
+ return nil
+ })
+}
+
+// Fchdir implements the linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Is it a directory?
+ if !fs.IsDir(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.ENOTDIR
+ }
+
+ // Does it have execute permissions?
+ if err := file.Dirent.Inode.CheckPermission(t, fs.PermMask{Execute: true}); err != nil {
+ return 0, nil, err
+ }
+
+ t.FSContext().SetWorkingDirectory(file.Dirent)
+ return 0, nil, nil
+}
+
+// Close implements linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+
+ file, ok := t.FDMap().Remove(fd)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.Flush(t)
+ return 0, nil, handleIOError(t, false /* partial */, err, syscall.EINTR, "close", file)
+}
+
+// Dup implements linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newfd, err := t.FDMap().NewFDFrom(0, file, kernel.FDFlags{}, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newfd), nil, nil
+}
+
+// Dup2 implements linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := kdefs.FD(args[0].Int())
+ newfd := kdefs.FD(args[1].Int())
+
+ // If oldfd is a valid file descriptor, and newfd has the same value as oldfd,
+ // then dup2() does nothing, and returns newfd.
+ if oldfd == newfd {
+ oldFile := t.FDMap().GetFile(oldfd)
+ if oldFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer oldFile.DecRef()
+
+ return uintptr(newfd), nil, nil
+ }
+
+ // Zero out flags arg to be used by Dup3.
+ args[2].Value = 0
+ return Dup3(t, args)
+}
+
+// Dup3 implements linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := kdefs.FD(args[0].Int())
+ newfd := kdefs.FD(args[1].Int())
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ oldFile := t.FDMap().GetFile(oldfd)
+ if oldFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer oldFile.DecRef()
+
+ err := t.FDMap().NewFDAt(newfd, oldFile, kernel.FDFlags{CloseOnExec: flags&linux.O_CLOEXEC != 0}, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(newfd), nil, nil
+}
+
+func fGetOwn(t *kernel.Task, file *fs.File) int32 {
+ ma := file.Async(nil)
+ if ma == nil {
+ return 0
+ }
+ a := ma.(*fasync.FileAsync)
+ ot, otg, opg := a.Owner()
+ switch {
+ case ot != nil:
+ return int32(t.PIDNamespace().IDOfTask(ot))
+ case otg != nil:
+ return int32(t.PIDNamespace().IDOfThreadGroup(otg))
+ case opg != nil:
+ return int32(-t.PIDNamespace().IDOfProcessGroup(opg))
+ default:
+ return 0
+ }
+}
+
+// fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux.
+//
+// If who is positive, it represents a PID. If negative, it represents a PGID.
+// If the PID or PGID is invalid, the owner is silently unset.
+func fSetOwn(t *kernel.Task, file *fs.File, who int32) {
+ a := file.Async(fasync.New).(*fasync.FileAsync)
+ if who < 0 {
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(-who))
+ a.SetOwnerProcessGroup(t, pg)
+ }
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who))
+ a.SetOwnerThreadGroup(t, tg)
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ cmd := args[1].Int()
+
+ file, flags := t.FDMap().GetDescriptor(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ from := kdefs.FD(args[2].Int())
+ fdFlags := kernel.FDFlags{CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC}
+ fd, err := t.FDMap().NewFDFrom(from, file, fdFlags, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ t.FDMap().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ case linux.F_GETFL:
+ return uintptr(file.Flags().ToLinux()), nil, nil
+ case linux.F_SETFL:
+ flags := uint(args[2].Uint())
+ file.SetFlags(linuxToFlags(flags).Settable())
+ case linux.F_SETLK, linux.F_SETLKW:
+ // In Linux the file system can choose to provide lock operations for an inode.
+ // Normally pipe and socket types lack lock operations. We diverge and use a heavy
+ // hammer by only allowing locks on files and directories.
+ if !fs.IsFile(file.Dirent.Inode.StableAttr) && !fs.IsDir(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Copy in the lock request.
+ flockAddr := args[2].Pointer()
+ var flock syscall.Flock_t
+ if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ return 0, nil, err
+ }
+
+ // Compute the lock whence.
+ var sw fs.SeekWhence
+ switch flock.Whence {
+ case 0:
+ sw = fs.SeekSet
+ case 1:
+ sw = fs.SeekCurrent
+ case 2:
+ sw = fs.SeekEnd
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Compute the lock offset.
+ var off int64
+ switch sw {
+ case fs.SeekSet:
+ off = 0
+ case fs.SeekCurrent:
+ // Note that Linux does not hold any mutexes while retrieving the file offset,
+ // see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk.
+ off = file.Offset()
+ case fs.SeekEnd:
+ uattr, err := file.Dirent.Inode.UnstableAttr(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ off = uattr.Size
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Compute the lock range.
+ rng, err := lock.ComputeRange(flock.Start, flock.Len, off)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // The lock uid is that of the Task's FDMap.
+ lockUniqueID := lock.UniqueID(t.FDMap().ID())
+
+ // These locks don't block; execute the non-blocking operation using the inode's lock
+ // context directly.
+ switch flock.Type {
+ case syscall.F_RDLCK:
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+ if cmd == syscall.F_SETLK {
+ // Non-blocking lock, provide a nil lock.Blocker.
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) {
+ return 0, nil, syserror.EAGAIN
+ }
+ } else {
+ // Blocking lock, pass in the task to satisfy the lock.Blocker interface.
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, t) {
+ return 0, nil, syserror.EINTR
+ }
+ }
+ return 0, nil, nil
+ case syscall.F_WRLCK:
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+ if cmd == syscall.F_SETLK {
+ // Non-blocking lock, provide a nil lock.Blocker.
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) {
+ return 0, nil, syserror.EAGAIN
+ }
+ } else {
+ // Blocking lock, pass in the task to satisfy the lock.Blocker interface.
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, t) {
+ return 0, nil, syserror.EINTR
+ }
+ }
+ return 0, nil, nil
+ case syscall.F_UNLCK:
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lockUniqueID, rng)
+ return 0, nil, nil
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ case linux.F_GETOWN:
+ return uintptr(fGetOwn(t, file)), nil, nil
+ case linux.F_SETOWN:
+ fSetOwn(t, file, args[2].Int())
+ return 0, nil, nil
+ case linux.F_GET_SEALS:
+ val, err := tmpfs.GetSeals(file.Dirent.Inode)
+ return uintptr(val), nil, err
+ case linux.F_ADD_SEALS:
+ if !file.Flags().Write {
+ return 0, nil, syserror.EPERM
+ }
+ err := tmpfs.AddSeals(file.Dirent.Inode, args[2].Uint())
+ return 0, nil, err
+ case linux.F_GETPIPE_SZ:
+ sz, ok := file.FileOperations.(pipe.Sizer)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ return uintptr(sz.PipeSize()), nil, nil
+ case linux.F_SETPIPE_SZ:
+ sz, ok := file.FileOperations.(pipe.Sizer)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ n, err := sz.SetPipeSize(int64(args[2].Int()))
+ return uintptr(n), nil, err
+ default:
+ // Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+ return 0, nil, nil
+}
+
+const (
+ _FADV_NORMAL = 0
+ _FADV_RANDOM = 1
+ _FADV_SEQUENTIAL = 2
+ _FADV_WILLNEED = 3
+ _FADV_DONTNEED = 4
+ _FADV_NOREUSE = 5
+)
+
+// Fadvise64 implements linux syscall fadvise64(2).
+// This implementation currently ignores the provided advice.
+func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ length := args[2].Int64()
+ advice := args[3].Int()
+
+ // Note: offset is allowed to be negative.
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // If the FD refers to a pipe or FIFO, return error.
+ if fs.IsPipe(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ switch advice {
+ case _FADV_NORMAL:
+ case _FADV_RANDOM:
+ case _FADV_SEQUENTIAL:
+ case _FADV_WILLNEED:
+ case _FADV_DONTNEED:
+ case _FADV_NOREUSE:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Sure, whatever.
+ return 0, nil, nil
+}
+
+func mkdirAt(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr, mode linux.FileMode) error {
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Does this directory exist already?
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ f, err := t.MountNamespace().FindInode(t, root, d, name, &remainingTraversals)
+ switch err {
+ case nil:
+ // The directory existed.
+ defer f.DecRef()
+ return syserror.EEXIST
+ case syserror.EACCES:
+ // Permission denied while walking to the directory.
+ return err
+ default:
+ // Do we have write permissions on the parent?
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+
+ // Create the directory.
+ perms := fs.FilePermsFromMode(mode &^ linux.FileMode(t.FSContext().Umask()))
+ return d.CreateDirectory(t, root, name, perms)
+ }
+ })
+}
+
+// Mkdir implements linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := linux.FileMode(args[1].ModeT())
+
+ return 0, nil, mkdirAt(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ mode := linux.FileMode(args[2].ModeT())
+
+ return 0, nil, mkdirAt(t, dirFD, addr, mode)
+}
+
+func rmdirAt(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr) error {
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ // Special case: removing the root always returns EBUSY.
+ if path == "/" {
+ return syserror.EBUSY
+ }
+
+ return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Linux returns different ernos when the path ends in single
+ // dot vs. double dots.
+ switch name {
+ case ".":
+ return syserror.EINVAL
+ case "..":
+ return syserror.ENOTEMPTY
+ }
+
+ if err := fs.MayDelete(t, root, d, name); err != nil {
+ return err
+ }
+
+ return d.RemoveDirectory(t, root, name)
+ })
+}
+
+// Rmdir implements linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ return 0, nil, rmdirAt(t, linux.AT_FDCWD, addr)
+}
+
+func symlinkAt(t *kernel.Task, dirFD kdefs.FD, newAddr usermem.Addr, oldAddr usermem.Addr) error {
+ newPath, dirPath, err := copyInPath(t, newAddr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ if dirPath {
+ return syserror.ENOENT
+ }
+
+ // The oldPath is copied in verbatim. This is because the symlink
+ // will include all details, including trailing slashes.
+ oldPath, err := t.CopyInString(oldAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ if oldPath == "" {
+ return syserror.ENOENT
+ }
+
+ return fileOpAt(t, dirFD, newPath, func(root *fs.Dirent, d *fs.Dirent, name string) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Make sure we have write permissions on the parent directory.
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+ return d.CreateLink(t, root, oldPath, name)
+ })
+}
+
+// Symlink implements linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldAddr := args[0].Pointer()
+ newAddr := args[1].Pointer()
+
+ return 0, nil, symlinkAt(t, linux.AT_FDCWD, newAddr, oldAddr)
+}
+
+// Symlinkat implements linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldAddr := args[0].Pointer()
+ dirFD := kdefs.FD(args[1].Int())
+ newAddr := args[2].Pointer()
+
+ return 0, nil, symlinkAt(t, dirFD, newAddr, oldAddr)
+}
+
+// mayLinkAt determines whether t can create a hard link to target.
+//
+// This corresponds to Linux's fs/namei.c:may_linkat.
+func mayLinkAt(t *kernel.Task, target *fs.Inode) error {
+ // Linux will impose the following restrictions on hard links only if
+ // sysctl_protected_hardlinks is enabled. The kernel disables this
+ // setting by default for backward compatibility (see commit
+ // 561ec64ae67e), but also recommends that distributions enable it (and
+ // Debian does:
+ // https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=889098).
+ //
+ // gVisor currently behaves as though sysctl_protected_hardlinks is
+ // always enabled, and thus imposes the following restrictions on hard
+ // links.
+
+ if target.CheckOwnership(t) {
+ // fs/namei.c:may_linkat: "Source inode owner (or CAP_FOWNER)
+ // can hardlink all they like."
+ return nil
+ }
+
+ // If we are not the owner, then the file must be regular and have
+ // Read+Write permissions.
+ if !fs.IsRegular(target.StableAttr) {
+ return syserror.EPERM
+ }
+ if target.CheckPermission(t, fs.PermMask{Read: true, Write: true}) != nil {
+ return syserror.EPERM
+ }
+
+ return nil
+}
+
+// linkAt creates a hard link to the target specified by oldDirFD and oldAddr,
+// specified by newDirFD and newAddr. If resolve is true, then the symlinks
+// will be followed when evaluating the target.
+func linkAt(t *kernel.Task, oldDirFD kdefs.FD, oldAddr usermem.Addr, newDirFD kdefs.FD, newAddr usermem.Addr, resolve, allowEmpty bool) error {
+ oldPath, _, err := copyInPath(t, oldAddr, allowEmpty)
+ if err != nil {
+ return err
+ }
+ newPath, dirPath, err := copyInPath(t, newAddr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ if dirPath {
+ return syserror.ENOENT
+ }
+
+ if allowEmpty && oldPath == "" {
+ target := t.FDMap().GetFile(oldDirFD)
+ if target == nil {
+ return syserror.EBADF
+ }
+ defer target.DecRef()
+ if err := mayLinkAt(t, target.Dirent.Inode); err != nil {
+ return err
+ }
+
+ // Resolve the target directory.
+ return fileOpAt(t, newDirFD, newPath, func(root *fs.Dirent, newParent *fs.Dirent, newName string) error {
+ if !fs.IsDir(newParent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Make sure we have write permissions on the parent directory.
+ if err := newParent.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+ return newParent.CreateHardLink(t, root, target.Dirent, newName)
+ })
+ }
+
+ // Resolve oldDirFD and oldAddr to a dirent. The "resolve" argument
+ // only applies to this name.
+ return fileOpOn(t, oldDirFD, oldPath, resolve, func(root *fs.Dirent, target *fs.Dirent) error {
+ if err := mayLinkAt(t, target.Inode); err != nil {
+ return err
+ }
+
+ // Next resolve newDirFD and newAddr to the parent dirent and name.
+ return fileOpAt(t, newDirFD, newPath, func(root *fs.Dirent, newParent *fs.Dirent, newName string) error {
+ if !fs.IsDir(newParent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Make sure we have write permissions on the parent directory.
+ if err := newParent.Inode.CheckPermission(t, fs.PermMask{Write: true, Execute: true}); err != nil {
+ return err
+ }
+ return newParent.CreateHardLink(t, root, target, newName)
+ })
+ })
+}
+
+// Link implements linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldAddr := args[0].Pointer()
+ newAddr := args[1].Pointer()
+
+ // man link(2):
+ // POSIX.1-2001 says that link() should dereference oldpath if it is a
+ // symbolic link. However, since kernel 2.0, Linux does not do so: if
+ // oldpath is a symbolic link, then newpath is created as a (hard) link
+ // to the same symbolic link file (i.e., newpath becomes a symbolic
+ // link to the same file that oldpath refers to).
+ resolve := false
+ return 0, nil, linkAt(t, linux.AT_FDCWD, oldAddr, linux.AT_FDCWD, newAddr, resolve, false /* allowEmpty */)
+}
+
+// Linkat implements linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldDirFD := kdefs.FD(args[0].Int())
+ oldAddr := args[1].Pointer()
+ newDirFD := kdefs.FD(args[2].Int())
+ newAddr := args[3].Pointer()
+
+ // man linkat(2):
+ // By default, linkat(), does not dereference oldpath if it is a
+ // symbolic link (like link(2)). Since Linux 2.6.18, the flag
+ // AT_SYMLINK_FOLLOW can be specified in flags to cause oldpath to be
+ // dereferenced if it is a symbolic link.
+ flags := args[4].Int()
+
+ // Sanity check flags.
+ if flags&^(linux.AT_SYMLINK_FOLLOW|linux.AT_EMPTY_PATH) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ resolve := flags&linux.AT_SYMLINK_FOLLOW == linux.AT_SYMLINK_FOLLOW
+ allowEmpty := flags&linux.AT_EMPTY_PATH == linux.AT_EMPTY_PATH
+
+ if allowEmpty && !t.HasCapabilityIn(linux.CAP_DAC_READ_SEARCH, t.UserNamespace().Root()) {
+ return 0, nil, syserror.ENOENT
+ }
+
+ return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
+}
+
+func readlinkAt(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, err
+ }
+ if dirPath {
+ return 0, syserror.ENOENT
+ }
+
+ err = fileOpOn(t, dirFD, path, false /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ // Check for Read permission.
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Read: true}); err != nil {
+ return err
+ }
+
+ s, err := d.Inode.Readlink(t)
+ if err == syserror.ENOLINK {
+ return syserror.EINVAL
+ }
+ if err != nil {
+ return err
+ }
+
+ buffer := []byte(s)
+ if uint(len(buffer)) > size {
+ buffer = buffer[:size]
+ }
+
+ n, err := t.CopyOutBytes(bufAddr, buffer)
+
+ // Update frame return value.
+ copied = uintptr(n)
+
+ return err
+ })
+ return copied, err // Return frame value.
+}
+
+// Readlink implements linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ n, err := readlinkAt(t, linux.AT_FDCWD, addr, bufAddr, size)
+ return n, nil, err
+}
+
+// Readlinkat implements linux syscall readlinkat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ n, err := readlinkAt(t, dirFD, addr, bufAddr, size)
+ return n, nil, err
+}
+
+func unlinkAt(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr) error {
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ if dirPath {
+ return syserror.ENOENT
+ }
+
+ return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string) error {
+ if !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ if err := fs.MayDelete(t, root, d, name); err != nil {
+ return err
+ }
+
+ return d.Remove(t, root, name)
+ })
+}
+
+// Unlink implements linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ return 0, nil, unlinkAt(t, linux.AT_FDCWD, addr)
+}
+
+// Unlinkat implements linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirAt(t, dirFD, addr)
+ }
+ return 0, nil, unlinkAt(t, dirFD, addr)
+}
+
+// Truncate implements linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+ if dirPath {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(syscall.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ if fs.IsDir(d.Inode.StableAttr) {
+ return syserror.EISDIR
+ }
+ if !fs.IsFile(d.Inode.StableAttr) {
+ return syserror.EINVAL
+ }
+
+ // Reject truncation if the access permissions do not allow truncation.
+ // This is different from the behavior of sys_ftruncate, see below.
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+
+ if err := d.Inode.Truncate(t, d, length); err != nil {
+ return err
+ }
+
+ // File length modified, generate notification.
+ d.InotifyEvent(linux.IN_MODIFY, 0)
+
+ return nil
+ })
+}
+
+// Ftruncate implements linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ length := args[1].Int64()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Reject truncation if the file flags do not permit this operation.
+ // This is different from truncate(2) above.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Note that this is different from truncate(2) above, where a
+ // directory returns EISDIR.
+ if !fs.IsFile(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(syscall.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ if err := file.Dirent.Inode.Truncate(t, file.Dirent, length); err != nil {
+ return 0, nil, err
+ }
+
+ // File length modified, generate notification.
+ file.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+
+ return 0, nil, nil
+}
+
+// Umask implements linux syscall umask(2).
+func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ mask := args[0].ModeT()
+ mask = t.FSContext().SwapUmask(mask & 0777)
+ return uintptr(mask), nil, nil
+}
+
+// Change ownership of a file.
+//
+// uid and gid may be -1, in which case they will not be changed.
+func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
+ owner := fs.FileOwner{
+ UID: auth.NoID,
+ GID: auth.NoID,
+ }
+
+ uattr, err := d.Inode.UnstableAttr(t)
+ if err != nil {
+ return err
+ }
+ c := t.Credentials()
+ hasCap := d.Inode.CheckCapability(t, linux.CAP_CHOWN)
+ isOwner := uattr.Owner.UID == c.EffectiveKUID
+ if uid.Ok() {
+ kuid := c.UserNamespace.MapToKUID(uid)
+ // Valid UID must be supplied if UID is to be changed.
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+
+ // "Only a privileged process (CAP_CHOWN) may change the owner
+ // of a file." -chown(2)
+ //
+ // Linux also allows chown if you own the file and are
+ // explicitly not changing its UID.
+ isNoop := uattr.Owner.UID == kuid
+ if !(hasCap || (isOwner && isNoop)) {
+ return syserror.EPERM
+ }
+
+ owner.UID = kuid
+ }
+ if gid.Ok() {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ // Valid GID must be supplied if GID is to be changed.
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+
+ // "The owner of a file may change the group of the file to any
+ // group of which that owner is a member. A privileged process
+ // (CAP_CHOWN) may change the group arbitrarily." -chown(2)
+ isNoop := uattr.Owner.GID == kgid
+ isMemberGroup := c.InGroup(kgid)
+ if !(hasCap || (isOwner && (isNoop || isMemberGroup))) {
+ return syserror.EPERM
+ }
+
+ owner.GID = kgid
+ }
+
+ // FIXME(b/62949101): This is racy; the inode's owner may have changed in
+ // the meantime. (Linux holds i_mutex while calling
+ // fs/attr.c:notify_change() => inode_operations::setattr =>
+ // inode_change_ok().)
+ if err := d.Inode.SetOwner(t, d, owner); err != nil {
+ return err
+ }
+
+ // When the owner or group are changed by an unprivileged user,
+ // chown(2) also clears the set-user-ID and set-group-ID bits, but
+ // we do not support them.
+ return nil
+}
+
+func chownAt(t *kernel.Task, fd kdefs.FD, addr usermem.Addr, resolve, allowEmpty bool, uid auth.UID, gid auth.GID) error {
+ path, _, err := copyInPath(t, addr, allowEmpty)
+ if err != nil {
+ return err
+ }
+
+ if path == "" {
+ // Annoying. What's wrong with fchown?
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return chown(t, file.Dirent, uid, gid)
+ }
+
+ return fileOpOn(t, fd, path, resolve, func(root *fs.Dirent, d *fs.Dirent) error {
+ return chown(t, d, uid, gid)
+ })
+}
+
+// Chown implements linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ uid := auth.UID(args[1].Uint())
+ gid := auth.GID(args[2].Uint())
+
+ return 0, nil, chownAt(t, linux.AT_FDCWD, addr, true /* resolve */, false /* allowEmpty */, uid, gid)
+}
+
+// Lchown implements linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ uid := auth.UID(args[1].Uint())
+ gid := auth.GID(args[2].Uint())
+
+ return 0, nil, chownAt(t, linux.AT_FDCWD, addr, false /* resolve */, false /* allowEmpty */, uid, gid)
+}
+
+// Fchown implements linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ uid := auth.UID(args[1].Uint())
+ gid := auth.GID(args[2].Uint())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, chown(t, file.Dirent, uid, gid)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ uid := auth.UID(args[2].Uint())
+ gid := auth.GID(args[3].Uint())
+ flags := args[4].Int()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, chownAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, flags&linux.AT_EMPTY_PATH != 0, uid, gid)
+}
+
+func chmod(t *kernel.Task, d *fs.Dirent, mode linux.FileMode) error {
+ // Must own file to change mode.
+ if !d.Inode.CheckOwnership(t) {
+ return syserror.EPERM
+ }
+
+ p := fs.FilePermsFromMode(mode)
+ if !d.Inode.SetPermissions(t, d, p) {
+ return syserror.EPERM
+ }
+
+ // File attribute changed, generate notification.
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+
+ return nil
+}
+
+func chmodAt(t *kernel.Task, fd kdefs.FD, addr usermem.Addr, mode linux.FileMode) error {
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpOn(t, fd, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ return chmod(t, d, mode)
+ })
+}
+
+// Chmod implements linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := linux.FileMode(args[1].ModeT())
+
+ return 0, nil, chmodAt(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Fchmod implements linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ mode := linux.FileMode(args[1].ModeT())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, chmod(t, file.Dirent, mode)
+}
+
+// Fchmodat implements linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ mode := linux.FileMode(args[2].ModeT())
+
+ return 0, nil, chmodAt(t, fd, addr, mode)
+}
+
+// defaultSetToSystemTimeSpec returns a TimeSpec that will set ATime and MTime
+// to the system time.
+func defaultSetToSystemTimeSpec() fs.TimeSpec {
+ return fs.TimeSpec{
+ ATimeSetSystemTime: true,
+ MTimeSetSystemTime: true,
+ }
+}
+
+func utimes(t *kernel.Task, dirFD kdefs.FD, addr usermem.Addr, ts fs.TimeSpec, resolve bool) error {
+ setTimestamp := func(root *fs.Dirent, d *fs.Dirent) error {
+ // Does the task own the file?
+ if !d.Inode.CheckOwnership(t) {
+ // Trying to set a specific time? Must be owner.
+ if (ts.ATimeOmit || !ts.ATimeSetSystemTime) && (ts.MTimeOmit || !ts.MTimeSetSystemTime) {
+ return syserror.EPERM
+ }
+
+ // Trying to set to current system time? Must have write access.
+ if err := d.Inode.CheckPermission(t, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+ }
+
+ if err := d.Inode.SetTimestamps(t, d, ts); err != nil {
+ return err
+ }
+
+ // File attribute changed, generate notification.
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+ return nil
+ }
+
+ // From utimes.c:
+ // "If filename is NULL and dfd refers to an open file, then operate on
+ // the file. Otherwise look up filename, possibly using dfd as a
+ // starting point."
+ if addr == 0 && dirFD != linux.AT_FDCWD {
+ if !resolve {
+ // Linux returns EINVAL in this case. See utimes.c.
+ return syserror.EINVAL
+ }
+ f := t.FDMap().GetFile(dirFD)
+ if f == nil {
+ return syserror.EBADF
+ }
+ defer f.DecRef()
+
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+
+ return setTimestamp(root, f.Dirent)
+ }
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpOn(t, dirFD, path, resolve, setTimestamp)
+}
+
+// Utime implements linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ filenameAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ // No timesAddr argument will be interpreted as current system time.
+ ts := defaultSetToSystemTimeSpec()
+ if timesAddr != 0 {
+ var times syscall.Utimbuf
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ ts = fs.TimeSpec{
+ ATime: ktime.FromSeconds(times.Actime),
+ MTime: ktime.FromSeconds(times.Modtime),
+ }
+ }
+ return 0, nil, utimes(t, linux.AT_FDCWD, filenameAddr, ts, true)
+}
+
+// Utimes implements linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ filenameAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ // No timesAddr argument will be interpreted as current system time.
+ ts := defaultSetToSystemTimeSpec()
+ if timesAddr != 0 {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ ts = fs.TimeSpec{
+ ATime: ktime.FromTimeval(times[0]),
+ MTime: ktime.FromTimeval(times[1]),
+ }
+ }
+ return 0, nil, utimes(t, linux.AT_FDCWD, filenameAddr, ts, true)
+}
+
+// timespecIsValid checks that the timespec is valid for use in utimensat.
+func timespecIsValid(ts linux.Timespec) bool {
+ // Nsec must be UTIME_OMIT, UTIME_NOW, or less than 10^9.
+ return ts.Nsec == linux.UTIME_OMIT || ts.Nsec == linux.UTIME_NOW || ts.Nsec < 1e9
+}
+
+// Utimensat implements linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ pathnameAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ // No timesAddr argument will be interpreted as current system time.
+ ts := defaultSetToSystemTimeSpec()
+ if timesAddr != 0 {
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // If both are UTIME_OMIT, this is a noop.
+ if times[0].Nsec == linux.UTIME_OMIT && times[1].Nsec == linux.UTIME_OMIT {
+ return 0, nil, nil
+ }
+
+ ts = fs.TimeSpec{
+ ATime: ktime.FromTimespec(times[0]),
+ ATimeOmit: times[0].Nsec == linux.UTIME_OMIT,
+ ATimeSetSystemTime: times[0].Nsec == linux.UTIME_NOW,
+ MTime: ktime.FromTimespec(times[1]),
+ MTimeOmit: times[1].Nsec == linux.UTIME_OMIT,
+ MTimeSetSystemTime: times[0].Nsec == linux.UTIME_NOW,
+ }
+ }
+ return 0, nil, utimes(t, dirFD, pathnameAddr, ts, flags&linux.AT_SYMLINK_NOFOLLOW == 0)
+}
+
+// Futimesat implements linux syscall futimesat(2).
+func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := kdefs.FD(args[0].Int())
+ pathnameAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+
+ // No timesAddr argument will be interpreted as current system time.
+ ts := defaultSetToSystemTimeSpec()
+ if timesAddr != 0 {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ if times[0].Usec >= 1e6 || times[0].Usec < 0 ||
+ times[1].Usec >= 1e6 || times[1].Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ts = fs.TimeSpec{
+ ATime: ktime.FromTimeval(times[0]),
+ MTime: ktime.FromTimeval(times[1]),
+ }
+ }
+ return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true)
+}
+
+func renameAt(t *kernel.Task, oldDirFD kdefs.FD, oldAddr usermem.Addr, newDirFD kdefs.FD, newAddr usermem.Addr) error {
+ newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+ oldPath, _, err := copyInPath(t, oldAddr, false /* allowEmpty */)
+ if err != nil {
+ return err
+ }
+
+ return fileOpAt(t, oldDirFD, oldPath, func(root *fs.Dirent, oldParent *fs.Dirent, oldName string) error {
+ if !fs.IsDir(oldParent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Rename rejects paths that end in ".", "..", or empty (i.e.
+ // the root) with EBUSY.
+ switch oldName {
+ case "", ".", "..":
+ return syserror.EBUSY
+ }
+
+ return fileOpAt(t, newDirFD, newPath, func(root *fs.Dirent, newParent *fs.Dirent, newName string) error {
+ if !fs.IsDir(newParent.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ // Rename rejects paths that end in ".", "..", or empty
+ // (i.e. the root) with EBUSY.
+ switch newName {
+ case "", ".", "..":
+ return syserror.EBUSY
+ }
+
+ return fs.Rename(t, root, oldParent, oldName, newParent, newName)
+ })
+ })
+}
+
+// Rename implements linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldPathAddr := args[0].Pointer()
+ newPathAddr := args[1].Pointer()
+ return 0, nil, renameAt(t, linux.AT_FDCWD, oldPathAddr, linux.AT_FDCWD, newPathAddr)
+}
+
+// Renameat implements linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldDirFD := kdefs.FD(args[0].Int())
+ oldPathAddr := args[1].Pointer()
+ newDirFD := kdefs.FD(args[2].Int())
+ newPathAddr := args[3].Pointer()
+ return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr)
+}
+
+// Fallocate implements linux system call fallocate(2).
+func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ mode := args[1].Int64()
+ offset := args[2].Int64()
+ length := args[3].Int64()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ if offset < 0 || length <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if mode != 0 {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOTSUP
+ }
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+ if fs.IsPipe(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.ESPIPE
+ }
+ if fs.IsDir(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EISDIR
+ }
+ if !fs.IsRegular(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.ENODEV
+ }
+ size := offset + length
+ if size < 0 {
+ return 0, nil, syserror.EFBIG
+ }
+ if uint64(size) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(syscall.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ if err := file.Dirent.Inode.Allocate(t, file.Dirent, offset, length); err != nil {
+ return 0, nil, err
+ }
+
+ // File length modified, generate notification.
+ file.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+
+ return 0, nil, nil
+}
+
+// Flock implements linux syscall flock(2).
+func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ operation := args[1].Int()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ // flock(2): EBADF fd is not an open file descriptor.
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ nonblocking := operation&linux.LOCK_NB != 0
+ operation &^= linux.LOCK_NB
+
+ // flock(2):
+ // Locks created by flock() are associated with an open file table entry. This means that
+ // duplicate file descriptors (created by, for example, fork(2) or dup(2)) refer to the
+ // same lock, and this lock may be modified or released using any of these descriptors. Furthermore,
+ // the lock is released either by an explicit LOCK_UN operation on any of these duplicate
+ // descriptors, or when all such descriptors have been closed.
+ //
+ // If a process uses open(2) (or similar) to obtain more than one descriptor for the same file,
+ // these descriptors are treated independently by flock(). An attempt to lock the file using
+ // one of these file descriptors may be denied by a lock that the calling process has already placed via
+ // another descriptor.
+ //
+ // We use the File UniqueID as the lock UniqueID because it needs to reference the same lock across dup(2)
+ // and fork(2).
+ lockUniqueID := lock.UniqueID(file.UniqueID)
+
+ // A BSD style lock spans the entire file.
+ rng := lock.LockRange{
+ Start: 0,
+ End: lock.LockEOF,
+ }
+
+ switch operation {
+ case linux.LOCK_EX:
+ if nonblocking {
+ // Since we're nonblocking we pass a nil lock.Blocker implementation.
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) {
+ return 0, nil, syserror.EWOULDBLOCK
+ }
+ } else {
+ // Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, t) {
+ return 0, nil, syserror.EINTR
+ }
+ }
+ case linux.LOCK_SH:
+ if nonblocking {
+ // Since we're nonblocking we pass a nil lock.Blocker implementation.
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) {
+ return 0, nil, syserror.EWOULDBLOCK
+ }
+ } else {
+ // Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, t) {
+ return 0, nil, syserror.EINTR
+ }
+ }
+ case linux.LOCK_UN:
+ file.Dirent.Inode.LockCtx.BSD.UnlockRegion(lockUniqueID, rng)
+ default:
+ // flock(2): EINVAL operation is invalid.
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
+
+const (
+ memfdPrefix = "/memfd:"
+ memfdAllFlags = uint32(linux.MFD_CLOEXEC | linux.MFD_ALLOW_SEALING)
+ memfdMaxNameLen = linux.NAME_MAX - len(memfdPrefix) + 1
+)
+
+// MemfdCreate implements the linux syscall memfd_create(2).
+func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+
+ if flags&^memfdAllFlags != 0 {
+ // Unknown bits in flags.
+ return 0, nil, syserror.EINVAL
+ }
+
+ allowSeals := flags&linux.MFD_ALLOW_SEALING != 0
+ cloExec := flags&linux.MFD_CLOEXEC != 0
+
+ name, err := t.CopyInString(addr, syscall.PathMax-len(memfdPrefix))
+ if err != nil {
+ return 0, nil, err
+ }
+ if len(name) > memfdMaxNameLen {
+ return 0, nil, syserror.EINVAL
+ }
+ name = memfdPrefix + name
+
+ inode := tmpfs.NewMemfdInode(t, allowSeals)
+ dirent := fs.NewDirent(inode, name)
+ // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with
+ // FMODE_READ | FMODE_WRITE.
+ file, err := inode.GetFile(t, dirent, fs.FileFlags{Read: true, Write: true})
+ if err != nil {
+ return 0, nil, err
+ }
+
+ defer dirent.DecRef()
+ defer file.DecRef()
+
+ fdFlags := kernel.FDFlags{CloseOnExec: cloExec}
+ newFD, err := t.FDMap().NewFDFrom(0, file, fdFlags, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(newFD), nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
new file mode 100644
index 000000000..7cef4b50c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -0,0 +1,278 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// futexWaitRestartBlock encapsulates the state required to restart futex(2)
+// via restart_syscall(2).
+//
+// +stateify savable
+type futexWaitRestartBlock struct {
+ duration time.Duration
+
+ // addr stored as uint64 since uintptr is not save-able.
+ addr uint64
+ private bool
+ val uint32
+ mask uint32
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (f *futexWaitRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return futexWaitDuration(t, f.duration, false, usermem.Addr(f.addr), f.private, f.val, f.mask)
+}
+
+// futexWaitAbsolute performs a FUTEX_WAIT_BITSET, blocking until the wait is
+// complete.
+//
+// The wait blocks forever if forever is true, otherwise it blocks until ts.
+//
+// If blocking is interrupted, the syscall is restarted with the original
+// arguments.
+func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, forever bool, addr usermem.Addr, private bool, val, mask uint32) (uintptr, error) {
+ w := t.FutexWaiter()
+ err := t.Futex().WaitPrepare(w, t, addr, private, val, mask)
+ if err != nil {
+ return 0, err
+ }
+
+ if forever {
+ err = t.Block(w.C)
+ } else if clockRealtime {
+ notifier, tchan := ktime.NewChannelNotifier()
+ timer := ktime.NewTimer(t.Kernel().RealtimeClock(), notifier)
+ timer.Swap(ktime.Setting{
+ Enabled: true,
+ Next: ktime.FromTimespec(ts),
+ })
+ err = t.BlockWithTimer(w.C, tchan)
+ timer.Destroy()
+ } else {
+ err = t.BlockWithDeadline(w.C, true, ktime.FromTimespec(ts))
+ }
+
+ t.Futex().WaitComplete(w)
+ return 0, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// futexWaitDuration performs a FUTEX_WAIT, blocking until the wait is
+// complete.
+//
+// The wait blocks forever if forever is true, otherwise is blocks for
+// duration.
+//
+// If blocking is interrupted, forever determines how to restart the
+// syscall. If forever is true, the syscall is restarted with the original
+// arguments. If forever is false, duration is a relative timeout and the
+// syscall is restarted with the remaining timeout.
+func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, addr usermem.Addr, private bool, val, mask uint32) (uintptr, error) {
+ w := t.FutexWaiter()
+ err := t.Futex().WaitPrepare(w, t, addr, private, val, mask)
+ if err != nil {
+ return 0, err
+ }
+
+ remaining, err := t.BlockWithTimeout(w.C, !forever, duration)
+ t.Futex().WaitComplete(w)
+ if err == nil {
+ return 0, nil
+ }
+
+ // The wait was unsuccessful for some reason other than interruption. Simply
+ // forward the error.
+ if err != syserror.ErrInterrupted {
+ return 0, err
+ }
+
+ // The wait was interrupted and we need to restart. Decide how.
+
+ // The wait duration was absolute, restart with the original arguments.
+ if forever {
+ return 0, kernel.ERESTARTSYS
+ }
+
+ // The wait duration was relative, restart with the remaining duration.
+ t.SetSyscallRestartBlock(&futexWaitRestartBlock{
+ duration: remaining,
+ addr: uint64(addr),
+ private: private,
+ val: val,
+ mask: mask,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+}
+
+func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.Addr, private bool) error {
+ w := t.FutexWaiter()
+ locked, err := t.Futex().LockPI(w, t, addr, uint32(t.ThreadID()), private, false)
+ if err != nil {
+ return err
+ }
+ if locked {
+ // Futex acquired, we're done!
+ return nil
+ }
+
+ if forever {
+ err = t.Block(w.C)
+ } else {
+ notifier, tchan := ktime.NewChannelNotifier()
+ timer := ktime.NewTimer(t.Kernel().RealtimeClock(), notifier)
+ timer.Swap(ktime.Setting{
+ Enabled: true,
+ Next: ktime.FromTimespec(ts),
+ })
+ err = t.BlockWithTimer(w.C, tchan)
+ timer.Destroy()
+ }
+
+ t.Futex().WaitComplete(w)
+ return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+func tryLockPI(t *kernel.Task, addr usermem.Addr, private bool) error {
+ w := t.FutexWaiter()
+ locked, err := t.Futex().LockPI(w, t, addr, uint32(t.ThreadID()), private, true)
+ if err != nil {
+ return err
+ }
+ if !locked {
+ return syserror.EWOULDBLOCK
+ }
+ return nil
+}
+
+// Futex implements linux syscall futex(2).
+// It provides a method for a program to wait for a value at a given address to
+// change, and a method to wake up anyone waiting on a particular address.
+func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ futexOp := args[1].Int()
+ val := int(args[2].Int())
+ nreq := int(args[3].Int())
+ timeout := args[3].Pointer()
+ naddr := args[4].Pointer()
+ val3 := args[5].Int()
+
+ cmd := futexOp &^ (linux.FUTEX_PRIVATE_FLAG | linux.FUTEX_CLOCK_REALTIME)
+ private := (futexOp & linux.FUTEX_PRIVATE_FLAG) != 0
+ clockRealtime := (futexOp & linux.FUTEX_CLOCK_REALTIME) == linux.FUTEX_CLOCK_REALTIME
+ mask := uint32(val3)
+
+ switch cmd {
+ case linux.FUTEX_WAIT, linux.FUTEX_WAIT_BITSET:
+ // WAIT{_BITSET} wait forever if the timeout isn't passed.
+ forever := (timeout == 0)
+
+ var timespec linux.Timespec
+ if !forever {
+ var err error
+ timespec, err = copyTimespecIn(t, timeout)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ switch cmd {
+ case linux.FUTEX_WAIT:
+ // WAIT uses a relative timeout.
+ mask = ^uint32(0)
+ var timeoutDur time.Duration
+ if !forever {
+ timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond
+ }
+ n, err := futexWaitDuration(t, timeoutDur, forever, addr, private, uint32(val), mask)
+ return n, nil, err
+
+ case linux.FUTEX_WAIT_BITSET:
+ // WAIT_BITSET uses an absolute timeout which is either
+ // CLOCK_MONOTONIC or CLOCK_REALTIME.
+ if mask == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ n, err := futexWaitAbsolute(t, clockRealtime, timespec, forever, addr, private, uint32(val), mask)
+ return n, nil, err
+ default:
+ panic("unreachable")
+ }
+
+ case linux.FUTEX_WAKE:
+ mask = ^uint32(0)
+ fallthrough
+
+ case linux.FUTEX_WAKE_BITSET:
+ if mask == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ n, err := t.Futex().Wake(t, addr, private, mask, val)
+ return uintptr(n), nil, err
+
+ case linux.FUTEX_REQUEUE:
+ n, err := t.Futex().Requeue(t, addr, naddr, private, val, nreq)
+ return uintptr(n), nil, err
+
+ case linux.FUTEX_CMP_REQUEUE:
+ // 'val3' contains the value to be checked at 'addr' and
+ // 'val' is the number of waiters that should be woken up.
+ nval := uint32(val3)
+ n, err := t.Futex().RequeueCmp(t, addr, naddr, private, nval, val, nreq)
+ return uintptr(n), nil, err
+
+ case linux.FUTEX_WAKE_OP:
+ op := uint32(val3)
+ n, err := t.Futex().WakeOp(t, addr, naddr, private, val, nreq, op)
+ return uintptr(n), nil, err
+
+ case linux.FUTEX_LOCK_PI:
+ forever := (timeout == 0)
+
+ var timespec linux.Timespec
+ if !forever {
+ var err error
+ timespec, err = copyTimespecIn(t, timeout)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ err := futexLockPI(t, timespec, forever, addr, private)
+ return 0, nil, err
+
+ case linux.FUTEX_TRYLOCK_PI:
+ err := tryLockPI(t, addr, private)
+ return 0, nil, err
+
+ case linux.FUTEX_UNLOCK_PI:
+ err := t.Futex().UnlockPI(t, addr, uint32(t.ThreadID()), private)
+ return 0, nil, err
+
+ case linux.FUTEX_WAIT_REQUEUE_PI, linux.FUTEX_CMP_REQUEUE_PI:
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+
+ default:
+ // We don't even know about this command.
+ return 0, nil, syserror.ENOSYS
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
new file mode 100644
index 000000000..1b597d5bc
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -0,0 +1,269 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "bytes"
+ "io"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Getdents implements linux syscall getdents(2) for 64bit systems.
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ minSize := int(smallestDirent(t.Arch()))
+ if size < minSize {
+ // size is smaller than smallest possible dirent.
+ return 0, nil, syserror.EINVAL
+ }
+
+ n, err := getdents(t, fd, addr, size, (*dirent).Serialize)
+ return n, nil, err
+}
+
+// Getdents64 implements linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ minSize := int(smallestDirent64(t.Arch()))
+ if size < minSize {
+ // size is smaller than smallest possible dirent.
+ return 0, nil, syserror.EINVAL
+ }
+
+ n, err := getdents(t, fd, addr, size, (*dirent).Serialize64)
+ return n, nil, err
+}
+
+// getdents implements the core of getdents(2)/getdents64(2).
+// f is the syscall implementation dirent serialization function.
+func getdents(t *kernel.Task, fd kdefs.FD, addr usermem.Addr, size int, f func(*dirent, io.Writer) (int, error)) (uintptr, error) {
+ dir := t.FDMap().GetFile(fd)
+ if dir == nil {
+ return 0, syserror.EBADF
+ }
+ defer dir.DecRef()
+
+ w := &usermem.IOReadWriter{
+ Ctx: t,
+ IO: t.MemoryManager(),
+ Addr: addr,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }
+
+ ds := newDirentSerializer(f, w, t.Arch(), size)
+ rerr := dir.Readdir(t, ds)
+
+ switch err := handleIOError(t, ds.Written() > 0, rerr, kernel.ERESTARTSYS, "getdents", dir); err {
+ case nil:
+ dir.Dirent.InotifyEvent(syscall.IN_ACCESS, 0)
+ return uintptr(ds.Written()), nil
+ case io.EOF:
+ return 0, nil
+ default:
+ return 0, err
+ }
+}
+
+// oldDirentHdr is a fixed sized header matching the fixed size
+// fields found in the old linux dirent struct.
+type oldDirentHdr struct {
+ Ino uint64
+ Off uint64
+ Reclen uint16
+}
+
+// direntHdr is a fixed sized header matching the fixed size
+// fields found in the new linux dirent struct.
+type direntHdr struct {
+ OldHdr oldDirentHdr
+ Typ uint8
+}
+
+// dirent contains the data pointed to by a new linux dirent struct.
+type dirent struct {
+ Hdr direntHdr
+ Name []byte
+}
+
+// newDirent returns a dirent from an fs.InodeOperationsInfo.
+func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent {
+ d := &dirent{
+ Hdr: direntHdr{
+ OldHdr: oldDirentHdr{
+ Ino: attr.InodeID,
+ Off: offset,
+ },
+ Typ: toType(attr.Type),
+ },
+ Name: []byte(name),
+ }
+ d.Hdr.OldHdr.Reclen = d.padRec(int(width))
+ return d
+}
+
+// smallestDirent returns the size of the smallest possible dirent using
+// the old linux dirent format.
+func smallestDirent(a arch.Context) uint {
+ d := dirent{}
+ return uint(binary.Size(d.Hdr.OldHdr)) + a.Width() + 1
+}
+
+// smallestDirent64 returns the size of the smallest possible dirent using
+// the new linux dirent format.
+func smallestDirent64(a arch.Context) uint {
+ d := dirent{}
+ return uint(binary.Size(d.Hdr)) + a.Width()
+}
+
+// toType converts an fs.InodeOperationsInfo to a linux dirent typ field.
+func toType(nodeType fs.InodeType) uint8 {
+ switch nodeType {
+ case fs.RegularFile, fs.SpecialFile:
+ return syscall.DT_REG
+ case fs.Symlink:
+ return syscall.DT_LNK
+ case fs.Directory, fs.SpecialDirectory:
+ return syscall.DT_DIR
+ case fs.Pipe:
+ return syscall.DT_FIFO
+ case fs.CharacterDevice:
+ return syscall.DT_CHR
+ case fs.BlockDevice:
+ return syscall.DT_BLK
+ case fs.Socket:
+ return syscall.DT_SOCK
+ default:
+ return syscall.DT_UNKNOWN
+ }
+}
+
+// padRec pads the name field until the rec length is a multiple of the width,
+// which must be a power of 2. It returns the padded rec length.
+func (d *dirent) padRec(width int) uint16 {
+ a := int(binary.Size(d.Hdr)) + len(d.Name)
+ r := (a + width) &^ (width - 1)
+ padding := r - a
+ d.Name = append(d.Name, make([]byte, padding)...)
+ return uint16(r)
+}
+
+// Serialize64 serializes a Dirent struct to a byte slice, keeping the new
+// linux dirent format. Returns the number of bytes serialized or an error.
+func (d *dirent) Serialize64(w io.Writer) (int, error) {
+ n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr))
+ if err != nil {
+ return 0, err
+ }
+ n2, err := w.Write(d.Name)
+ if err != nil {
+ return 0, err
+ }
+ return n1 + n2, nil
+}
+
+// Serialize serializes a Dirent struct to a byte slice, using the old linux
+// dirent format.
+// Returns the number of bytes serialized or an error.
+func (d *dirent) Serialize(w io.Writer) (int, error) {
+ n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr.OldHdr))
+ if err != nil {
+ return 0, err
+ }
+ n2, err := w.Write(d.Name)
+ if err != nil {
+ return 0, err
+ }
+ n3, err := w.Write([]byte{d.Hdr.Typ})
+ if err != nil {
+ return 0, err
+ }
+ return n1 + n2 + n3, nil
+}
+
+// direntSerializer implements fs.InodeOperationsInfoSerializer, serializing dirents to an
+// io.Writer.
+type direntSerializer struct {
+ serialize func(*dirent, io.Writer) (int, error)
+ w io.Writer
+ // width is the arch native value width.
+ width uint
+ // offset is the current dirent offset.
+ offset uint64
+ // written is the total bytes serialized.
+ written int
+ // size is the size of the buffer to serialize into.
+ size int
+}
+
+func newDirentSerializer(f func(d *dirent, w io.Writer) (int, error), w io.Writer, ac arch.Context, size int) *direntSerializer {
+ return &direntSerializer{
+ serialize: f,
+ w: w,
+ width: ac.Width(),
+ size: size,
+ }
+}
+
+// CopyOut implements fs.InodeOperationsInfoSerializer.CopyOut.
+// It serializes and writes the fs.DentAttr to the direntSerializer io.Writer.
+func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error {
+ ds.offset++
+
+ d := newDirent(ds.width, name, attr, ds.offset)
+
+ // Serialize dirent into a temp buffer.
+ var b bytes.Buffer
+ n, err := ds.serialize(d, &b)
+ if err != nil {
+ ds.offset--
+ return err
+ }
+
+ // Check that we have enough room remaining to write the dirent.
+ if n > (ds.size - ds.written) {
+ ds.offset--
+ return io.EOF
+ }
+
+ // Write out the temp buffer.
+ if _, err := b.WriteTo(ds.w); err != nil {
+ ds.offset--
+ return err
+ }
+
+ ds.written += n
+ return nil
+}
+
+// Written returns the total number of bytes written.
+func (ds *direntSerializer) Written() int {
+ return ds.written
+}
diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go
new file mode 100644
index 000000000..27e765a2d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_identity.go
@@ -0,0 +1,180 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const (
+ // As NGROUPS_MAX in include/uapi/linux/limits.h.
+ maxNGroups = 65536
+)
+
+// Getuid implements the Linux syscall getuid.
+func Getuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ c := t.Credentials()
+ ruid := c.RealKUID.In(c.UserNamespace).OrOverflow()
+ return uintptr(ruid), nil, nil
+}
+
+// Geteuid implements the Linux syscall geteuid.
+func Geteuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ c := t.Credentials()
+ euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow()
+ return uintptr(euid), nil, nil
+}
+
+// Getresuid implements the Linux syscall getresuid.
+func Getresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ruidAddr := args[0].Pointer()
+ euidAddr := args[1].Pointer()
+ suidAddr := args[2].Pointer()
+ c := t.Credentials()
+ ruid := c.RealKUID.In(c.UserNamespace).OrOverflow()
+ euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow()
+ suid := c.SavedKUID.In(c.UserNamespace).OrOverflow()
+ if _, err := t.CopyOut(ruidAddr, ruid); err != nil {
+ return 0, nil, err
+ }
+ if _, err := t.CopyOut(euidAddr, euid); err != nil {
+ return 0, nil, err
+ }
+ if _, err := t.CopyOut(suidAddr, suid); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, nil
+}
+
+// Getgid implements the Linux syscall getgid.
+func Getgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ c := t.Credentials()
+ rgid := c.RealKGID.In(c.UserNamespace).OrOverflow()
+ return uintptr(rgid), nil, nil
+}
+
+// Getegid implements the Linux syscall getegid.
+func Getegid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ c := t.Credentials()
+ egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow()
+ return uintptr(egid), nil, nil
+}
+
+// Getresgid implements the Linux syscall getresgid.
+func Getresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ rgidAddr := args[0].Pointer()
+ egidAddr := args[1].Pointer()
+ sgidAddr := args[2].Pointer()
+ c := t.Credentials()
+ rgid := c.RealKGID.In(c.UserNamespace).OrOverflow()
+ egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow()
+ sgid := c.SavedKGID.In(c.UserNamespace).OrOverflow()
+ if _, err := t.CopyOut(rgidAddr, rgid); err != nil {
+ return 0, nil, err
+ }
+ if _, err := t.CopyOut(egidAddr, egid); err != nil {
+ return 0, nil, err
+ }
+ if _, err := t.CopyOut(sgidAddr, sgid); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, nil
+}
+
+// Setuid implements the Linux syscall setuid.
+func Setuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ uid := auth.UID(args[0].Int())
+ return 0, nil, t.SetUID(uid)
+}
+
+// Setreuid implements the Linux syscall setreuid.
+func Setreuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ruid := auth.UID(args[0].Int())
+ euid := auth.UID(args[1].Int())
+ return 0, nil, t.SetREUID(ruid, euid)
+}
+
+// Setresuid implements the Linux syscall setreuid.
+func Setresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ruid := auth.UID(args[0].Int())
+ euid := auth.UID(args[1].Int())
+ suid := auth.UID(args[2].Int())
+ return 0, nil, t.SetRESUID(ruid, euid, suid)
+}
+
+// Setgid implements the Linux syscall setgid.
+func Setgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ gid := auth.GID(args[0].Int())
+ return 0, nil, t.SetGID(gid)
+}
+
+// Setregid implements the Linux syscall setregid.
+func Setregid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ rgid := auth.GID(args[0].Int())
+ egid := auth.GID(args[1].Int())
+ return 0, nil, t.SetREGID(rgid, egid)
+}
+
+// Setresgid implements the Linux syscall setregid.
+func Setresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ rgid := auth.GID(args[0].Int())
+ egid := auth.GID(args[1].Int())
+ sgid := auth.GID(args[2].Int())
+ return 0, nil, t.SetRESGID(rgid, egid, sgid)
+}
+
+// Getgroups implements the Linux syscall getgroups.
+func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := int(args[0].Int())
+ if size < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ kgids := t.Credentials().ExtraKGIDs
+ // "If size is zero, list is not modified, but the total number of
+ // supplementary group IDs for the process is returned." - getgroups(2)
+ if size == 0 {
+ return uintptr(len(kgids)), nil, nil
+ }
+ if size < len(kgids) {
+ return 0, nil, syserror.EINVAL
+ }
+ gids := make([]auth.GID, len(kgids))
+ for i, kgid := range kgids {
+ gids[i] = kgid.In(t.UserNamespace()).OrOverflow()
+ }
+ if _, err := t.CopyOut(args[1].Pointer(), gids); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(len(gids)), nil, nil
+}
+
+// Setgroups implements the Linux syscall setgroups.
+func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+ if size < 0 || size > maxNGroups {
+ return 0, nil, syserror.EINVAL
+ }
+ if size == 0 {
+ return 0, nil, t.SetExtraGIDs(nil)
+ }
+ gids := make([]auth.GID, size)
+ if _, err := t.CopyIn(args[1].Pointer(), &gids); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, t.SetExtraGIDs(gids)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_inotify.go b/pkg/sentry/syscalls/linux/sys_inotify.go
new file mode 100644
index 000000000..20269a769
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_inotify.go
@@ -0,0 +1,135 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/anon"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+)
+
+const allFlags = int(linux.IN_NONBLOCK | linux.IN_CLOEXEC)
+
+// InotifyInit1 implements the inotify_init1() syscalls.
+func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+
+ if flags&^allFlags != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ dirent := fs.NewDirent(anon.NewInode(t), "inotify")
+ fileFlags := fs.FileFlags{
+ Read: true,
+ Write: true,
+ NonBlocking: flags&linux.IN_NONBLOCK != 0,
+ }
+ n := fs.NewFile(t, dirent, fileFlags, fs.NewInotify(t))
+ defer n.DecRef()
+
+ fd, err := t.FDMap().NewFDFrom(0, n, kernel.FDFlags{
+ CloseOnExec: flags&linux.IN_CLOEXEC != 0,
+ }, t.ThreadGroup().Limits())
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// InotifyInit implements the inotify_init() syscalls.
+func InotifyInit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[0].Value = 0
+ return InotifyInit1(t, args)
+}
+
+// fdToInotify resolves an fd to an inotify object. If successful, the file will
+// have an extra ref and the caller is responsible for releasing the ref.
+func fdToInotify(t *kernel.Task, fd kdefs.FD) (*fs.Inotify, *fs.File, error) {
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ // Invalid fd.
+ return nil, nil, syscall.EBADF
+ }
+
+ ino, ok := file.FileOperations.(*fs.Inotify)
+ if !ok {
+ // Not an inotify fd.
+ file.DecRef()
+ return nil, nil, syscall.EINVAL
+ }
+
+ return ino, file, nil
+}
+
+// InotifyAddWatch implements the inotify_add_watch() syscall.
+func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ mask := args[2].Uint()
+
+ // "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link."
+ // -- inotify(7)
+ resolve := mask&linux.IN_DONT_FOLLOW == 0
+
+ // "EINVAL: The given event mask contains no valid events."
+ // -- inotify_add_watch(2)
+ if validBits := mask & linux.ALL_INOTIFY_BITS; validBits == 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ ino, file, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ err = fileOpOn(t, linux.AT_FDCWD, path, resolve, func(root *fs.Dirent, dirent *fs.Dirent) error {
+ // "IN_ONLYDIR: Only watch pathname if it is a directory." -- inotify(7)
+ if onlyDir := mask&linux.IN_ONLYDIR != 0; onlyDir && !fs.IsDir(dirent.Inode.StableAttr) {
+ return syscall.ENOTDIR
+ }
+
+ // Copy out to the return frame.
+ fd = kdefs.FD(ino.AddWatch(dirent, mask))
+
+ return nil
+ })
+ return uintptr(fd), nil, err // Return from the existing value.
+}
+
+// InotifyRmWatch implements the inotify_rm_watch() syscall.
+func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ wd := args[1].Int()
+
+ ino, file, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ return 0, nil, ino.RmWatch(wd)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
new file mode 100644
index 000000000..8aadc6d8c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Lseek implements linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var sw fs.SeekWhence
+ switch whence {
+ case 0:
+ sw = fs.SeekSet
+ case 1:
+ sw = fs.SeekCurrent
+ case 2:
+ sw = fs.SeekEnd
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ offset, serr := file.Seek(t, sw, offset)
+ err := handleIOError(t, false /* partialResult */, serr, kernel.ERESTARTSYS, "lseek", file)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(offset), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
new file mode 100644
index 000000000..64a6e639c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -0,0 +1,470 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "bytes"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memmap"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Brk implements linux syscall brk(2).
+func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr, _ := t.MemoryManager().Brk(t, args[0].Pointer())
+ // "However, the actual Linux system call returns the new program break on
+ // success. On failure, the system call returns the current break." -
+ // brk(2)
+ return uintptr(addr), nil, nil
+}
+
+// Mmap implements linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := kdefs.FD(args[4].Int())
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef()
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ flags := file.Flags()
+ // mmap unconditionally requires that the FD is readable.
+ if !flags.Read {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !flags.Write {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
+
+// Munmap implements linux syscall munmap(2).
+func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64())
+}
+
+// Mremap implements linux syscall mremap(2).
+func Mremap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldAddr := args[0].Pointer()
+ oldSize := args[1].Uint64()
+ newSize := args[2].Uint64()
+ flags := args[3].Uint64()
+ newAddr := args[4].Pointer()
+
+ if flags&^(linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ mayMove := flags&linux.MREMAP_MAYMOVE != 0
+ fixed := flags&linux.MREMAP_FIXED != 0
+ var moveMode mm.MRemapMoveMode
+ switch {
+ case !mayMove && !fixed:
+ moveMode = mm.MRemapNoMove
+ case mayMove && !fixed:
+ moveMode = mm.MRemapMayMove
+ case mayMove && fixed:
+ moveMode = mm.MRemapMustMove
+ case !mayMove && fixed:
+ // "If MREMAP_FIXED is specified, then MREMAP_MAYMOVE must also be
+ // specified." - mremap(2)
+ return 0, nil, syserror.EINVAL
+ }
+
+ rv, err := t.MemoryManager().MRemap(t, oldAddr, oldSize, newSize, mm.MRemapOpts{
+ Move: moveMode,
+ NewAddr: newAddr,
+ })
+ return uintptr(rv), nil, err
+}
+
+// Mprotect implements linux syscall mprotect(2).
+func Mprotect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ length := args[1].Uint64()
+ prot := args[2].Int()
+ err := t.MemoryManager().MProtect(args[0].Pointer(), length, usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ }, linux.PROT_GROWSDOWN&prot != 0)
+ return 0, nil, err
+}
+
+// Madvise implements linux syscall madvise(2).
+func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := uint64(args[1].SizeT())
+ adv := args[2].Int()
+
+ // "The Linux implementation requires that the address addr be
+ // page-aligned, and allows length to be zero." - madvise(2)
+ if addr.RoundDown() != addr {
+ return 0, nil, syserror.EINVAL
+ }
+ if length == 0 {
+ return 0, nil, nil
+ }
+ // Not explicitly stated: length need not be page-aligned.
+ lenAddr, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ length = uint64(lenAddr)
+
+ switch adv {
+ case linux.MADV_DONTNEED:
+ return 0, nil, t.MemoryManager().Decommit(addr, length)
+ case linux.MADV_HUGEPAGE, linux.MADV_NOHUGEPAGE:
+ fallthrough
+ case linux.MADV_MERGEABLE, linux.MADV_UNMERGEABLE:
+ fallthrough
+ case linux.MADV_DONTDUMP, linux.MADV_DODUMP:
+ // TODO(b/72045799): Core dumping isn't implemented, so these are
+ // no-ops.
+ fallthrough
+ case linux.MADV_NORMAL, linux.MADV_RANDOM, linux.MADV_SEQUENTIAL, linux.MADV_WILLNEED:
+ // Do nothing, we totally ignore the suggestions above.
+ return 0, nil, nil
+ case linux.MADV_REMOVE, linux.MADV_DOFORK, linux.MADV_DONTFORK:
+ // These "suggestions" have application-visible side effects, so we
+ // have to indicate that we don't support them.
+ return 0, nil, syserror.ENOSYS
+ case linux.MADV_HWPOISON:
+ // Only privileged processes are allowed to poison pages.
+ return 0, nil, syserror.EPERM
+ default:
+ // If adv is not a valid value tell the caller.
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+func copyOutIfNotNull(t *kernel.Task, ptr usermem.Addr, val interface{}) (int, error) {
+ if ptr != 0 {
+ return t.CopyOut(ptr, val)
+ }
+ return 0, nil
+}
+
+// GetMempolicy implements the syscall get_mempolicy(2).
+func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ mode := args[0].Pointer()
+ nodemask := args[1].Pointer()
+ maxnode := args[2].Uint()
+ addr := args[3].Pointer()
+ flags := args[4].Uint()
+
+ memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0
+ nodeFlag := flags&linux.MPOL_F_NODE != 0
+ addrFlag := flags&linux.MPOL_F_ADDR != 0
+
+ // TODO(rahat): Once sysfs is implemented, report a single numa node in
+ // /sys/devices/system/node.
+ if nodemask != 0 && maxnode < 1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // 'addr' provided iff 'addrFlag' set.
+ if addrFlag == (addr == 0) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Default policy for the thread.
+ if flags == 0 {
+ policy, nodemaskVal := t.NumaPolicy()
+ if _, err := copyOutIfNotNull(t, mode, policy); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err := copyOutIfNotNull(t, nodemask, nodemaskVal); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ return 0, nil, nil
+ }
+
+ // Report all nodes available to caller.
+ if memsAllowed {
+ // MPOL_F_NODE and MPOL_F_ADDR not allowed with MPOL_F_MEMS_ALLOWED.
+ if nodeFlag || addrFlag {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Report a single numa node.
+ if _, err := copyOutIfNotNull(t, nodemask, uint32(0x1)); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ return 0, nil, nil
+ }
+
+ if addrFlag {
+ if nodeFlag {
+ // Return the id for the node where 'addr' resides, via 'mode'.
+ //
+ // The real get_mempolicy(2) allocates the page referenced by 'addr'
+ // by simulating a read, if it is unallocated before the call. It
+ // then returns the node the page is allocated on through the mode
+ // pointer.
+ b := t.CopyScratchBuffer(1)
+ _, err := t.CopyInBytes(addr, b)
+ if err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err := copyOutIfNotNull(t, mode, int32(0)); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ } else {
+ storedPolicy, _ := t.NumaPolicy()
+ // Return the policy governing the memory referenced by 'addr'.
+ if _, err := copyOutIfNotNull(t, mode, int32(storedPolicy)); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ }
+ return 0, nil, nil
+ }
+
+ storedPolicy, _ := t.NumaPolicy()
+ if nodeFlag && (storedPolicy&^linux.MPOL_MODE_FLAGS == linux.MPOL_INTERLEAVE) {
+ // Policy for current thread is to interleave memory between
+ // nodes. Return the next node we'll allocate on. Since we only have a
+ // single node, this is always node 0.
+ if _, err := copyOutIfNotNull(t, mode, int32(0)); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ return 0, nil, nil
+ }
+
+ return 0, nil, syserror.EINVAL
+}
+
+func allowedNodesMask() uint32 {
+ const maxNodes = 1
+ return ^uint32((1 << maxNodes) - 1)
+}
+
+// SetMempolicy implements the syscall set_mempolicy(2).
+func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ modeWithFlags := args[0].Int()
+ nodemask := args[1].Pointer()
+ maxnode := args[2].Uint()
+
+ if nodemask != 0 && maxnode < 1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if modeWithFlags&linux.MPOL_MODE_FLAGS == linux.MPOL_MODE_FLAGS {
+ // Can't specify multiple modes simultaneously.
+ return 0, nil, syserror.EINVAL
+ }
+
+ mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS
+ if mode < 0 || mode >= linux.MPOL_MAX {
+ // Must specify a valid mode.
+ return 0, nil, syserror.EINVAL
+ }
+
+ var nodemaskVal uint32
+ // Nodemask may be empty for some policy modes.
+ if nodemask != 0 && maxnode > 0 {
+ if _, err := t.CopyIn(nodemask, &nodemaskVal); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ }
+
+ if (mode == linux.MPOL_INTERLEAVE || mode == linux.MPOL_BIND) && nodemaskVal == 0 {
+ // Mode requires a non-empty nodemask, but got an empty nodemask.
+ return 0, nil, syserror.EINVAL
+ }
+
+ if nodemaskVal&allowedNodesMask() != 0 {
+ // Invalid node specified.
+ return 0, nil, syserror.EINVAL
+ }
+
+ t.SetNumaPolicy(int32(modeWithFlags), nodemaskVal)
+
+ return 0, nil, nil
+}
+
+// Mincore implements the syscall mincore(2).
+func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+ vec := args[2].Pointer()
+
+ if addr != addr.RoundDown() {
+ return 0, nil, syserror.EINVAL
+ }
+ // "The length argument need not be a multiple of the page size, but since
+ // residency information is returned for whole pages, length is effectively
+ // rounded up to the next multiple of the page size." - mincore(2)
+ la, ok := usermem.Addr(length).RoundUp()
+ if !ok {
+ return 0, nil, syserror.ENOMEM
+ }
+ ar, ok := addr.ToRange(uint64(la))
+ if !ok {
+ return 0, nil, syserror.ENOMEM
+ }
+
+ // Pretend that all mapped pages are "resident in core".
+ mapped := t.MemoryManager().VirtualMemorySizeRange(ar)
+ // "ENOMEM: addr to addr + length contained unmapped memory."
+ if mapped != uint64(la) {
+ return 0, nil, syserror.ENOMEM
+ }
+ resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize))
+ _, err := t.CopyOut(vec, resident)
+ return 0, nil, err
+}
+
+// Msync implements Linux syscall msync(2).
+func Msync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+ flags := args[2].Int()
+
+ // "The flags argument should specify exactly one of MS_ASYNC and MS_SYNC,
+ // and may additionally include the MS_INVALIDATE bit. ... However, Linux
+ // permits a call to msync() that specifies neither of these flags, with
+ // semantics that are (currently) equivalent to specifying MS_ASYNC." -
+ // msync(2)
+ if flags&^(linux.MS_ASYNC|linux.MS_SYNC|linux.MS_INVALIDATE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ sync := flags&linux.MS_SYNC != 0
+ if sync && flags&linux.MS_ASYNC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ err := t.MemoryManager().MSync(t, addr, uint64(length), mm.MSyncOpts{
+ Sync: sync,
+ Invalidate: flags&linux.MS_INVALIDATE != 0,
+ })
+ // MSync calls fsync, the same interrupt conversion rules apply, see
+ // mm/msync.c, fsync POSIX.1-2008.
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// Mlock implements linux syscall mlock(2).
+func Mlock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+
+ return 0, nil, t.MemoryManager().MLock(t, addr, uint64(length), memmap.MLockEager)
+}
+
+// Mlock2 implements linux syscall mlock2(2).
+func Mlock2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+ flags := args[2].Int()
+
+ if flags&^(linux.MLOCK_ONFAULT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ mode := memmap.MLockEager
+ if flags&linux.MLOCK_ONFAULT != 0 {
+ mode = memmap.MLockLazy
+ }
+ return 0, nil, t.MemoryManager().MLock(t, addr, uint64(length), mode)
+}
+
+// Munlock implements linux syscall munlock(2).
+func Munlock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+
+ return 0, nil, t.MemoryManager().MLock(t, addr, uint64(length), memmap.MLockNone)
+}
+
+// Mlockall implements linux syscall mlockall(2).
+func Mlockall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+
+ if flags&^(linux.MCL_CURRENT|linux.MCL_FUTURE|linux.MCL_ONFAULT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ mode := memmap.MLockEager
+ if flags&linux.MCL_ONFAULT != 0 {
+ mode = memmap.MLockLazy
+ }
+ return 0, nil, t.MemoryManager().MLockAll(t, mm.MLockAllOpts{
+ Current: flags&linux.MCL_CURRENT != 0,
+ Future: flags&linux.MCL_FUTURE != 0,
+ Mode: mode,
+ })
+}
+
+// Munlockall implements linux syscall munlockall(2).
+func Munlockall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.MemoryManager().MLockAll(t, mm.MLockAllOpts{
+ Current: true,
+ Future: true,
+ Mode: memmap.MLockNone,
+ })
+}
diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go
new file mode 100644
index 000000000..cf613bad0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_mount.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Mount implements Linux syscall mount(2).
+func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sourceAddr := args[0].Pointer()
+ targetAddr := args[1].Pointer()
+ typeAddr := args[2].Pointer()
+ flags := args[3].Uint64()
+ dataAddr := args[4].Pointer()
+
+ fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ sourcePath, _, err := copyInPath(t, sourceAddr, true /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ targetPath, _, err := copyInPath(t, targetAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ data := ""
+ if dataAddr != 0 {
+ // In Linux, a full page is always copied in regardless of null
+ // character placement, and the address is passed to each file system.
+ // Most file systems always treat this data as a string, though, and so
+ // do all of the ones we implement.
+ data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ // Ignore magic value that was required before Linux 2.4.
+ if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL {
+ flags = flags &^ linux.MS_MGC_MSK
+ }
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ if !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespace().UserNamespace()) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND |
+ linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE |
+ linux.MS_UNBINDABLE | linux.MS_MOVE
+
+ // Silently allow MS_NOSUID, since we don't implement set-id bits
+ // anyway.
+ const unsupportedFlags = linux.MS_NODEV |
+ linux.MS_NODIRATIME | linux.MS_STRICTATIME
+
+ // Linux just allows passing any flags to mount(2) - it won't fail when
+ // unknown or unsupported flags are passed. Since we don't implement
+ // everything, we fail explicitly on flags that are unimplemented.
+ if flags&(unsupportedOps|unsupportedFlags) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ rsys, ok := fs.FindFilesystem(fsType)
+ if !ok {
+ return 0, nil, syserror.ENODEV
+ }
+ if !rsys.AllowUserMount() {
+ return 0, nil, syserror.EPERM
+ }
+
+ var superFlags fs.MountSourceFlags
+ if flags&linux.MS_NOATIME == linux.MS_NOATIME {
+ superFlags.NoAtime = true
+ }
+ if flags&linux.MS_RDONLY == linux.MS_RDONLY {
+ superFlags.ReadOnly = true
+ }
+ if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
+ superFlags.NoExec = true
+ }
+
+ rootInode, err := rsys.Mount(t, sourcePath, superFlags, data, nil)
+ if err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, targetPath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ return t.MountNamespace().Mount(t, d, rootInode)
+ })
+}
+
+// Umount2 implements Linux syscall umount2(2).
+func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+
+ const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
+ if flags&unsupported != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ //
+ // Currently, this is always the init task's user namespace.
+ if !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespace().UserNamespace()) {
+ return 0, nil, syserror.EPERM
+ }
+
+ resolve := flags&linux.UMOUNT_NOFOLLOW != linux.UMOUNT_NOFOLLOW
+ detachOnly := flags&linux.MNT_DETACH == linux.MNT_DETACH
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolve, func(root *fs.Dirent, d *fs.Dirent) error {
+ return t.MountNamespace().Unmount(t, d, detachOnly)
+ })
+}
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
new file mode 100644
index 000000000..036845c13
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// pipe2 implements the actual system call with flags.
+func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
+ if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
+ return 0, syscall.EINVAL
+ }
+ r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize)
+
+ r.SetFlags(linuxToFlags(flags).Settable())
+ defer r.DecRef()
+
+ w.SetFlags(linuxToFlags(flags).Settable())
+ defer w.DecRef()
+
+ rfd, err := t.FDMap().NewFDFrom(0, r, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0},
+ t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, err
+ }
+
+ wfd, err := t.FDMap().NewFDFrom(0, w, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0},
+ t.ThreadGroup().Limits())
+ if err != nil {
+ t.FDMap().Remove(rfd)
+ return 0, err
+ }
+
+ if _, err := t.CopyOut(addr, []kdefs.FD{rfd, wfd}); err != nil {
+ t.FDMap().Remove(rfd)
+ t.FDMap().Remove(wfd)
+ return 0, syscall.EFAULT
+ }
+ return 0, nil
+}
+
+// Pipe implements linux syscall pipe(2).
+func Pipe(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ n, err := pipe2(t, addr, 0)
+ return n, nil, err
+}
+
+// Pipe2 implements linux syscall pipe2(2).
+func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := uint(args[1].Uint())
+
+ n, err := pipe2(t, addr, flags)
+ return n, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
new file mode 100644
index 000000000..e32099dd4
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -0,0 +1,549 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file descriptor and waiter of a PollFD.
+type pollState struct {
+ file *fs.File
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.FDMap().GetFile(kdefs.FD(pfd.FD))
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef()
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef()
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ forever := timeout < 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, !forever, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// CopyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := CopyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Capture all the provided input vectors.
+ //
+ // N.B. This only works on little-endian architectures.
+ byteCount := (nfds + 7) / 8
+
+ bitsInLastPartialByte := uint(nfds % 8)
+ r := make([]byte, byteCount)
+ w := make([]byte, byteCount)
+ e := make([]byte, byteCount)
+
+ if readFDs != 0 {
+ if _, err := t.CopyIn(readFDs, &r); err != nil {
+ return 0, err
+ }
+ // Mask out bits above nfds.
+ if bitsInLastPartialByte != 0 {
+ r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyIn(writeFDs, &w); err != nil {
+ return 0, err
+ }
+ if bitsInLastPartialByte != 0 {
+ w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyIn(exceptFDs, &e); err != nil {
+ return 0, err
+ }
+ if bitsInLastPartialByte != 0 {
+ e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
+ }
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < byteCount; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < byteCount; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.FDMap().GetFile(kdefs.FD(fd))
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef()
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ _, _, err := pollBlock(t, pfd, timeout)
+ if err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ return copyTimespecOut(t, timespecAddr, &tsRemaining)
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ return copyTimevalOut(t, timevalAddr, &tvRemaining)
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskAddr != 0 {
+ mask, err := copyInSigSet(t, maskAddr, maskSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ timeval, err := copyTimevalIn(t, timevalAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ maskAddr, size, err := copyInSigSetWithSize(t, maskWithSizeAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if maskAddr != 0 {
+ mask, err := copyInSigSet(t, maskAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
new file mode 100644
index 000000000..117ae1a0e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -0,0 +1,201 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+)
+
+// Prctl implements linux syscall prctl(2).
+// It has a list of subfunctions which operate on the process. The arguments are
+// all based on each subfunction.
+func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ option := args[0].Int()
+
+ switch option {
+ case linux.PR_SET_PDEATHSIG:
+ sig := linux.Signal(args[1].Int())
+ if sig != 0 && !sig.IsValid() {
+ return 0, nil, syscall.EINVAL
+ }
+ t.SetParentDeathSignal(sig)
+ return 0, nil, nil
+
+ case linux.PR_GET_PDEATHSIG:
+ _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal()))
+ return 0, nil, err
+
+ case linux.PR_GET_KEEPCAPS:
+ if t.Credentials().KeepCaps {
+ return 1, nil, nil
+ }
+
+ return 0, nil, nil
+
+ case linux.PR_SET_KEEPCAPS:
+ val := args[1].Int()
+ // prctl(2): arg2 must be either 0 (permitted capabilities are cleared)
+ // or 1 (permitted capabilities are kept).
+ if val == 0 {
+ t.SetKeepCaps(false)
+ } else if val == 1 {
+ t.SetKeepCaps(true)
+ } else {
+ return 0, nil, syscall.EINVAL
+ }
+
+ return 0, nil, nil
+
+ case linux.PR_SET_NAME:
+ addr := args[1].Pointer()
+ name, err := t.CopyInString(addr, linux.TASK_COMM_LEN-1)
+ if err != nil && err != syscall.ENAMETOOLONG {
+ return 0, nil, err
+ }
+ t.SetName(name)
+
+ case linux.PR_GET_NAME:
+ addr := args[1].Pointer()
+ buf := t.CopyScratchBuffer(linux.TASK_COMM_LEN)
+ len := copy(buf, t.Name())
+ if len < linux.TASK_COMM_LEN {
+ buf[len] = 0
+ len++
+ }
+ _, err := t.CopyOut(addr, buf[:len])
+ if err != nil {
+ return 0, nil, err
+ }
+
+ case linux.PR_SET_MM:
+ if !t.HasCapability(linux.CAP_SYS_RESOURCE) {
+ return 0, nil, syscall.EPERM
+ }
+
+ switch args[1].Int() {
+ case linux.PR_SET_MM_EXE_FILE:
+ fd := kdefs.FD(args[2].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // They trying to set exe to a non-file?
+ if !fs.IsFile(file.Dirent.Inode.StableAttr) {
+ return 0, nil, syscall.EBADF
+ }
+
+ // Set the underlying executable.
+ t.MemoryManager().SetExecutable(file.Dirent)
+
+ case linux.PR_SET_MM_AUXV,
+ linux.PR_SET_MM_START_CODE,
+ linux.PR_SET_MM_END_CODE,
+ linux.PR_SET_MM_START_DATA,
+ linux.PR_SET_MM_END_DATA,
+ linux.PR_SET_MM_START_STACK,
+ linux.PR_SET_MM_START_BRK,
+ linux.PR_SET_MM_BRK,
+ linux.PR_SET_MM_ARG_START,
+ linux.PR_SET_MM_ARG_END,
+ linux.PR_SET_MM_ENV_START,
+ linux.PR_SET_MM_ENV_END:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ fallthrough
+ default:
+ return 0, nil, syscall.EINVAL
+ }
+
+ case linux.PR_SET_NO_NEW_PRIVS:
+ if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ // no_new_privs is assumed to always be set. See
+ // kernel.Task.updateCredsForExec.
+ return 0, nil, nil
+
+ case linux.PR_GET_NO_NEW_PRIVS:
+ if args[1].Int() != 0 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ return 1, nil, nil
+
+ case linux.PR_SET_SECCOMP:
+ if args[1].Int() != linux.SECCOMP_MODE_FILTER {
+ // Unsupported mode.
+ return 0, nil, syscall.EINVAL
+ }
+
+ return 0, nil, seccomp(t, linux.SECCOMP_SET_MODE_FILTER, 0, args[2].Pointer())
+
+ case linux.PR_GET_SECCOMP:
+ return uintptr(t.SeccompMode()), nil, nil
+
+ case linux.PR_CAPBSET_READ:
+ cp := linux.Capability(args[1].Uint64())
+ if !cp.Ok() {
+ return 0, nil, syscall.EINVAL
+ }
+ var rv uintptr
+ if auth.CapabilitySetOf(cp)&t.Credentials().BoundingCaps != 0 {
+ rv = 1
+ }
+ return rv, nil, nil
+
+ case linux.PR_CAPBSET_DROP:
+ cp := linux.Capability(args[1].Uint64())
+ if !cp.Ok() {
+ return 0, nil, syscall.EINVAL
+ }
+ return 0, nil, t.DropBoundingCapability(cp)
+
+ case linux.PR_GET_DUMPABLE,
+ linux.PR_SET_DUMPABLE,
+ linux.PR_GET_TIMING,
+ linux.PR_SET_TIMING,
+ linux.PR_GET_TSC,
+ linux.PR_SET_TSC,
+ linux.PR_TASK_PERF_EVENTS_DISABLE,
+ linux.PR_TASK_PERF_EVENTS_ENABLE,
+ linux.PR_GET_TIMERSLACK,
+ linux.PR_SET_TIMERSLACK,
+ linux.PR_MCE_KILL,
+ linux.PR_MCE_KILL_GET,
+ linux.PR_GET_TID_ADDRESS,
+ linux.PR_SET_CHILD_SUBREAPER,
+ linux.PR_GET_CHILD_SUBREAPER,
+ linux.PR_GET_THP_DISABLE,
+ linux.PR_SET_THP_DISABLE,
+ linux.PR_MPX_ENABLE_MANAGEMENT,
+ linux.PR_MPX_DISABLE_MANAGEMENT:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ fallthrough
+ default:
+ return 0, nil, syscall.EINVAL
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_random.go b/pkg/sentry/syscalls/linux/sys_random.go
new file mode 100644
index 000000000..fc3959a7e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_random.go
@@ -0,0 +1,92 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "io"
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const (
+ _GRND_NONBLOCK = 0x1
+ _GRND_RANDOM = 0x2
+)
+
+// GetRandom implements the linux syscall getrandom(2).
+//
+// In a multi-tenant/shared environment, the only valid implementation is to
+// fetch data from the urandom pool, otherwise starvation attacks become
+// possible. The urandom pool is also expected to have plenty of entropy, thus
+// the GRND_RANDOM flag is ignored. The GRND_NONBLOCK flag does not apply, as
+// the pool will already be initialized.
+func GetRandom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].SizeT()
+ flags := args[2].Int()
+
+ // Flags are checked for validity but otherwise ignored. See above.
+ if flags & ^(_GRND_NONBLOCK|_GRND_RANDOM) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if length > math.MaxInt32 {
+ length = math.MaxInt32
+ }
+ ar, ok := addr.ToRange(uint64(length))
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+
+ // "If the urandom source has been initialized, reads of up to 256 bytes
+ // will always return as many bytes as requested and will not be
+ // interrupted by signals. No such guarantees apply for larger buffer
+ // sizes." - getrandom(2)
+ min := int(length)
+ if min > 256 {
+ min = 256
+ }
+ n, err := t.MemoryManager().CopyOutFrom(t, usermem.AddrRangeSeqOf(ar), safemem.FromIOReader{&randReader{-1, min}}, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if n >= int64(min) {
+ return uintptr(n), nil, nil
+ }
+ return 0, nil, err
+}
+
+// randReader is a io.Reader that handles partial reads from rand.Reader.
+type randReader struct {
+ done int
+ min int
+}
+
+// Read implements io.Reader.Read.
+func (r *randReader) Read(dst []byte) (int, error) {
+ if r.done >= r.min {
+ return rand.Reader.Read(dst)
+ }
+ min := r.min - r.done
+ if min > len(dst) {
+ min = len(dst)
+ }
+ return io.ReadAtLeast(rand.Reader, dst, min)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
new file mode 100644
index 000000000..48b0fd49d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -0,0 +1,357 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // EventMaskRead contains events that can be triggerd on reads.
+ EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements linux syscall read(2). Note that we try to get a buffer that
+// is exactly the size requested because some applications like qemu expect
+// they can do large reads all at once. Bug for bug. Same for other read
+// calls below.
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := readv(t, file, dst)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Pread64 implements linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is reading at an offset supported?
+ if !file.Flags().Pread {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := preadv(t, file, dst, offset)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Readv implements linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := readv(t, file, dst)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+// Preadv implements linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is reading at an offset supported?
+ if !file.Flags().Pread {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := preadv(t, file, dst, offset)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements linux syscall preadv2(2).
+// TODO(b/120162627): Implement RWF_HIPRI functionality.
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the syscall is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the linux internal call
+ // (https://elixir.bootlin.com/linux/v4.18/source/fs/read_write.c#L1248)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 5th argument.
+
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := int(args[5].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is reading at an offset supported?
+ if offset > -1 && !file.Flags().Pread {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check flags field.
+ if flags&^linux.RWF_VALID != 0 {
+ return 0, nil, syserror.EOPNOTSUPP
+ }
+
+ // Read the iovecs that specify the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // If preadv2 is called with an offset of -1, readv is called.
+ if offset == -1 {
+ n, err := readv(t, file, dst)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+ }
+
+ n, err := preadv(t, file, dst, offset)
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) {
+ n, err := f.Readv(t, dst)
+ if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if n > 0 {
+ // Queue notification if we read anything.
+ f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ }
+ return n, err
+ }
+
+ // Sockets support read timeouts.
+ var haveDeadline bool
+ var deadline ktime.Time
+ if s, ok := f.FileOperations.(socket.Socket); ok {
+ dl := s.RecvTimeout()
+ if dl < 0 && err == syserror.ErrWouldBlock {
+ return n, err
+ }
+ if dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ }
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ f.EventRegister(&w, EventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst64(n)
+
+ // Issue the request and break out if it completes with anything
+ // other than "would block".
+ n, err = f.Readv(t, dst)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+
+ f.EventUnregister(&w)
+
+ if total > 0 {
+ // Queue notification if we read anything.
+ f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ }
+
+ return total, err
+}
+
+func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ n, err := f.Preadv(t, dst, offset)
+ if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if n > 0 {
+ // Queue notification if we read anything.
+ f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ f.EventRegister(&w, EventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst64(n)
+
+ // Issue the request and break out if it completes with anything
+ // other than "would block".
+ n, err = f.Preadv(t, dst, offset+total)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.Block(ch); err != nil {
+ break
+ }
+ }
+
+ f.EventUnregister(&w)
+
+ if total > 0 {
+ // Queue notification if we read anything.
+ f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
+ }
+
+ return total, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
new file mode 100644
index 000000000..8b0379779
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -0,0 +1,224 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/limits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// rlimit describes an implementation of 'struct rlimit', which may vary from
+// system-to-system.
+type rlimit interface {
+ // toLimit converts an rlimit to a limits.Limit.
+ toLimit() *limits.Limit
+
+ // fromLimit converts a limits.Limit to an rlimit.
+ fromLimit(lim limits.Limit)
+
+ // copyIn copies an rlimit from the untrusted app to the kernel.
+ copyIn(t *kernel.Task, addr usermem.Addr) error
+
+ // copyOut copies an rlimit from the kernel to the untrusted app.
+ copyOut(t *kernel.Task, addr usermem.Addr) error
+}
+
+// newRlimit returns the appropriate rlimit type for 'struct rlimit' on this system.
+func newRlimit(t *kernel.Task) (rlimit, error) {
+ switch t.Arch().Width() {
+ case 8:
+ // On 64-bit system, struct rlimit and struct rlimit64 are identical.
+ return &rlimit64{}, nil
+ default:
+ return nil, syserror.ENOSYS
+ }
+}
+
+type rlimit64 struct {
+ Cur uint64
+ Max uint64
+}
+
+func (r *rlimit64) toLimit() *limits.Limit {
+ return &limits.Limit{
+ Cur: limits.FromLinux(r.Cur),
+ Max: limits.FromLinux(r.Max),
+ }
+}
+
+func (r *rlimit64) fromLimit(lim limits.Limit) {
+ *r = rlimit64{
+ Cur: limits.ToLinux(lim.Cur),
+ Max: limits.ToLinux(lim.Max),
+ }
+}
+
+func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error {
+ _, err := t.CopyIn(addr, r)
+ return err
+}
+
+func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error {
+ _, err := t.CopyOut(addr, *r)
+ return err
+}
+
+func makeRlimit64(lim limits.Limit) *rlimit64 {
+ return &rlimit64{Cur: lim.Cur, Max: lim.Max}
+}
+
+// setableLimits is the set of supported setable limits.
+var setableLimits = map[limits.LimitType]struct{}{
+ limits.NumberOfFiles: {},
+ limits.AS: {},
+ limits.CPU: {},
+ limits.Data: {},
+ limits.FileSize: {},
+ limits.MemoryLocked: {},
+ limits.Stack: {},
+ // These are not enforced, but we include them here to avoid returning
+ // EPERM, since some apps expect them to succeed.
+ limits.Core: {},
+ limits.ProcessCount: {},
+}
+
+func prlimit64(t *kernel.Task, resource limits.LimitType, newLim *limits.Limit) (limits.Limit, error) {
+ if newLim == nil {
+ return t.ThreadGroup().Limits().Get(resource), nil
+ }
+
+ if _, ok := setableLimits[resource]; !ok {
+ return limits.Limit{}, syserror.EPERM
+ }
+
+ // "A privileged process (under Linux: one with the CAP_SYS_RESOURCE
+ // capability in the initial user namespace) may make arbitrary changes
+ // to either limit value."
+ privileged := t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.Kernel().RootUserNamespace())
+
+ oldLim, err := t.ThreadGroup().Limits().Set(resource, *newLim, privileged)
+ if err != nil {
+ return limits.Limit{}, err
+ }
+
+ if resource == limits.CPU {
+ t.NotifyRlimitCPUUpdated()
+ }
+ return oldLim, nil
+}
+
+// Getrlimit implements linux syscall getrlimit(2).
+func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ resource, ok := limits.FromLinuxResource[int(args[0].Int())]
+ if !ok {
+ // Return err; unknown limit.
+ return 0, nil, syserror.EINVAL
+ }
+ addr := args[1].Pointer()
+ rlim, err := newRlimit(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ lim, err := prlimit64(t, resource, nil)
+ if err != nil {
+ return 0, nil, err
+ }
+ rlim.fromLimit(lim)
+ return 0, nil, rlim.copyOut(t, addr)
+}
+
+// Setrlimit implements linux syscall setrlimit(2).
+func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ resource, ok := limits.FromLinuxResource[int(args[0].Int())]
+ if !ok {
+ // Return err; unknown limit.
+ return 0, nil, syserror.EINVAL
+ }
+ addr := args[1].Pointer()
+ rlim, err := newRlimit(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ if err := rlim.copyIn(t, addr); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ _, err = prlimit64(t, resource, rlim.toLimit())
+ return 0, nil, err
+}
+
+// Prlimit64 implements linux syscall prlimit64(2).
+func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := kernel.ThreadID(args[0].Int())
+ resource, ok := limits.FromLinuxResource[int(args[1].Int())]
+ if !ok {
+ // Return err; unknown limit.
+ return 0, nil, syserror.EINVAL
+ }
+ newRlimAddr := args[2].Pointer()
+ oldRlimAddr := args[3].Pointer()
+
+ var newLim *limits.Limit
+ if newRlimAddr != 0 {
+ var nrl rlimit64
+ if err := nrl.copyIn(t, newRlimAddr); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ newLim = nrl.toLimit()
+ }
+
+ if tid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ ot := t
+ if tid > 0 {
+ if ot = t.PIDNamespace().TaskWithID(tid); ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // "To set or get the resources of a process other than itself, the caller
+ // must have the CAP_SYS_RESOURCE capability, or the real, effective, and
+ // saved set user IDs of the target process must match the real user ID of
+ // the caller and the real, effective, and saved set group IDs of the
+ // target process must match the real group ID of the caller."
+ if !t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.PIDNamespace().UserNamespace()) {
+ cred, tcred := t.Credentials(), ot.Credentials()
+ if cred.RealKUID != tcred.RealKUID ||
+ cred.RealKUID != tcred.EffectiveKUID ||
+ cred.RealKUID != tcred.SavedKUID ||
+ cred.RealKGID != tcred.RealKGID ||
+ cred.RealKGID != tcred.EffectiveKGID ||
+ cred.RealKGID != tcred.SavedKGID {
+ return 0, nil, syserror.EPERM
+ }
+ }
+
+ oldLim, err := prlimit64(ot, resource, newLim)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if oldRlimAddr != 0 {
+ if err := makeRlimit64(oldLim).copyOut(t, oldRlimAddr); err != nil {
+ return 0, nil, syserror.EFAULT
+ }
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go
new file mode 100644
index 000000000..003d718da
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rusage.go
@@ -0,0 +1,112 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+func getrusage(t *kernel.Task, which int32) linux.Rusage {
+ var cs usage.CPUStats
+
+ switch which {
+ case linux.RUSAGE_SELF:
+ cs = t.ThreadGroup().CPUStats()
+
+ case linux.RUSAGE_CHILDREN:
+ cs = t.ThreadGroup().JoinedChildCPUStats()
+
+ case linux.RUSAGE_THREAD:
+ cs = t.CPUStats()
+
+ case linux.RUSAGE_BOTH:
+ tg := t.ThreadGroup()
+ cs = tg.CPUStats()
+ cs.Accumulate(tg.JoinedChildCPUStats())
+ }
+
+ return linux.Rusage{
+ UTime: linux.NsecToTimeval(cs.UserTime.Nanoseconds()),
+ STime: linux.NsecToTimeval(cs.SysTime.Nanoseconds()),
+ NVCSw: int64(cs.VoluntarySwitches),
+ MaxRSS: int64(t.MaxRSS(which) / 1024),
+ }
+}
+
+// Getrusage implements linux syscall getrusage(2).
+// marked "y" are supported now
+// marked "*" are not used on Linux
+// marked "p" are pending for support
+//
+// y struct timeval ru_utime; /* user CPU time used */
+// y struct timeval ru_stime; /* system CPU time used */
+// p long ru_maxrss; /* maximum resident set size */
+// * long ru_ixrss; /* integral shared memory size */
+// * long ru_idrss; /* integral unshared data size */
+// * long ru_isrss; /* integral unshared stack size */
+// p long ru_minflt; /* page reclaims (soft page faults) */
+// p long ru_majflt; /* page faults (hard page faults) */
+// * long ru_nswap; /* swaps */
+// p long ru_inblock; /* block input operations */
+// p long ru_oublock; /* block output operations */
+// * long ru_msgsnd; /* IPC messages sent */
+// * long ru_msgrcv; /* IPC messages received */
+// * long ru_nsignals; /* signals received */
+// y long ru_nvcsw; /* voluntary context switches */
+// y long ru_nivcsw; /* involuntary context switches */
+func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ which := args[0].Int()
+ addr := args[1].Pointer()
+
+ if which != linux.RUSAGE_SELF && which != linux.RUSAGE_CHILDREN && which != linux.RUSAGE_THREAD {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ru := getrusage(t, which)
+ _, err := t.CopyOut(addr, &ru)
+ return 0, nil, err
+}
+
+// Times implements linux syscall times(2).
+func Times(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ // Calculate the ticks first, and figure out if any additional work is
+ // necessary. Linux allows for a NULL addr, in which case only the
+ // return value is meaningful. We don't need to do anything else.
+ ticks := uintptr(ktime.NowFromContext(t).Nanoseconds() / linux.ClockTick.Nanoseconds())
+ if addr == 0 {
+ return ticks, nil, nil
+ }
+
+ cs1 := t.ThreadGroup().CPUStats()
+ cs2 := t.ThreadGroup().JoinedChildCPUStats()
+ r := linux.Tms{
+ UTime: linux.ClockTFromDuration(cs1.UserTime),
+ STime: linux.ClockTFromDuration(cs1.SysTime),
+ CUTime: linux.ClockTFromDuration(cs2.UserTime),
+ CSTime: linux.ClockTFromDuration(cs2.SysTime),
+ }
+ if _, err := t.CopyOut(addr, &r); err != nil {
+ return 0, nil, err
+ }
+
+ return ticks, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go
new file mode 100644
index 000000000..8aea03abe
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_sched.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+)
+
+const (
+ onlyScheduler = linux.SCHED_NORMAL
+ onlyPriority = 0
+)
+
+// SchedParam replicates struct sched_param in sched.h.
+type SchedParam struct {
+ schedPriority int64
+}
+
+// SchedGetparam implements linux syscall sched_getparam(2).
+func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := args[0].Int()
+ param := args[1].Pointer()
+ if param == 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ if pid < 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
+ return 0, nil, syscall.ESRCH
+ }
+ r := SchedParam{schedPriority: onlyPriority}
+ if _, err := t.CopyOut(param, r); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// SchedGetscheduler implements linux syscall sched_getscheduler(2).
+func SchedGetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := args[0].Int()
+ if pid < 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
+ return 0, nil, syscall.ESRCH
+ }
+ return onlyScheduler, nil, nil
+}
+
+// SchedSetscheduler implements linux syscall sched_setscheduler(2).
+func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := args[0].Int()
+ policy := args[1].Int()
+ param := args[2].Pointer()
+ if pid < 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ if policy != onlyScheduler {
+ return 0, nil, syscall.EINVAL
+ }
+ if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
+ return 0, nil, syscall.ESRCH
+ }
+ var r SchedParam
+ if _, err := t.CopyIn(param, &r); err != nil {
+ return 0, nil, syscall.EINVAL
+ }
+ if r.schedPriority != onlyPriority {
+ return 0, nil, syscall.EINVAL
+ }
+ return 0, nil, nil
+}
+
+// SchedGetPriorityMax implements linux syscall sched_get_priority_max(2).
+func SchedGetPriorityMax(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return onlyPriority, nil, nil
+}
+
+// SchedGetPriorityMin implements linux syscall sched_get_priority_min(2).
+func SchedGetPriorityMin(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return onlyPriority, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go
new file mode 100644
index 000000000..b4262162a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_seccomp.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/bpf"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// userSockFprog is equivalent to Linux's struct sock_fprog on amd64.
+type userSockFprog struct {
+ // Len is the length of the filter in BPF instructions.
+ Len uint16
+
+ _ [6]byte // padding for alignment
+
+ // Filter is a user pointer to the struct sock_filter array that makes up
+ // the filter program. Filter is a uint64 rather than a usermem.Addr
+ // because usermem.Addr is actually uintptr, which is not a fixed-size
+ // type, and encoding/binary.Read objects to this.
+ Filter uint64
+}
+
+// seccomp applies a seccomp policy to the current task.
+func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error {
+ // We only support SECCOMP_SET_MODE_FILTER at the moment.
+ if mode != linux.SECCOMP_SET_MODE_FILTER {
+ // Unsupported mode.
+ return syscall.EINVAL
+ }
+
+ tsync := flags&linux.SECCOMP_FILTER_FLAG_TSYNC != 0
+
+ // The only flag we support now is SECCOMP_FILTER_FLAG_TSYNC.
+ if flags&^linux.SECCOMP_FILTER_FLAG_TSYNC != 0 {
+ // Unsupported flag.
+ return syscall.EINVAL
+ }
+
+ var fprog userSockFprog
+ if _, err := t.CopyIn(addr, &fprog); err != nil {
+ return err
+ }
+ filter := make([]linux.BPFInstruction, int(fprog.Len))
+ if _, err := t.CopyIn(usermem.Addr(fprog.Filter), &filter); err != nil {
+ return err
+ }
+ compiledFilter, err := bpf.Compile(filter)
+ if err != nil {
+ t.Debugf("Invalid seccomp-bpf filter: %v", err)
+ return syscall.EINVAL
+ }
+
+ return t.AppendSyscallFilter(compiledFilter, tsync)
+}
+
+// Seccomp implements linux syscall seccomp(2).
+func Seccomp(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, seccomp(t, args[0].Uint64(), args[1].Uint64(), args[2].Pointer())
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
new file mode 100644
index 000000000..5bd61ab87
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -0,0 +1,241 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const opsMax = 500 // SEMOPM
+
+// Semget handles: semget(key_t key, int nsems, int semflg)
+func Semget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ key := args[0].Int()
+ nsems := args[1].Int()
+ flag := args[2].Int()
+
+ private := key == linux.IPC_PRIVATE
+ create := flag&linux.IPC_CREAT == linux.IPC_CREAT
+ exclusive := flag&linux.IPC_EXCL == linux.IPC_EXCL
+ mode := linux.FileMode(flag & 0777)
+
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set, err := r.FindOrCreate(t, key, nsems, mode, private, create, exclusive)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(set.ID), nil, nil
+}
+
+// Semop handles: semop(int semid, struct sembuf *sops, size_t nsops)
+func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Int()
+ sembufAddr := args[1].Pointer()
+ nsops := args[2].SizeT()
+
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, nil, syserror.EINVAL
+ }
+ if nsops <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if nsops > opsMax {
+ return 0, nil, syserror.E2BIG
+ }
+
+ ops := make([]linux.Sembuf, nsops)
+ if _, err := t.CopyIn(sembufAddr, ops); err != nil {
+ return 0, nil, err
+ }
+
+ creds := auth.CredentialsFromContext(t)
+ pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
+ for {
+ ch, num, err := set.ExecuteOps(t, ops, creds, int32(pid))
+ if ch == nil || err != nil {
+ // We're done (either on success or a failure).
+ return 0, nil, err
+ }
+ if err = t.Block(ch); err != nil {
+ set.AbortWait(num, ch)
+ return 0, nil, err
+ }
+ }
+}
+
+// Semctl handles: semctl(int semid, int semnum, int cmd, ...)
+func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Int()
+ num := args[1].Int()
+ cmd := args[2].Int()
+
+ switch cmd {
+ case linux.SETVAL:
+ val := args[3].Int()
+ if val > math.MaxInt16 {
+ return 0, nil, syserror.ERANGE
+ }
+ return 0, nil, setVal(t, id, num, int16(val))
+
+ case linux.SETALL:
+ array := args[3].Pointer()
+ return 0, nil, setValAll(t, id, array)
+
+ case linux.GETVAL:
+ v, err := getVal(t, id, num)
+ return uintptr(v), nil, err
+
+ case linux.GETALL:
+ array := args[3].Pointer()
+ return 0, nil, getValAll(t, id, array)
+
+ case linux.IPC_RMID:
+ return 0, nil, remove(t, id)
+
+ case linux.IPC_SET:
+ arg := args[3].Pointer()
+ s := linux.SemidDS{}
+ if _, err := t.CopyIn(arg, &s); err != nil {
+ return 0, nil, err
+ }
+
+ perms := fs.FilePermsFromMode(linux.FileMode(s.SemPerm.Mode & 0777))
+ return 0, nil, ipcSet(t, id, auth.UID(s.SemPerm.UID), auth.GID(s.SemPerm.GID), perms)
+
+ case linux.GETPID:
+ v, err := getPID(t, id, num)
+ return uintptr(v), nil, err
+
+ case linux.IPC_INFO,
+ linux.SEM_INFO,
+ linux.IPC_STAT,
+ linux.SEM_STAT,
+ linux.SEM_STAT_ANY,
+ linux.GETNCNT,
+ linux.GETZCNT:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ fallthrough
+
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+func remove(t *kernel.Task, id int32) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ creds := auth.CredentialsFromContext(t)
+ return r.RemoveID(id, creds)
+}
+
+func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FilePermissions) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+
+ creds := auth.CredentialsFromContext(t)
+ kuid := creds.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ kgid := creds.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ owner := fs.FileOwner{UID: kuid, GID: kgid}
+ return set.Change(t, creds, owner, perms)
+}
+
+func setVal(t *kernel.Task, id int32, num int32, val int16) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
+ return set.SetVal(t, num, val, creds, int32(pid))
+}
+
+func setValAll(t *kernel.Task, id int32, array usermem.Addr) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+ vals := make([]uint16, set.Size())
+ if _, err := t.CopyIn(array, vals); err != nil {
+ return err
+ }
+ creds := auth.CredentialsFromContext(t)
+ pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
+ return set.SetValAll(t, vals, creds, int32(pid))
+}
+
+func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.GetVal(num, creds)
+}
+
+func getValAll(t *kernel.Task, id int32, array usermem.Addr) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ vals, err := set.GetValAll(creds)
+ if err != nil {
+ return err
+ }
+ _, err = t.CopyOut(array, vals)
+ return err
+}
+
+func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ gpid, err := set.GetPID(num, creds)
+ if err != nil {
+ return 0, err
+ }
+ // Convert pid from init namespace to the caller's namespace.
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(gpid))
+ if tg == nil {
+ return 0, nil
+ }
+ return int32(tg.ID()), nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
new file mode 100644
index 000000000..d0eceac7c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -0,0 +1,156 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/shm"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Shmget implements shmget(2).
+func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ key := shm.Key(args[0].Int())
+ size := uint64(args[1].SizeT())
+ flag := args[2].Int()
+
+ private := key == linux.IPC_PRIVATE
+ create := flag&linux.IPC_CREAT == linux.IPC_CREAT
+ exclusive := flag&linux.IPC_EXCL == linux.IPC_EXCL
+ mode := linux.FileMode(flag & 0777)
+
+ pid := int32(t.ThreadGroup().ID())
+ r := t.IPCNamespace().ShmRegistry()
+ segment, err := r.FindOrCreate(t, pid, key, size, mode, private, create, exclusive)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(segment.ID), nil, nil
+}
+
+// findSegment retrives a shm segment by the given id.
+func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
+ r := t.IPCNamespace().ShmRegistry()
+ segment := r.FindByID(id)
+ if segment == nil {
+ // No segment with provided id.
+ return nil, syserror.EINVAL
+ }
+ return segment, nil
+}
+
+// Shmat implements shmat(2).
+func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := shm.ID(args[0].Int())
+ addr := args[1].Pointer()
+ flag := args[2].Int()
+
+ segment, err := findSegment(t, id)
+ if err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{
+ Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC,
+ Readonly: flag&linux.SHM_RDONLY == linux.SHM_RDONLY,
+ Remap: flag&linux.SHM_REMAP == linux.SHM_REMAP,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ defer segment.DecRef()
+ addr, err = t.MemoryManager().MMap(t, opts)
+ return uintptr(addr), nil, err
+}
+
+// Shmdt implements shmdt(2).
+func Shmdt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ err := t.MemoryManager().DetachShm(t, addr)
+ return 0, nil, err
+}
+
+// Shmctl implements shmctl(2).
+func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := shm.ID(args[0].Int())
+ cmd := args[1].Int()
+ buf := args[2].Pointer()
+
+ r := t.IPCNamespace().ShmRegistry()
+
+ switch cmd {
+ case linux.SHM_STAT:
+ // Technically, we should be treating id as "an index into the kernel's
+ // internal array that maintains information about all shared memory
+ // segments on the system". Since we don't track segments in an array,
+ // we'll just pretend the shmid is the index and do the same thing as
+ // IPC_STAT. Linux also uses the index as the shmid.
+ fallthrough
+ case linux.IPC_STAT:
+ segment, err := findSegment(t, id)
+ if err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+
+ stat, err := segment.IPCStat(t)
+ if err == nil {
+ _, err = t.CopyOut(buf, stat)
+ }
+ return 0, nil, err
+
+ case linux.IPC_INFO:
+ params := r.IPCInfo()
+ _, err := t.CopyOut(buf, params)
+ return 0, nil, err
+
+ case linux.SHM_INFO:
+ info := r.ShmInfo()
+ _, err := t.CopyOut(buf, info)
+ return 0, nil, err
+ }
+
+ // Remaining commands refer to a specific segment.
+ segment, err := findSegment(t, id)
+ if err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+
+ switch cmd {
+ case linux.IPC_SET:
+ var ds linux.ShmidDS
+ _, err = t.CopyIn(buf, &ds)
+ if err != nil {
+ return 0, nil, err
+ }
+ err = segment.Set(t, &ds)
+ return 0, nil, err
+
+ case linux.IPC_RMID:
+ segment.MarkDestroyed()
+ return 0, nil, nil
+
+ case linux.SHM_LOCK, linux.SHM_UNLOCK:
+ // We currently do not support memory locking anywhere.
+ // mlock(2)/munlock(2) are currently stubbed out as no-ops so do the
+ // same here.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, nil
+
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
new file mode 100644
index 000000000..7fbeb4fcd
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -0,0 +1,508 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "math"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// "For a process to have permission to send a signal it must
+// - either be privileged (CAP_KILL), or
+// - the real or effective user ID of the sending process must be equal to the
+// real or saved set-user-ID of the target process.
+//
+// In the case of SIGCONT it suffices when the sending and receiving processes
+// belong to the same session." - kill(2)
+//
+// Equivalent to kernel/signal.c:check_kill_permission.
+func mayKill(t *kernel.Task, target *kernel.Task, sig linux.Signal) bool {
+ // kernel/signal.c:check_kill_permission also allows a signal if the
+ // sending and receiving tasks share a thread group, which is not
+ // mentioned in kill(2) since kill does not allow task-level
+ // granularity in signal sending.
+ if t.ThreadGroup() == target.ThreadGroup() {
+ return true
+ }
+
+ if t.HasCapabilityIn(linux.CAP_KILL, target.UserNamespace()) {
+ return true
+ }
+
+ creds := t.Credentials()
+ tcreds := target.Credentials()
+ if creds.EffectiveKUID == tcreds.SavedKUID ||
+ creds.EffectiveKUID == tcreds.RealKUID ||
+ creds.RealKUID == tcreds.SavedKUID ||
+ creds.RealKUID == tcreds.RealKUID {
+ return true
+ }
+
+ if sig == linux.SIGCONT && target.ThreadGroup().Session() == t.ThreadGroup().Session() {
+ return true
+ }
+ return false
+}
+
+// Kill implements linux syscall kill(2).
+func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := kernel.ThreadID(args[0].Int())
+ sig := linux.Signal(args[1].Int())
+
+ switch {
+ case pid > 0:
+ // "If pid is positive, then signal sig is sent to the process with the
+ // ID specified by pid." - kill(2)
+ // This loops to handle races with execve where target dies between
+ // TaskWithID and SendGroupSignal. Compare Linux's
+ // kernel/signal.c:kill_pid_info().
+ for {
+ target := t.PIDNamespace().TaskWithID(pid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ info.SetPid(int32(target.PIDNamespace().IDOfTask(t)))
+ info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
+ if err := target.SendGroupSignal(info); err != syserror.ESRCH {
+ return 0, nil, err
+ }
+ }
+ case pid == -1:
+ // "If pid equals -1, then sig is sent to every process for which the
+ // calling process has permission to send signals, except for process 1
+ // (init), but see below. ... POSIX.1-2001 requires that kill(-1,sig)
+ // send sig to all processes that the calling process may send signals
+ // to, except possibly for some implementation-defined system
+ // processes. Linux allows a process to signal itself, but on Linux the
+ // call kill(-1,sig) does not signal the calling process."
+ var (
+ lastErr error
+ delivered int
+ )
+ for _, tg := range t.PIDNamespace().ThreadGroups() {
+ if tg == t.ThreadGroup() {
+ continue
+ }
+ if t.PIDNamespace().IDOfThreadGroup(tg) == kernel.InitTID {
+ continue
+ }
+
+ // If pid == -1, the returned error is the last non-EPERM error
+ // from any call to group_send_sig_info.
+ if !mayKill(t, tg.Leader(), sig) {
+ continue
+ }
+ // Here and below, whether or not kill returns an error may
+ // depend on the iteration order. We at least implement the
+ // semantics documented by the man page: "On success (at least
+ // one signal was sent), zero is returned."
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ err := tg.SendSignal(info)
+ if err == syserror.ESRCH {
+ // ESRCH is ignored because it means the task
+ // exited while we were iterating. This is a
+ // race which would not normally exist on
+ // Linux, so we suppress it.
+ continue
+ }
+ delivered++
+ if err != nil {
+ lastErr = err
+ }
+ }
+ if delivered > 0 {
+ return 0, nil, lastErr
+ }
+ return 0, nil, syserror.ESRCH
+ default:
+ // "If pid equals 0, then sig is sent to every process in the process
+ // group of the calling process."
+ //
+ // "If pid is less than -1, then sig is sent to every process
+ // in the process group whose ID is -pid."
+ pgid := kernel.ProcessGroupID(-pid)
+ if pgid == 0 {
+ pgid = t.PIDNamespace().IDOfProcessGroup(t.ThreadGroup().ProcessGroup())
+ }
+
+ // If pid != -1 (i.e. signalling a process group), the returned error
+ // is the last error from any call to group_send_sig_info.
+ lastErr := syserror.ESRCH
+ for _, tg := range t.PIDNamespace().ThreadGroups() {
+ if t.PIDNamespace().IDOfProcessGroup(tg.ProcessGroup()) == pgid {
+ if !mayKill(t, tg.Leader(), sig) {
+ lastErr = syserror.EPERM
+ continue
+ }
+
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoUser,
+ }
+ info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ // See note above regarding ESRCH race above.
+ if err := tg.SendSignal(info); err != syserror.ESRCH {
+ lastErr = err
+ }
+ }
+ }
+
+ return 0, nil, lastErr
+ }
+}
+
+func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalInfo {
+ info := &arch.SignalInfo{
+ Signo: int32(sig),
+ Code: arch.SignalInfoTkill,
+ }
+ info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
+ info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ return info
+}
+
+// Tkill implements linux syscall tkill(2).
+func Tkill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := kernel.ThreadID(args[0].Int())
+ sig := linux.Signal(args[1].Int())
+
+ // N.B. Inconsistent with man page, linux actually rejects calls with
+ // tid <=0 by EINVAL. This isn't the same for all signal calls.
+ if tid <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ target := t.PIDNamespace().TaskWithID(tid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, target.SendSignal(tkillSigInfo(t, target, sig))
+}
+
+// Tgkill implements linux syscall tgkill(2).
+func Tgkill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tgid := kernel.ThreadID(args[0].Int())
+ tid := kernel.ThreadID(args[1].Int())
+ sig := linux.Signal(args[2].Int())
+
+ // N.B. Inconsistent with man page, linux actually rejects calls with
+ // tgid/tid <=0 by EINVAL. This isn't the same for all signal calls.
+ if tgid <= 0 || tid <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ targetTG := t.PIDNamespace().ThreadGroupWithID(tgid)
+ target := t.PIDNamespace().TaskWithID(tid)
+ if targetTG == nil || target == nil || target.ThreadGroup() != targetTG {
+ return 0, nil, syserror.ESRCH
+ }
+
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, target.SendSignal(tkillSigInfo(t, target, sig))
+}
+
+// RtSigaction implements linux syscall rt_sigaction(2).
+func RtSigaction(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sig := linux.Signal(args[0].Int())
+ newactarg := args[1].Pointer()
+ oldactarg := args[2].Pointer()
+
+ var newactptr *arch.SignalAct
+ if newactarg != 0 {
+ newact, err := t.CopyInSignalAct(newactarg)
+ if err != nil {
+ return 0, nil, err
+ }
+ newactptr = &newact
+ }
+ oldact, err := t.ThreadGroup().SetSignalAct(sig, newactptr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if oldactarg != 0 {
+ if err := t.CopyOutSignalAct(oldactarg, &oldact); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// Sigreturn implements linux syscall sigreturn(2).
+func Sigreturn(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ctrl, err := t.SignalReturn(false)
+ return 0, ctrl, err
+}
+
+// RtSigreturn implements linux syscall rt_sigreturn(2).
+func RtSigreturn(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ ctrl, err := t.SignalReturn(true)
+ return 0, ctrl, err
+}
+
+// RtSigprocmask implements linux syscall rt_sigprocmask(2).
+func RtSigprocmask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ how := args[0].Int()
+ setaddr := args[1].Pointer()
+ oldaddr := args[2].Pointer()
+ sigsetsize := args[3].SizeT()
+
+ if sigsetsize != linux.SignalSetSize {
+ return 0, nil, syserror.EINVAL
+ }
+ oldmask := t.SignalMask()
+ if setaddr != 0 {
+ mask, err := copyInSigSet(t, setaddr, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch how {
+ case linux.SIG_BLOCK:
+ t.SetSignalMask(oldmask | mask)
+ case linux.SIG_UNBLOCK:
+ t.SetSignalMask(oldmask &^ mask)
+ case linux.SIG_SETMASK:
+ t.SetSignalMask(mask)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ }
+ if oldaddr != 0 {
+ return 0, nil, copyOutSigSet(t, oldaddr, oldmask)
+ }
+
+ return 0, nil, nil
+}
+
+// Sigaltstack implements linux syscall sigaltstack(2).
+func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ setaddr := args[0].Pointer()
+ oldaddr := args[1].Pointer()
+
+ alt := t.SignalStack()
+ if oldaddr != 0 {
+ if err := t.CopyOutSignalStack(oldaddr, &alt); err != nil {
+ return 0, nil, err
+ }
+ }
+ if setaddr != 0 {
+ alt, err := t.CopyInSignalStack(setaddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // The signal stack cannot be changed if the task is currently
+ // on the stack. This is enforced at the lowest level because
+ // these semantics apply to changing the signal stack via a
+ // ucontext during a signal handler.
+ if !t.SetSignalStack(alt) {
+ return 0, nil, syserror.EPERM
+ }
+ }
+
+ return 0, nil, nil
+}
+
+// Pause implements linux syscall pause(2).
+func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.ConvertIntr(t.Block(nil), kernel.ERESTARTNOHAND)
+}
+
+// RtSigpending implements linux syscall rt_sigpending(2).
+func RtSigpending(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ pending := t.PendingSignals()
+ _, err := t.CopyOut(addr, pending)
+ return 0, nil, err
+}
+
+// RtSigtimedwait implements linux syscall rt_sigtimedwait(2).
+func RtSigtimedwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sigset := args[0].Pointer()
+ siginfo := args[1].Pointer()
+ timespec := args[2].Pointer()
+ sigsetsize := args[3].SizeT()
+
+ mask, err := copyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var timeout time.Duration
+ if timespec != 0 {
+ d, err := copyTimespecIn(t, timespec)
+ if err != nil {
+ return 0, nil, err
+ }
+ if !d.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(d.ToNsecCapped())
+ } else {
+ timeout = time.Duration(math.MaxInt64)
+ }
+
+ si, err := t.Sigtimedwait(mask, timeout)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if siginfo != 0 {
+ si.FixSignalCodeForUser()
+ if _, err := t.CopyOut(siginfo, si); err != nil {
+ return 0, nil, err
+ }
+ }
+ return uintptr(si.Signo), nil, nil
+}
+
+// RtSigqueueinfo implements linux syscall rt_sigqueueinfo(2).
+func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := kernel.ThreadID(args[0].Int())
+ sig := linux.Signal(args[1].Int())
+ infoAddr := args[2].Pointer()
+
+ // Copy in the info.
+ //
+ // We must ensure that the Signo is set (Linux overrides this in the
+ // same way), and that the code is in the allowed set. This same logic
+ // appears below in RtSigtgqueueinfo and should be kept in sync.
+ var info arch.SignalInfo
+ if _, err := t.CopyIn(infoAddr, &info); err != nil {
+ return 0, nil, err
+ }
+ info.Signo = int32(sig)
+
+ // This must loop to handle the race with execve described in Kill.
+ for {
+ // Deliver to the given task's thread group.
+ target := t.PIDNamespace().TaskWithID(pid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ // If the sender is not the receiver, it can't use si_codes used by the
+ // kernel or SI_TKILL.
+ if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t {
+ return 0, nil, syserror.EPERM
+ }
+
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+
+ if err := target.SendGroupSignal(&info); err != syserror.ESRCH {
+ return 0, nil, err
+ }
+ }
+}
+
+// RtTgsigqueueinfo implements linux syscall rt_tgsigqueueinfo(2).
+func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tgid := kernel.ThreadID(args[0].Int())
+ tid := kernel.ThreadID(args[1].Int())
+ sig := linux.Signal(args[2].Int())
+ infoAddr := args[3].Pointer()
+
+ // N.B. Inconsistent with man page, linux actually rejects calls with
+ // tgid/tid <=0 by EINVAL. This isn't the same for all signal calls.
+ if tgid <= 0 || tid <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy in the info. See RtSigqueueinfo above.
+ var info arch.SignalInfo
+ if _, err := t.CopyIn(infoAddr, &info); err != nil {
+ return 0, nil, err
+ }
+ info.Signo = int32(sig)
+
+ // Deliver to the given task.
+ targetTG := t.PIDNamespace().ThreadGroupWithID(tgid)
+ target := t.PIDNamespace().TaskWithID(tid)
+ if targetTG == nil || target == nil || target.ThreadGroup() != targetTG {
+ return 0, nil, syserror.ESRCH
+ }
+
+ // If the sender is not the receiver, it can't use si_codes used by the
+ // kernel or SI_TKILL.
+ if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t {
+ return 0, nil, syserror.EPERM
+ }
+
+ if !mayKill(t, target, sig) {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, target.SendSignal(&info)
+}
+
+// RtSigsuspend implements linux syscall rt_sigsuspend(2).
+func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sigset := args[0].Pointer()
+
+ // Copy in the signal mask.
+ var mask linux.SignalSet
+ if _, err := t.CopyIn(sigset, &mask); err != nil {
+ return 0, nil, err
+ }
+ mask &^= kernel.UnblockableSignals
+
+ // Swap the mask.
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+
+ // Perform the wait.
+ return 0, nil, syserror.ConvertIntr(t.Block(nil), kernel.ERESTARTNOHAND)
+}
+
+// RestartSyscall implements the linux syscall restart_syscall(2).
+func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ if r := t.SyscallRestartBlock(); r != nil {
+ n, err := r.Restart(t)
+ return n, nil, err
+ }
+ // The restart block should never be nil here, but it's possible
+ // ERESTART_RESTARTBLOCK was set by ptrace without the current syscall
+ // setting up a restart block. If ptrace didn't manipulate the return value,
+ // finding a nil restart block is a bug. Linux ensures that the restart
+ // function is never null by (re)initializing it with one that translates
+ // the restart into EINTR. We'll emulate that behaviour.
+ t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?")
+ return 0, nil, syserror.EINTR
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
new file mode 100644
index 000000000..8f4dbf3bc
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -0,0 +1,1117 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// minListenBacklog is the minimum reasonable backlog for listening sockets.
+const minListenBacklog = 8
+
+// maxListenBacklog is the maximum allowed backlog for listening sockets.
+const maxListenBacklog = 1024
+
+// maxAddrLen is the maximum socket address length we're willing to accept.
+const maxAddrLen = 200
+
+// maxOptLen is the maximum sockopt parameter length we're willing to accept.
+const maxOptLen = 1024
+
+// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
+// willing to accept. Note that this limit is smaller than Linux, which allows
+// buffers upto INT_MAX.
+const maxControlLen = 10 * 1024 * 1024
+
+// nameLenOffset is the offset from the start of the MessageHeader64 struct to
+// the NameLen field.
+const nameLenOffset = 8
+
+// controlLenOffset is the offset form the start of the MessageHeader64 struct
+// to the ControlLen field.
+const controlLenOffset = 40
+
+// flagsOffset is the offset form the start of the MessageHeader64 struct
+// to the Flags field.
+const flagsOffset = 48
+
+// messageHeader64Len is the length of a MessageHeader64 struct.
+var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+
+// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
+var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+
+// baseRecvFlags are the flags that are accepted across recvmsg(2),
+// recvmmsg(2), and recvfrom(2).
+const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC
+
+// MessageHeader64 is the 64-bit representation of the msghdr struct used in
+// the recvmsg and sendmsg syscalls.
+type MessageHeader64 struct {
+ // Name is the optional pointer to a network address buffer.
+ Name uint64
+
+ // NameLen is the length of the buffer pointed to by Name.
+ NameLen uint32
+ _ uint32
+
+ // Iov is a pointer to an array of io vectors that describe the memory
+ // locations involved in the io operation.
+ Iov uint64
+
+ // IovLen is the length of the array pointed to by Iov.
+ IovLen uint64
+
+ // Control is the optional pointer to ancillary control data.
+ Control uint64
+
+ // ControlLen is the length of the data pointed to by Control.
+ ControlLen uint64
+
+ // Flags on the sent/received message.
+ Flags int32
+ _ int32
+}
+
+// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
+// the recvmmsg and sendmmsg syscalls.
+type multipleMessageHeader64 struct {
+ msgHdr MessageHeader64
+ msgLen uint32
+ _ int32
+}
+
+// CopyInMessageHeader64 copies a message header from user to kernel memory.
+func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
+ b := t.CopyScratchBuffer(52)
+ if _, err := t.CopyInBytes(addr, b); err != nil {
+ return err
+ }
+
+ msg.Name = usermem.ByteOrder.Uint64(b[0:])
+ msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
+ msg.Iov = usermem.ByteOrder.Uint64(b[16:])
+ msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
+ msg.Control = usermem.ByteOrder.Uint64(b[32:])
+ msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
+ msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
+
+ return nil
+}
+
+// CaptureAddress allocates memory for and copies a socket address structure
+// from the untrusted address space range.
+func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
+ if addrlen > maxAddrLen {
+ return nil, syscall.EINVAL
+ }
+
+ addrBuf := make([]byte, addrlen)
+ if _, err := t.CopyInBytes(addr, addrBuf); err != nil {
+ return nil, err
+ }
+
+ return addrBuf, nil
+}
+
+// writeAddress writes a sockaddr structure and its length to an output buffer
+// in the unstrusted address space range. If the address is bigger than the
+// buffer, it is truncated.
+func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+ // Get the buffer length.
+ var bufLen uint32
+ if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ return err
+ }
+
+ if int32(bufLen) < 0 {
+ return syscall.EINVAL
+ }
+
+ // Write the length unconditionally.
+ if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ return err
+ }
+
+ if addr == nil {
+ return nil
+ }
+
+ if bufLen > addrLen {
+ bufLen = addrLen
+ }
+
+ // Copy as much of the address as will fit in the buffer.
+ encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ if bufLen > uint32(len(encodedAddr)) {
+ bufLen = uint32(len(encodedAddr))
+ }
+ _, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)])
+ return err
+}
+
+// Socket implements the linux syscall socket(2).
+func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ // Create the new socket.
+ s, e := socket.New(t, domain, transport.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ s.SetFlags(fs.SettableFileFlags{
+ NonBlocking: stype&linux.SOCK_NONBLOCK != 0,
+ })
+ defer s.DecRef()
+
+ fd, err := t.FDMap().NewFDFrom(0, s, kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ }, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// SocketPair implements the linux syscall socketpair(2).
+func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+ socks := args[3].Pointer()
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ fileFlags := fs.SettableFileFlags{
+ NonBlocking: stype&linux.SOCK_NONBLOCK != 0,
+ }
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ }
+
+ // Create the socket pair.
+ s1, s2, e := socket.Pair(t, domain, transport.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ s1.SetFlags(fileFlags)
+ s2.SetFlags(fileFlags)
+ defer s1.DecRef()
+ defer s2.DecRef()
+
+ // Create the FDs for the sockets.
+ fd1, err := t.FDMap().NewFDFrom(0, s1, fdFlags, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, err
+ }
+ fd2, err := t.FDMap().NewFDFrom(0, s2, fdFlags, t.ThreadGroup().Limits())
+ if err != nil {
+ t.FDMap().Remove(fd1)
+ return 0, nil, err
+ }
+
+ // Copy the file descriptors out.
+ if _, err := t.CopyOut(socks, []int32{int32(fd1), int32(fd2)}); err != nil {
+ t.FDMap().Remove(fd1)
+ t.FDMap().Remove(fd2)
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// Connect implements the linux syscall connect(2).
+func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ blocking := !file.Flags().NonBlocking
+ return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS)
+}
+
+// accept is the implementation of the accept syscall. It is called by accept
+// and accept4 syscall handlers.
+func accept(t *kernel.Task, fd kdefs.FD, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) {
+ // Check that no unsupported flags are passed in.
+ if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, syscall.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, syscall.ENOTSOCK
+ }
+
+ // Call the syscall implementation for this socket, then copy the
+ // output address if one is specified.
+ blocking := !file.Flags().NonBlocking
+
+ peerRequested := addrLen != 0
+ nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ if peerRequested {
+ // NOTE(magi): Linux does not give you an error if it can't
+ // write the data back out so neither do we.
+ if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syscall.EINVAL {
+ return 0, err
+ }
+ }
+ return uintptr(nfd), nil
+}
+
+// Accept4 implements the linux syscall accept4(2).
+func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+ flags := int(args[3].Int())
+
+ n, err := accept(t, fd, addr, addrlen, flags)
+ return n, nil, err
+}
+
+// Accept implements the linux syscall accept(2).
+func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ n, err := accept(t, fd, addr, addrlen, 0)
+ return n, nil, err
+}
+
+// Bind implements the linux syscall bind(2).
+func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, s.Bind(t, a).ToError()
+}
+
+// Listen implements the linux syscall listen(2).
+func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ backlog := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Per Linux, the backlog is silently capped to reasonable values.
+ if backlog <= 0 {
+ backlog = minListenBacklog
+ }
+ if backlog > maxListenBacklog {
+ backlog = maxListenBacklog
+ }
+
+ return 0, nil, s.Listen(t, int(backlog)).ToError()
+}
+
+// Shutdown implements the linux syscall shutdown(2).
+func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ how := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Validate how, then call syscall implementation.
+ switch how {
+ case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
+ default:
+ return 0, nil, syscall.EINVAL
+ }
+
+ return 0, nil, s.Shutdown(t, int(how)).ToError()
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2).
+func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLenAddr := args[4].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Read the length if present. Reject negative values.
+ optLen := int32(0)
+ if optLenAddr != 0 {
+ if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ return 0, nil, err
+ }
+
+ if optLen < 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ }
+
+ // Call syscall implementation then copy both value and value len out.
+ v, e := s.GetSockOpt(t, int(level), int(name), int(optLen))
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+
+ if optLenAddr != 0 {
+ vLen := int32(binary.Size(v))
+ if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if v != nil {
+ if _, err := t.CopyOut(optValAddr, v); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return 0, nil, nil
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2).
+//
+// Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket.
+func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLen := args[4].Int()
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ if optLen <= 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ if optLen > maxOptLen {
+ return 0, nil, syscall.EINVAL
+ }
+ buf := t.CopyScratchBuffer(int(optLen))
+ if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ return 0, nil, err
+ }
+
+ // Call syscall implementation.
+ if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2).
+func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Get the socket name and copy it to the caller.
+ v, vl, err := s.GetSockName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// GetPeerName implements the linux syscall getpeername(2).
+func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Get the socket peer name and copy it to the caller.
+ v, vl, err := s.GetPeerName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// RecvMsg implements the linux syscall recvmsg(2).
+func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syscall.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline)
+ return n, nil, err
+}
+
+// RecvMMsg implements the linux syscall recvmmsg(2).
+func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+ toPtr := args[4].Pointer()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syscall.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if toPtr != 0 {
+ ts, err := copyTimespecIn(t, toPtr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if !ts.Valid() {
+ return 0, nil, syscall.EINVAL
+ }
+ deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
+ haveDeadline = true
+ }
+
+ if !haveDeadline {
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syscall.EFAULT
+ }
+ var n uintptr
+ if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syscall.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
+ // Capture the message header and io vectors.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syscall.EMSGSIZE
+ }
+ dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ // FIXME(b/63594852): Pretend we have an empty error queue.
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return 0, syscall.EAGAIN
+ }
+
+ // Fast path when no control message nor name buffers are provided.
+ if msg.ControlLen == 0 && msg.NameLen == 0 {
+ n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
+ if err != nil {
+ return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS)
+ }
+ if !cms.Unix.Empty() {
+ mflags |= linux.MSG_CTRUNC
+ cms.Unix.Release()
+ }
+
+ if int(msg.Flags) != mflags {
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+ }
+
+ if msg.ControlLen > maxControlLen {
+ return 0, syscall.ENOBUFS
+ }
+ n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ defer cms.Unix.Release()
+
+ controlData := make([]byte, 0, msg.ControlLen)
+
+ if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
+ creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
+ controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
+ }
+
+ if cms.IP.HasTimestamp {
+ controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData)
+ }
+
+ if cms.Unix.Rights != nil {
+ controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
+ }
+
+ // Copy the address to the caller.
+ if msg.NameLen != 0 {
+ if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy the control data to the caller.
+ if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ return 0, err
+ }
+ if len(controlData) > 0 {
+ if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+
+ return uintptr(n), nil
+}
+
+// recvFrom is the implementation of the recvfrom syscall. It is called by
+// recvfrom and recv syscall handlers.
+func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) {
+ if int(bufLen) < 0 {
+ return 0, syscall.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
+ return 0, syscall.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, syscall.ENOTSOCK
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
+ cm.Unix.Release()
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+
+ // Copy the address to the caller.
+ if nameLenPtr != 0 {
+ if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+}
+
+// RecvFrom implements the linux syscall recvfrom(2).
+func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLenPtr := args[5].Pointer()
+
+ n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr)
+ return n, nil, err
+}
+
+// SendMsg implements the linux syscall sendmsg(2).
+func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syscall.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := sendSingleMsg(t, s, file, msgPtr, flags)
+ return n, nil, err
+}
+
+// SendMMsg implements the linux syscall sendmmsg(2).
+func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syscall.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, nil, syscall.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syscall.EFAULT
+ }
+ var n uintptr
+ if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syscall.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) {
+ // Capture the message header.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ var controlData []byte
+ if msg.ControlLen > 0 {
+ // Put an upper bound to prevent large allocations.
+ if msg.ControlLen > maxControlLen {
+ return 0, syscall.ENOBUFS
+ }
+ controlData = make([]byte, msg.ControlLen)
+ if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ if msg.NameLen != 0 {
+ var err error
+ to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ // Read data then call the sendmsg implementation.
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syscall.EMSGSIZE
+ }
+ src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ controlMessages, err := control.Parse(t, s, controlData)
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: controlMessages})
+ err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
+ if err != nil {
+ controlMessages.Release()
+ }
+ return uintptr(n), err
+}
+
+// sendTo is the implementation of the sendto syscall. It is called by sendto
+// and send syscall handlers.
+func sendTo(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) {
+ bl := int(bufLen)
+ if bl < 0 {
+ return 0, syscall.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, syscall.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.FileOperations.(socket.Socket)
+ if !ok {
+ return 0, syscall.ENOTSOCK
+ }
+
+ if file.Flags().NonBlocking {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ var err error
+ if namePtr != 0 {
+ to, err = CaptureAddress(t, namePtr, nameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
+ return uintptr(n), handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file)
+}
+
+// SendTo implements the linux syscall sendto(2).
+func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLen := args[5].Uint()
+
+ n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
+ return n, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
new file mode 100644
index 000000000..37303606f
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -0,0 +1,293 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// doSplice implements a blocking splice operation.
+func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) {
+ var (
+ total int64
+ n int64
+ err error
+ ch chan struct{}
+ inW bool
+ outW bool
+ )
+ for opts.Length > 0 {
+ n, err = fs.Splice(t, outFile, inFile, opts)
+ opts.Length -= n
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ } else if err == syserror.ErrWouldBlock && nonBlocking {
+ break
+ }
+
+ // Are we a registered waiter?
+ if ch == nil {
+ ch = make(chan struct{}, 1)
+ }
+ if !inW && inFile.Readiness(EventMaskRead) == 0 && !inFile.Flags().NonBlocking {
+ w, _ := waiter.NewChannelEntry(ch)
+ inFile.EventRegister(&w, EventMaskRead)
+ defer inFile.EventUnregister(&w)
+ inW = true // Registered.
+ } else if !outW && outFile.Readiness(EventMaskWrite) == 0 && !outFile.Flags().NonBlocking {
+ w, _ := waiter.NewChannelEntry(ch)
+ outFile.EventRegister(&w, EventMaskWrite)
+ defer outFile.EventUnregister(&w)
+ outW = true // Registered.
+ }
+
+ // Was anything registered? If no, everything is non-blocking.
+ if !inW && !outW {
+ break
+ }
+
+ // Block until there's data.
+ if err = t.Block(ch); err != nil {
+ break
+ }
+ }
+
+ return total, err
+}
+
+// Sendfile implements linux system call sendfile(2).
+func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ outFD := kdefs.FD(args[0].Int())
+ inFD := kdefs.FD(args[1].Int())
+ offsetAddr := args[2].Pointer()
+ count := int64(args[3].SizeT())
+
+ // Don't send a negative number of bytes.
+ if count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get files.
+ outFile := t.FDMap().GetFile(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ inFile := t.FDMap().GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ // Verify that the outfile Append flag is not set. Note that fs.Splice
+ // itself validates that the output file is writable.
+ if outFile.Flags().Append {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Verify that we have a regular infile. This is a requirement; the
+ // same check appears in Linux (fs/splice.c:splice_direct_to_actor).
+ if !fs.IsRegular(inFile.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var (
+ n int64
+ err error
+ )
+ if offsetAddr != 0 {
+ // Verify that when offset address is not null, infile must be
+ // seekable. The fs.Splice routine itself validates basic read.
+ if !inFile.Flags().Pread {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Copy in the offset.
+ var offset int64
+ if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
+ return 0, nil, err
+ }
+
+ // The offset must be valid.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Do the splice.
+ n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
+ Length: count,
+ SrcOffset: true,
+ SrcStart: offset,
+ }, false)
+
+ // Copy out the new offset.
+ if _, err := t.CopyOut(offsetAddr, n+offset); err != nil {
+ return 0, nil, err
+ }
+ } else {
+ // Send data using splice.
+ n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
+ Length: count,
+ }, false)
+ }
+
+ // We can only pass a single file to handleIOError, so pick inFile
+ // arbitrarily. This is used only for debugging purposes.
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "sendfile", inFile)
+}
+
+// Splice implements splice(2).
+func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := kdefs.FD(args[0].Int())
+ inOffset := args[1].Pointer()
+ outFD := kdefs.FD(args[2].Int())
+ outOffset := args[3].Pointer()
+ count := int64(args[4].SizeT())
+ flags := args[5].Int()
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Only non-blocking is meaningful. Note that unlike in Linux, this
+ // flag is applied consistently. We will have either fully blocking or
+ // non-blocking behavior below, regardless of the underlying files
+ // being spliced to. It's unclear if this is a bug or not yet.
+ nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
+
+ // Get files.
+ outFile := t.FDMap().GetFile(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ inFile := t.FDMap().GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ // Construct our options.
+ //
+ // Note that exactly one of the underlying buffers must be a pipe. We
+ // don't actually have this constraint internally, but we enforce it
+ // for the semantics of the call.
+ opts := fs.SpliceOpts{
+ Length: count,
+ }
+ switch {
+ case fs.IsPipe(inFile.Dirent.Inode.StableAttr) && !fs.IsPipe(outFile.Dirent.Inode.StableAttr):
+ if inOffset != 0 {
+ return 0, nil, syserror.ESPIPE
+ }
+ if outOffset != 0 {
+ var offset int64
+ if _, err := t.CopyIn(outOffset, &offset); err != nil {
+ return 0, nil, err
+ }
+ // Use the destination offset.
+ opts.DstOffset = true
+ opts.DstStart = offset
+ }
+ case !fs.IsPipe(inFile.Dirent.Inode.StableAttr) && fs.IsPipe(outFile.Dirent.Inode.StableAttr):
+ if outOffset != 0 {
+ return 0, nil, syserror.ESPIPE
+ }
+ if inOffset != 0 {
+ var offset int64
+ if _, err := t.CopyIn(inOffset, &offset); err != nil {
+ return 0, nil, err
+ }
+ // Use the source offset.
+ opts.SrcOffset = true
+ opts.SrcStart = offset
+ }
+ case fs.IsPipe(inFile.Dirent.Inode.StableAttr) && fs.IsPipe(outFile.Dirent.Inode.StableAttr):
+ if inOffset != 0 || outOffset != 0 {
+ return 0, nil, syserror.ESPIPE
+ }
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // We may not refer to the same pipe; otherwise it's a continuous loop.
+ if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Splice data.
+ n, err := doSplice(t, outFile, inFile, opts, nonBlocking)
+
+ // See above; inFile is chosen arbitrarily here.
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile)
+}
+
+// Tee imlements tee(2).
+func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := kdefs.FD(args[0].Int())
+ outFD := kdefs.FD(args[1].Int())
+ count := int64(args[2].SizeT())
+ flags := args[3].Int()
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Only non-blocking is meaningful.
+ nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
+
+ // Get files.
+ outFile := t.FDMap().GetFile(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ inFile := t.FDMap().GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ // All files must be pipes.
+ if !fs.IsPipe(inFile.Dirent.Inode.StableAttr) || !fs.IsPipe(outFile.Dirent.Inode.StableAttr) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // We may not refer to the same pipe; see above.
+ if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Splice data.
+ n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{
+ Length: count,
+ Dup: true,
+ }, nonBlocking)
+
+ // See above; inFile is chosen arbitrarily here.
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
new file mode 100644
index 000000000..10fc201ef
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -0,0 +1,259 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Stat implements linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ return stat(t, d, dirPath, statAddr)
+ })
+}
+
+// Fstatat implements linux syscall newfstatat, i.e. fstatat(2).
+func Fstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ path, dirPath, err := copyInPath(t, addr, flags&linux.AT_EMPTY_PATH != 0)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if path == "" {
+ // Annoying. What's wrong with fstat?
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, fstat(t, file, statAddr)
+ }
+
+ // If the path ends in a slash (i.e. dirPath is true) or if AT_SYMLINK_NOFOLLOW is unset,
+ // then we must resolve the final component.
+ resolve := dirPath || flags&linux.AT_SYMLINK_NOFOLLOW == 0
+
+ return 0, nil, fileOpOn(t, fd, path, resolve, func(root *fs.Dirent, d *fs.Dirent) error {
+ return stat(t, d, dirPath, statAddr)
+ })
+}
+
+// Lstat implements linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+
+ path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // If the path ends in a slash (i.e. dirPath is true), then we *do*
+ // want to resolve the final component.
+ resolve := dirPath
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolve, func(root *fs.Dirent, d *fs.Dirent) error {
+ return stat(t, d, dirPath, statAddr)
+ })
+}
+
+// Fstat implements linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ statAddr := args[1].Pointer()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, fstat(t, file, statAddr)
+}
+
+// stat implements stat from the given *fs.Dirent.
+func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+ uattr, err := d.Inode.UnstableAttr(t)
+ if err != nil {
+ return err
+ }
+ return copyOutStat(t, statAddr, d.Inode.StableAttr, uattr)
+}
+
+// fstat implements fstat for the given *fs.File.
+func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
+ uattr, err := f.UnstableAttr(t)
+ if err != nil {
+ return err
+ }
+ return copyOutStat(t, statAddr, f.Dirent.Inode.StableAttr, uattr)
+}
+
+// copyOutStat copies the attributes (sattr, uattr) to the struct stat at
+// address dst in t's address space. It encodes the stat struct to bytes
+// manually, as stat() is a very common syscall for many applications, and
+// t.CopyObjectOut has noticeable performance impact due to its many slice
+// allocations and use of reflection.
+func copyOutStat(t *kernel.Task, dst usermem.Addr, sattr fs.StableAttr, uattr fs.UnstableAttr) error {
+ var mode uint32
+ switch sattr.Type {
+ case fs.RegularFile, fs.SpecialFile:
+ mode |= linux.ModeRegular
+ case fs.Symlink:
+ mode |= linux.ModeSymlink
+ case fs.Directory, fs.SpecialDirectory:
+ mode |= linux.ModeDirectory
+ case fs.Pipe:
+ mode |= linux.ModeNamedPipe
+ case fs.CharacterDevice:
+ mode |= linux.ModeCharacterDevice
+ case fs.BlockDevice:
+ mode |= linux.ModeBlockDevice
+ case fs.Socket:
+ mode |= linux.ModeSocket
+ }
+
+ b := t.CopyScratchBuffer(int(linux.SizeOfStat))[:0]
+
+ // Dev (uint64)
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.DeviceID))
+ // Ino (uint64)
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.InodeID))
+ // Nlink (uint64)
+ b = binary.AppendUint64(b, usermem.ByteOrder, uattr.Links)
+ // Mode (uint32)
+ b = binary.AppendUint32(b, usermem.ByteOrder, mode|uint32(uattr.Perms.LinuxMode()))
+ // UID (uint32)
+ b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
+ // GID (uint32)
+ b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()))
+ // Padding (uint32)
+ b = binary.AppendUint32(b, usermem.ByteOrder, 0)
+ // Rdev (uint64)
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)))
+ // Size (uint64)
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Size))
+ // Blksize (uint64)
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.BlockSize))
+ // Blocks (uint64)
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Usage/512))
+
+ // ATime
+ atime := uattr.AccessTime.Timespec()
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Sec))
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Nsec))
+
+ // MTime
+ mtime := uattr.ModificationTime.Timespec()
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Sec))
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Nsec))
+
+ // CTime
+ ctime := uattr.StatusChangeTime.Timespec()
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Sec))
+ b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Nsec))
+
+ _, err := t.CopyOutBytes(dst, b)
+ return err
+}
+
+// Statfs implements linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ statfsAddr := args[1].Pointer()
+
+ path, _, err := copyInPath(t, addr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error {
+ return statfsImpl(t, d, statfsAddr)
+ })
+}
+
+// Fstatfs implements linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ statfsAddr := args[1].Pointer()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, statfsImpl(t, file.Dirent, statfsAddr)
+}
+
+// statfsImpl implements the linux syscall statfs and fstatfs based on a Dirent,
+// copying the statfs structure out to addr on success, otherwise an error is
+// returned.
+func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
+ info, err := d.Inode.StatFS(t)
+ if err != nil {
+ return err
+ }
+ // Construct the statfs structure and copy it out.
+ statfs := linux.Statfs{
+ Type: info.Type,
+ // Treat block size and fragment size as the same, as
+ // most consumers of this structure will expect one
+ // or the other to be filled in.
+ BlockSize: d.Inode.StableAttr.BlockSize,
+ Blocks: info.TotalBlocks,
+ // We don't have the concept of reserved blocks, so
+ // report blocks free the same as available blocks.
+ // This is a normal thing for filesystems, to do, see
+ // udf, hugetlbfs, tmpfs, among others.
+ BlocksFree: info.FreeBlocks,
+ BlocksAvailable: info.FreeBlocks,
+ Files: info.TotalFiles,
+ FilesFree: info.FreeFiles,
+ // Same as Linux for simple_statfs, see fs/libfs.c.
+ NameLength: linux.NAME_MAX,
+ FragmentSize: d.Inode.StableAttr.BlockSize,
+ // Leave other fields 0 like simple_statfs does.
+ }
+ if _, err := t.CopyOut(addr, &statfs); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
new file mode 100644
index 000000000..4352482fb
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -0,0 +1,138 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Sync implements linux system call sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ t.MountNamespace().SyncAll(t)
+ // Sync is always successful.
+ return 0, nil, nil
+}
+
+// Syncfs implements linux system call syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Use "sync-the-world" for now, it's guaranteed that fd is at least
+ // on the root filesystem.
+ return Sync(t, args)
+}
+
+// Fsync implements linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll)
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// Fdatasync implements linux syscall fdatasync(2).
+//
+// At the moment, it just calls Fsync, which is a big hammer, but correct.
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData)
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
+
+// SyncFileRange implements linux syscall sync_file_rage(2)
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ var err error
+
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ uflags := args[3].Uint()
+
+ if offset < 0 || offset+nbytes < offset {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if uflags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|
+ linux.SYNC_FILE_RANGE_WRITE|
+ linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if nbytes == 0 {
+ nbytes = fs.FileMaxOffset
+ }
+
+ fd := kdefs.FD(args[0].Int())
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // SYNC_FILE_RANGE_WAIT_BEFORE waits upon write-out of all pages in the
+ // specified range that have already been submitted to the device
+ // driver for write-out before performing any write.
+ if uflags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 &&
+ uflags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ // SYNC_FILE_RANGE_WRITE initiates write-out of all dirty pages in the
+ // specified range which are not presently submitted write-out.
+ //
+ // It looks impossible to implement this functionality without a
+ // massive rework of the vfs subsystem. file.Fsync() take a file lock
+ // for the entire operation, so even if it is running in a go routing,
+ // it blocks other file operations instead of flushing data in the
+ // background.
+ //
+ // It should be safe to skipped this flag while nobody uses
+ // SYNC_FILE_RANGE_WAIT_BEFORE.
+
+ // SYNC_FILE_RANGE_WAIT_AFTER waits upon write-out of all pages in the
+ // range after performing any write.
+ //
+ // In Linux, sync_file_range() doesn't writes out the file's
+ // meta-data, but fdatasync() does if a file size is changed.
+ if uflags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 {
+ err = file.Fsync(t, offset, fs.FileMaxOffset, fs.SyncData)
+ }
+
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
new file mode 100644
index 000000000..ecf88edc1
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -0,0 +1,43 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usage"
+)
+
+// Sysinfo implements the sysinfo syscall as described in man 2 sysinfo.
+func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ mf := t.Kernel().MemoryFile()
+ mf.UpdateUsage()
+ _, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+
+ // Only a subset of the fields in sysinfo_t make sense to return.
+ si := linux.Sysinfo{
+ Procs: uint16(len(t.PIDNamespace().Tasks())),
+ Uptime: t.Kernel().MonotonicClock().Now().Seconds(),
+ TotalRAM: totalSize,
+ FreeRAM: totalSize - totalUsage,
+ Unit: 1,
+ }
+ _, err := t.CopyOut(addr, si)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_syslog.go b/pkg/sentry/syscalls/linux/sys_syslog.go
new file mode 100644
index 000000000..9efc58d34
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_syslog.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const (
+ _SYSLOG_ACTION_READ_ALL = 3
+ _SYSLOG_ACTION_SIZE_BUFFER = 10
+)
+
+// logBufLen is the default syslog buffer size on Linux.
+const logBufLen = 1 << 17
+
+// Syslog implements part of Linux syscall syslog.
+//
+// Only the unpriviledged commands are implemented, allowing applications to
+// read a fun dmesg.
+func Syslog(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ command := args[0].Int()
+ buf := args[1].Pointer()
+ size := int(args[2].Int())
+
+ switch command {
+ case _SYSLOG_ACTION_READ_ALL:
+ if size < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if size > logBufLen {
+ size = logBufLen
+ }
+
+ log := t.Kernel().Syslog().Log()
+ if len(log) > size {
+ log = log[:size]
+ }
+
+ n, err := t.CopyOutBytes(buf, log)
+ return uintptr(n), nil, err
+ case _SYSLOG_ACTION_SIZE_BUFFER:
+ return logBufLen, nil, nil
+ default:
+ return 0, nil, syserror.ENOSYS
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
new file mode 100644
index 000000000..26f7e8ead
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -0,0 +1,706 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const (
+ // ExecMaxTotalSize is the maximum length of all argv and envv entries.
+ //
+ // N.B. The behavior here is different than Linux. Linux provides a limit on
+ // individual arguments of 32 pages, and an aggregate limit of at least 32 pages
+ // but otherwise bounded by min(stack size / 4, 8 MB * 3 / 4). We don't implement
+ // any behavior based on the stack size, and instead provide a fixed hard-limit of
+ // 2 MB (which should work well given that 8 MB stack limits are common).
+ ExecMaxTotalSize = 2 * 1024 * 1024
+
+ // ExecMaxElemSize is the maximum length of a single argv or envv entry.
+ ExecMaxElemSize = 32 * usermem.PageSize
+
+ // exitSignalMask is the signal mask to be sent at exit. Same as CSIGNAL in linux.
+ exitSignalMask = 0xff
+)
+
+// Getppid implements linux syscall getppid(2).
+func Getppid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ parent := t.Parent()
+ if parent == nil {
+ return 0, nil, nil
+ }
+ return uintptr(t.PIDNamespace().IDOfThreadGroup(parent.ThreadGroup())), nil, nil
+}
+
+// Getpid implements linux syscall getpid(2).
+func Getpid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return uintptr(t.ThreadGroup().ID()), nil, nil
+}
+
+// Gettid implements linux syscall gettid(2).
+func Gettid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return uintptr(t.ThreadID()), nil, nil
+}
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ filenameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+
+ // Extract our arguments.
+ filename, err := t.CopyInString(filenameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, ExecMaxElemSize, ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, ExecMaxElemSize, ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+ wd := t.FSContext().WorkingDirectory()
+ defer wd.DecRef()
+
+ // Load the new TaskContext.
+ maxTraversals := uint(linux.MaxSymlinkTraversals)
+ tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, argv, envv, t.Arch().FeatureSet())
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
+
+// Exit implements linux syscall exit(2).
+func Exit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ status := int(args[0].Int())
+ t.PrepareExit(kernel.ExitStatus{Code: status})
+ return 0, kernel.CtrlDoExit, nil
+}
+
+// ExitGroup implements linux syscall exit_group(2).
+func ExitGroup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ status := int(args[0].Int())
+ t.PrepareGroupExit(kernel.ExitStatus{Code: status})
+ return 0, kernel.CtrlDoExit, nil
+}
+
+// clone is used by Clone, Fork, and VFork.
+func clone(t *kernel.Task, flags int, stack usermem.Addr, parentTID usermem.Addr, childTID usermem.Addr, tls usermem.Addr) (uintptr, *kernel.SyscallControl, error) {
+ opts := kernel.CloneOptions{
+ SharingOptions: kernel.SharingOptions{
+ NewAddressSpace: flags&syscall.CLONE_VM == 0,
+ NewSignalHandlers: flags&syscall.CLONE_SIGHAND == 0,
+ NewThreadGroup: flags&syscall.CLONE_THREAD == 0,
+ TerminationSignal: linux.Signal(flags & exitSignalMask),
+ NewPIDNamespace: flags&syscall.CLONE_NEWPID == syscall.CLONE_NEWPID,
+ NewUserNamespace: flags&syscall.CLONE_NEWUSER == syscall.CLONE_NEWUSER,
+ NewNetworkNamespace: flags&syscall.CLONE_NEWNET == syscall.CLONE_NEWNET,
+ NewFiles: flags&syscall.CLONE_FILES == 0,
+ NewFSContext: flags&syscall.CLONE_FS == 0,
+ NewUTSNamespace: flags&syscall.CLONE_NEWUTS == syscall.CLONE_NEWUTS,
+ NewIPCNamespace: flags&syscall.CLONE_NEWIPC == syscall.CLONE_NEWIPC,
+ },
+ Stack: stack,
+ SetTLS: flags&syscall.CLONE_SETTLS == syscall.CLONE_SETTLS,
+ TLS: tls,
+ ChildClearTID: flags&syscall.CLONE_CHILD_CLEARTID == syscall.CLONE_CHILD_CLEARTID,
+ ChildSetTID: flags&syscall.CLONE_CHILD_SETTID == syscall.CLONE_CHILD_SETTID,
+ ChildTID: childTID,
+ ParentSetTID: flags&syscall.CLONE_PARENT_SETTID == syscall.CLONE_PARENT_SETTID,
+ ParentTID: parentTID,
+ Vfork: flags&syscall.CLONE_VFORK == syscall.CLONE_VFORK,
+ Untraced: flags&syscall.CLONE_UNTRACED == syscall.CLONE_UNTRACED,
+ InheritTracer: flags&syscall.CLONE_PTRACE == syscall.CLONE_PTRACE,
+ }
+ ntid, ctrl, err := t.Clone(&opts)
+ return uintptr(ntid), ctrl, err
+}
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors. We implement the default one in linux 3.11
+// x86_64:
+// sys_clone(clone_flags, newsp, parent_tidptr, child_tidptr, tls_val)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ childTID := args[3].Pointer()
+ tls := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
+
+// Fork implements Linux syscall fork(2).
+func Fork(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // "A call to fork() is equivalent to a call to clone(2) specifying flags
+ // as just SIGCHLD." - fork(2)
+ return clone(t, int(syscall.SIGCHLD), 0, 0, 0, 0)
+}
+
+// Vfork implements Linux syscall vfork(2).
+func Vfork(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // """
+ // A call to vfork() is equivalent to calling clone(2) with flags specified as:
+ //
+ // CLONE_VM | CLONE_VFORK | SIGCHLD
+ // """ - vfork(2)
+ return clone(t, syscall.CLONE_VM|syscall.CLONE_VFORK|int(syscall.SIGCHLD), 0, 0, 0, 0)
+}
+
+// parseCommonWaitOptions applies the options common to wait4 and waitid to
+// wopts.
+func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error {
+ switch options & (linux.WCLONE | linux.WALL) {
+ case 0:
+ wopts.NonCloneTasks = true
+ case linux.WCLONE:
+ wopts.CloneTasks = true
+ case linux.WALL:
+ wopts.NonCloneTasks = true
+ wopts.CloneTasks = true
+ default:
+ return syscall.EINVAL
+ }
+ if options&linux.WCONTINUED != 0 {
+ wopts.Events |= kernel.EventGroupContinue
+ }
+ if options&linux.WNOHANG == 0 {
+ wopts.BlockInterruptErr = kernel.ERESTARTSYS
+ }
+ if options&linux.WNOTHREAD == 0 {
+ wopts.SiblingChildren = true
+ }
+ return nil
+}
+
+// wait4 waits for the given child process to exit.
+func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusageAddr usermem.Addr) (uintptr, error) {
+ if options&^(linux.WNOHANG|linux.WUNTRACED|linux.WCONTINUED|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 {
+ return 0, syscall.EINVAL
+ }
+ wopts := kernel.WaitOptions{
+ Events: kernel.EventExit | kernel.EventTraceeStop,
+ ConsumeEvent: true,
+ }
+ // There are four cases to consider:
+ //
+ // pid < -1 any child process whose process group ID is equal to the absolute value of pid
+ // pid == -1 any child process
+ // pid == 0 any child process whose process group ID is equal to that of the calling process
+ // pid > 0 the child whose process ID is equal to the value of pid
+ switch {
+ case pid < -1:
+ wopts.SpecificPGID = kernel.ProcessGroupID(-pid)
+ case pid == -1:
+ // Any process is the default.
+ case pid == 0:
+ wopts.SpecificPGID = t.PIDNamespace().IDOfProcessGroup(t.ThreadGroup().ProcessGroup())
+ default:
+ wopts.SpecificTID = kernel.ThreadID(pid)
+ }
+
+ if err := parseCommonWaitOptions(&wopts, options); err != nil {
+ return 0, err
+ }
+ if options&linux.WUNTRACED != 0 {
+ wopts.Events |= kernel.EventChildGroupStop
+ }
+
+ wr, err := t.Wait(&wopts)
+ if err != nil {
+ if err == kernel.ErrNoWaitableEvent {
+ return 0, nil
+ }
+ return 0, err
+ }
+ if statusAddr != 0 {
+ if _, err := t.CopyOut(statusAddr, wr.Status); err != nil {
+ return 0, err
+ }
+ }
+ if rusageAddr != 0 {
+ ru := getrusage(wr.Task, linux.RUSAGE_BOTH)
+ if _, err := t.CopyOut(rusageAddr, &ru); err != nil {
+ return 0, err
+ }
+ }
+ return uintptr(wr.TID), nil
+}
+
+// Wait4 implements linux syscall wait4(2).
+func Wait4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := int(args[0].Int())
+ statusAddr := args[1].Pointer()
+ options := int(args[2].Uint())
+ rusageAddr := args[3].Pointer()
+
+ n, err := wait4(t, pid, statusAddr, options, rusageAddr)
+ return n, nil, err
+}
+
+// WaitPid implements linux syscall waitpid(2).
+func WaitPid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pid := int(args[0].Int())
+ statusAddr := args[1].Pointer()
+ options := int(args[2].Uint())
+
+ n, err := wait4(t, pid, statusAddr, options, 0)
+ return n, nil, err
+}
+
+// Waitid implements linux syscall waitid(2).
+func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ idtype := args[0].Int()
+ id := args[1].Int()
+ infop := args[2].Pointer()
+ options := int(args[3].Uint())
+ rusageAddr := args[4].Pointer()
+
+ if options&^(linux.WNOHANG|linux.WEXITED|linux.WSTOPPED|linux.WCONTINUED|linux.WNOWAIT|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ if options&(linux.WEXITED|linux.WSTOPPED|linux.WCONTINUED) == 0 {
+ return 0, nil, syscall.EINVAL
+ }
+ wopts := kernel.WaitOptions{
+ Events: kernel.EventTraceeStop,
+ ConsumeEvent: options&linux.WNOWAIT == 0,
+ }
+ switch idtype {
+ case linux.P_ALL:
+ case linux.P_PID:
+ wopts.SpecificTID = kernel.ThreadID(id)
+ case linux.P_PGID:
+ wopts.SpecificPGID = kernel.ProcessGroupID(id)
+ default:
+ return 0, nil, syscall.EINVAL
+ }
+
+ if err := parseCommonWaitOptions(&wopts, options); err != nil {
+ return 0, nil, err
+ }
+ if options&linux.WEXITED != 0 {
+ wopts.Events |= kernel.EventExit
+ }
+ if options&linux.WSTOPPED != 0 {
+ wopts.Events |= kernel.EventChildGroupStop
+ }
+
+ wr, err := t.Wait(&wopts)
+ if err != nil {
+ if err == kernel.ErrNoWaitableEvent {
+ err = nil
+ // "If WNOHANG was specified in options and there were no children
+ // in a waitable state, then waitid() returns 0 immediately and the
+ // state of the siginfo_t structure pointed to by infop is
+ // unspecified." - waitid(2). But Linux's waitid actually zeroes
+ // out the fields it would set for a successful waitid in this case
+ // as well.
+ if infop != 0 {
+ var si arch.SignalInfo
+ _, err = t.CopyOut(infop, &si)
+ }
+ }
+ return 0, nil, err
+ }
+ if rusageAddr != 0 {
+ ru := getrusage(wr.Task, linux.RUSAGE_BOTH)
+ if _, err := t.CopyOut(rusageAddr, &ru); err != nil {
+ return 0, nil, err
+ }
+ }
+ if infop == 0 {
+ return 0, nil, nil
+ }
+ si := arch.SignalInfo{
+ Signo: int32(syscall.SIGCHLD),
+ }
+ si.SetPid(int32(wr.TID))
+ si.SetUid(int32(wr.UID))
+ // TODO(b/73541790): convert kernel.ExitStatus to functions and make
+ // WaitResult.Status a linux.WaitStatus
+ s := syscall.WaitStatus(wr.Status)
+ switch {
+ case s.Exited():
+ si.Code = arch.CLD_EXITED
+ si.SetStatus(int32(s.ExitStatus()))
+ case s.Signaled():
+ si.Code = arch.CLD_KILLED
+ si.SetStatus(int32(s.Signal()))
+ case s.CoreDump():
+ si.Code = arch.CLD_DUMPED
+ si.SetStatus(int32(s.Signal()))
+ case s.Stopped():
+ if wr.Event == kernel.EventTraceeStop {
+ si.Code = arch.CLD_TRAPPED
+ si.SetStatus(int32(s.TrapCause()))
+ } else {
+ si.Code = arch.CLD_STOPPED
+ si.SetStatus(int32(s.StopSignal()))
+ }
+ case s.Continued():
+ si.Code = arch.CLD_CONTINUED
+ si.SetStatus(int32(syscall.SIGCONT))
+ default:
+ t.Warningf("waitid got incomprehensible wait status %d", s)
+ }
+ _, err = t.CopyOut(infop, &si)
+ return 0, nil, err
+}
+
+// SetTidAddress implements linux syscall set_tid_address(2).
+func SetTidAddress(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ // Always succeed, return caller's tid.
+ t.SetClearTID(addr)
+ return uintptr(t.ThreadID()), nil, nil
+}
+
+// Unshare implements linux syscall unshare(2).
+func Unshare(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ opts := kernel.SharingOptions{
+ NewAddressSpace: flags&syscall.CLONE_VM == syscall.CLONE_VM,
+ NewSignalHandlers: flags&syscall.CLONE_SIGHAND == syscall.CLONE_SIGHAND,
+ NewThreadGroup: flags&syscall.CLONE_THREAD == syscall.CLONE_THREAD,
+ NewPIDNamespace: flags&syscall.CLONE_NEWPID == syscall.CLONE_NEWPID,
+ NewUserNamespace: flags&syscall.CLONE_NEWUSER == syscall.CLONE_NEWUSER,
+ NewNetworkNamespace: flags&syscall.CLONE_NEWNET == syscall.CLONE_NEWNET,
+ NewFiles: flags&syscall.CLONE_FILES == syscall.CLONE_FILES,
+ NewFSContext: flags&syscall.CLONE_FS == syscall.CLONE_FS,
+ NewUTSNamespace: flags&syscall.CLONE_NEWUTS == syscall.CLONE_NEWUTS,
+ NewIPCNamespace: flags&syscall.CLONE_NEWIPC == syscall.CLONE_NEWIPC,
+ }
+ // "CLONE_NEWPID automatically implies CLONE_THREAD as well." - unshare(2)
+ if opts.NewPIDNamespace {
+ opts.NewThreadGroup = true
+ }
+ // "... specifying CLONE_NEWUSER automatically implies CLONE_THREAD. Since
+ // Linux 3.9, CLONE_NEWUSER also automatically implies CLONE_FS."
+ if opts.NewUserNamespace {
+ opts.NewThreadGroup = true
+ opts.NewFSContext = true
+ }
+ return 0, nil, t.Unshare(&opts)
+}
+
+// SchedYield implements linux syscall sched_yield(2).
+func SchedYield(t *kernel.Task, _ arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ t.Yield()
+ return 0, nil, nil
+}
+
+// SchedSetaffinity implements linux syscall sched_setaffinity(2).
+func SchedSetaffinity(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := args[0].Int()
+ size := args[1].SizeT()
+ maskAddr := args[2].Pointer()
+
+ var task *kernel.Task
+ if tid == 0 {
+ task = t
+ } else {
+ task = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ mask := sched.NewCPUSet(t.Kernel().ApplicationCores())
+ if size > mask.Size() {
+ size = mask.Size()
+ }
+ if _, err := t.CopyInBytes(maskAddr, mask[:size]); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, task.SetCPUMask(mask)
+}
+
+// SchedGetaffinity implements linux syscall sched_getaffinity(2).
+func SchedGetaffinity(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := args[0].Int()
+ size := args[1].SizeT()
+ maskAddr := args[2].Pointer()
+
+ // This limitation is because linux stores the cpumask
+ // in an array of "unsigned long" so the buffer needs to
+ // be a multiple of the word size.
+ if size&(t.Arch().Width()-1) > 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var task *kernel.Task
+ if tid == 0 {
+ task = t
+ } else {
+ task = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ mask := task.CPUMask()
+ // The buffer needs to be big enough to hold a cpumask with
+ // all possible cpus.
+ if size < mask.Size() {
+ return 0, nil, syserror.EINVAL
+ }
+ _, err := t.CopyOutBytes(maskAddr, mask)
+
+ // NOTE: The syscall interface is slightly different than the glibc
+ // interface. The raw sched_getaffinity syscall returns the number of
+ // bytes used to represent a cpu mask.
+ return uintptr(mask.Size()), nil, err
+}
+
+// Getcpu implements linux syscall getcpu(2).
+func Getcpu(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ cpu := args[0].Pointer()
+ node := args[1].Pointer()
+ // third argument to this system call is nowadays unused.
+
+ if cpu != 0 {
+ buf := t.CopyScratchBuffer(4)
+ usermem.ByteOrder.PutUint32(buf, uint32(t.CPU()))
+ if _, err := t.CopyOutBytes(cpu, buf); err != nil {
+ return 0, nil, err
+ }
+ }
+ // We always return node 0.
+ if node != 0 {
+ if _, err := t.MemoryManager().ZeroOut(t, node, 4, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// Setpgid implements the linux syscall setpgid(2).
+func Setpgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Note that throughout this function, pgid is interpreted with respect
+ // to t's namespace, not with respect to the selected ThreadGroup's
+ // namespace (which may be different).
+ pid := kernel.ThreadID(args[0].Int())
+ pgid := kernel.ProcessGroupID(args[1].Int())
+
+ // "If pid is zero, then the process ID of the calling process is used."
+ tg := t.ThreadGroup()
+ if pid != 0 {
+ ot := t.PIDNamespace().TaskWithID(pid)
+ if ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ tg = ot.ThreadGroup()
+ if tg.Leader() != ot {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Setpgid only operates on child threadgroups.
+ if tg != t.ThreadGroup() && (tg.Leader().Parent() == nil || tg.Leader().Parent().ThreadGroup() != t.ThreadGroup()) {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // "If pgid is zero, then the PGID of the process specified by pid is made
+ // the same as its process ID."
+ defaultPGID := kernel.ProcessGroupID(t.PIDNamespace().IDOfThreadGroup(tg))
+ if pgid == 0 {
+ pgid = defaultPGID
+ } else if pgid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // If the pgid is the same as the group, then create a new one. Otherwise,
+ // we attempt to join an existing process group.
+ if pgid == defaultPGID {
+ // For convenience, errors line up with Linux syscall API.
+ if err := tg.CreateProcessGroup(); err != nil {
+ // Is the process group already as expected? If so,
+ // just return success. This is the same behavior as
+ // Linux.
+ if t.PIDNamespace().IDOfProcessGroup(tg.ProcessGroup()) == defaultPGID {
+ return 0, nil, nil
+ }
+ return 0, nil, err
+ }
+ } else {
+ // Same as CreateProcessGroup, above.
+ if err := tg.JoinProcessGroup(t.PIDNamespace(), pgid, tg != t.ThreadGroup()); err != nil {
+ // See above.
+ if t.PIDNamespace().IDOfProcessGroup(tg.ProcessGroup()) == pgid {
+ return 0, nil, nil
+ }
+ return 0, nil, err
+ }
+ }
+
+ // Success.
+ return 0, nil, nil
+}
+
+// Getpgrp implements the linux syscall getpgrp(2).
+func Getpgrp(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return uintptr(t.PIDNamespace().IDOfProcessGroup(t.ThreadGroup().ProcessGroup())), nil, nil
+}
+
+// Getpgid implements the linux syscall getpgid(2).
+func Getpgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := kernel.ThreadID(args[0].Int())
+ if tid == 0 {
+ return Getpgrp(t, args)
+ }
+
+ target := t.PIDNamespace().TaskWithID(tid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ return uintptr(t.PIDNamespace().IDOfProcessGroup(target.ThreadGroup().ProcessGroup())), nil, nil
+}
+
+// Setsid implements the linux syscall setsid(2).
+func Setsid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.ThreadGroup().CreateSession()
+}
+
+// Getsid implements the linux syscall getsid(2).
+func Getsid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tid := kernel.ThreadID(args[0].Int())
+ if tid == 0 {
+ return uintptr(t.PIDNamespace().IDOfSession(t.ThreadGroup().Session())), nil, nil
+ }
+
+ target := t.PIDNamespace().TaskWithID(tid)
+ if target == nil {
+ return 0, nil, syserror.ESRCH
+ }
+
+ return uintptr(t.PIDNamespace().IDOfSession(target.ThreadGroup().Session())), nil, nil
+}
+
+// Getpriority pretends to implement the linux syscall getpriority(2).
+//
+// This is a stub; real priorities require a full scheduler.
+func Getpriority(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ which := args[0].Int()
+ who := kernel.ThreadID(args[1].Int())
+
+ switch which {
+ case syscall.PRIO_PROCESS:
+ // Look for who, return ESRCH if not found.
+ var task *kernel.Task
+ if who == 0 {
+ task = t
+ } else {
+ task = t.PIDNamespace().TaskWithID(who)
+ }
+
+ if task == nil {
+ return 0, nil, syscall.ESRCH
+ }
+
+ // From kernel/sys.c:getpriority:
+ // "To avoid negative return values, 'getpriority()'
+ // will not return the normal nice-value, but a negated
+ // value that has been offset by 20"
+ return uintptr(20 - task.Niceness()), nil, nil
+ case syscall.PRIO_USER:
+ fallthrough
+ case syscall.PRIO_PGRP:
+ // PRIO_USER and PRIO_PGRP have no further implementation yet.
+ return 0, nil, nil
+ default:
+ return 0, nil, syscall.EINVAL
+ }
+}
+
+// Setpriority pretends to implement the linux syscall setpriority(2).
+//
+// This is a stub; real priorities require a full scheduler.
+func Setpriority(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ which := args[0].Int()
+ who := kernel.ThreadID(args[1].Int())
+ niceval := int(args[2].Int())
+
+ // In the kernel's implementation, values outside the range
+ // of [-20, 19] are truncated to these minimum and maximum
+ // values.
+ if niceval < -20 /* min niceval */ {
+ niceval = -20
+ } else if niceval > 19 /* max niceval */ {
+ niceval = 19
+ }
+
+ switch which {
+ case syscall.PRIO_PROCESS:
+ // Look for who, return ESRCH if not found.
+ var task *kernel.Task
+ if who == 0 {
+ task = t
+ } else {
+ task = t.PIDNamespace().TaskWithID(who)
+ }
+
+ if task == nil {
+ return 0, nil, syscall.ESRCH
+ }
+
+ task.SetNiceness(niceval)
+ case syscall.PRIO_USER:
+ fallthrough
+ case syscall.PRIO_PGRP:
+ // PRIO_USER and PRIO_PGRP have no further implementation yet.
+ return 0, nil, nil
+ default:
+ return 0, nil, syscall.EINVAL
+ }
+
+ return 0, nil, nil
+}
+
+// Ptrace implements linux system call ptrace(2).
+func Ptrace(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ req := args[0].Int64()
+ pid := kernel.ThreadID(args[1].Int())
+ addr := args[2].Pointer()
+ data := args[3].Pointer()
+
+ return 0, nil, t.Ptrace(req, pid, addr, data)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
new file mode 100644
index 000000000..b4f2609c0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -0,0 +1,340 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// The most significant 29 bits hold either a pid or a file descriptor.
+func pidOfClockID(c int32) kernel.ThreadID {
+ return kernel.ThreadID(^(c >> 3))
+}
+
+// whichCPUClock returns one of CPUCLOCK_PERF, CPUCLOCK_VIRT, CPUCLOCK_SCHED or
+// CLOCK_FD.
+func whichCPUClock(c int32) int32 {
+ return c & linux.CPUCLOCK_CLOCK_MASK
+}
+
+// isCPUClockPerThread returns true if the CPUCLOCK_PERTHREAD bit is set in the
+// clock id.
+func isCPUClockPerThread(c int32) bool {
+ return c&linux.CPUCLOCK_PERTHREAD_MASK != 0
+}
+
+// isValidCPUClock returns checks that the cpu clock id is valid.
+func isValidCPUClock(c int32) bool {
+ // Bits 0, 1, and 2 cannot all be set.
+ if c&7 == 7 {
+ return false
+ }
+ if whichCPUClock(c) >= linux.CPUCLOCK_MAX {
+ return false
+ }
+ return true
+}
+
+// targetTask returns the kernel.Task for the given clock id.
+func targetTask(t *kernel.Task, c int32) *kernel.Task {
+ pid := pidOfClockID(c)
+ if pid == 0 {
+ return t
+ }
+ return t.PIDNamespace().TaskWithID(pid)
+}
+
+// ClockGetres implements linux syscall clock_getres(2).
+func ClockGetres(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := int32(args[0].Int())
+ addr := args[1].Pointer()
+ r := linux.Timespec{
+ Sec: 0,
+ Nsec: 1,
+ }
+
+ if _, err := getClock(t, clockID); err != nil {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if addr == 0 {
+ // Don't need to copy out.
+ return 0, nil, nil
+ }
+
+ return 0, nil, copyTimespecOut(t, addr, &r)
+}
+
+type cpuClocker interface {
+ UserCPUClock() ktime.Clock
+ CPUClock() ktime.Clock
+}
+
+func getClock(t *kernel.Task, clockID int32) (ktime.Clock, error) {
+ if clockID < 0 {
+ if !isValidCPUClock(clockID) {
+ return nil, syserror.EINVAL
+ }
+
+ targetTask := targetTask(t, clockID)
+ if targetTask == nil {
+ return nil, syserror.EINVAL
+ }
+
+ var target cpuClocker
+ if isCPUClockPerThread(clockID) {
+ target = targetTask
+ } else {
+ target = targetTask.ThreadGroup()
+ }
+
+ switch whichCPUClock(clockID) {
+ case linux.CPUCLOCK_VIRT:
+ return target.UserCPUClock(), nil
+ case linux.CPUCLOCK_PROF, linux.CPUCLOCK_SCHED:
+ // CPUCLOCK_SCHED is approximated by CPUCLOCK_PROF.
+ return target.CPUClock(), nil
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+
+ switch clockID {
+ case linux.CLOCK_REALTIME, linux.CLOCK_REALTIME_COARSE:
+ return t.Kernel().RealtimeClock(), nil
+ case linux.CLOCK_MONOTONIC, linux.CLOCK_MONOTONIC_COARSE, linux.CLOCK_MONOTONIC_RAW:
+ // CLOCK_MONOTONIC approximates CLOCK_MONOTONIC_RAW.
+ return t.Kernel().MonotonicClock(), nil
+ case linux.CLOCK_PROCESS_CPUTIME_ID:
+ return t.ThreadGroup().CPUClock(), nil
+ case linux.CLOCK_THREAD_CPUTIME_ID:
+ return t.CPUClock(), nil
+ default:
+ return nil, syserror.EINVAL
+ }
+}
+
+// ClockGettime implements linux syscall clock_gettime(2).
+func ClockGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := int32(args[0].Int())
+ addr := args[1].Pointer()
+
+ c, err := getClock(t, clockID)
+ if err != nil {
+ return 0, nil, err
+ }
+ ts := c.Now().Timespec()
+ return 0, nil, copyTimespecOut(t, addr, &ts)
+}
+
+// ClockSettime implements linux syscall clock_settime(2).
+func ClockSettime(*kernel.Task, arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.EPERM
+}
+
+// Time implements linux syscall time(2).
+func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ r := t.Kernel().RealtimeClock().Now().TimeT()
+ if addr == usermem.Addr(0) {
+ return uintptr(r), nil, nil
+ }
+
+ if _, err := t.CopyOut(addr, r); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r), nil, nil
+}
+
+// clockNanosleepRestartBlock encapsulates the state required to restart
+// clock_nanosleep(2) via restart_syscall(2).
+//
+// +stateify savable
+type clockNanosleepRestartBlock struct {
+ c ktime.Clock
+ duration time.Duration
+ rem usermem.Addr
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (n *clockNanosleepRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return 0, clockNanosleepFor(t, n.c, n.duration, n.rem)
+}
+
+// clockNanosleepUntil blocks until a specified time.
+//
+// If blocking is interrupted, the syscall is restarted with the original
+// arguments.
+func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error {
+ notifier, tchan := ktime.NewChannelNotifier()
+ timer := ktime.NewTimer(c, notifier)
+
+ // Turn on the timer.
+ timer.Swap(ktime.Setting{
+ Period: 0,
+ Enabled: true,
+ Next: ktime.FromTimespec(ts),
+ })
+
+ err := t.BlockWithTimer(nil, tchan)
+
+ timer.Destroy()
+
+ // Did we just block until the timeout happened?
+ if err == syserror.ETIMEDOUT {
+ return nil
+ }
+
+ return syserror.ConvertIntr(err, kernel.ERESTARTNOHAND)
+}
+
+// clockNanosleepFor blocks for a specified duration.
+//
+// If blocking is interrupted, the syscall is restarted with the remaining
+// duration timeout.
+func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem usermem.Addr) error {
+ timer, start, tchan := ktime.After(c, dur)
+
+ err := t.BlockWithTimer(nil, tchan)
+
+ after := c.Now()
+
+ timer.Destroy()
+
+ var remaining time.Duration
+ // Did we just block for the entire duration?
+ if err == syserror.ETIMEDOUT {
+ remaining = 0
+ } else {
+ remaining = dur - after.Sub(start)
+ if remaining < 0 {
+ remaining = time.Duration(0)
+ }
+ }
+
+ // Copy out remaining time.
+ if err != nil && rem != usermem.Addr(0) {
+ timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
+ if err := copyTimespecOut(t, rem, &timeleft); err != nil {
+ return err
+ }
+ }
+
+ // Did we just block for the entire duration?
+ if err == syserror.ETIMEDOUT {
+ return nil
+ }
+
+ // If interrupted, arrange for a restart with the remaining duration.
+ if err == syserror.ErrInterrupted {
+ t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
+ c: c,
+ duration: remaining,
+ rem: rem,
+ })
+ return kernel.ERESTART_RESTARTBLOCK
+ }
+
+ return err
+}
+
+// Nanosleep implements linux syscall Nanosleep(2).
+func Nanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ rem := args[1].Pointer()
+
+ ts, err := copyTimespecIn(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !ts.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Just like linux, we cap the timeout with the max number that int64 can
+ // represent which is roughly 292 years.
+ dur := time.Duration(ts.ToNsecCapped()) * time.Nanosecond
+ return 0, nil, clockNanosleepFor(t, t.Kernel().MonotonicClock(), dur, rem)
+}
+
+// ClockNanosleep implements linux syscall clock_nanosleep(2).
+func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := int32(args[0].Int())
+ flags := args[1].Int()
+ addr := args[2].Pointer()
+ rem := args[3].Pointer()
+
+ req, err := copyTimespecIn(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !req.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Only allow clock constants also allowed by Linux.
+ if clockID > 0 {
+ if clockID != linux.CLOCK_REALTIME &&
+ clockID != linux.CLOCK_MONOTONIC &&
+ clockID != linux.CLOCK_PROCESS_CPUTIME_ID {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ c, err := getClock(t, clockID)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if flags&linux.TIMER_ABSTIME != 0 {
+ return 0, nil, clockNanosleepUntil(t, c, req)
+ }
+
+ dur := time.Duration(req.ToNsecCapped()) * time.Nanosecond
+ return 0, nil, clockNanosleepFor(t, c, dur, rem)
+}
+
+// Gettimeofday implements linux syscall gettimeofday(2).
+func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ tv := args[0].Pointer()
+ tz := args[1].Pointer()
+
+ if tv != usermem.Addr(0) {
+ nowTv := t.Kernel().RealtimeClock().Now().Timeval()
+ if err := copyTimevalOut(t, tv, &nowTv); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if tz != usermem.Addr(0) {
+ // Ask the time package for the timezone.
+ _, offset := time.Now().Zone()
+ // This int32 array mimics linux's struct timezone.
+ timezone := [2]int32{-int32(offset) / 60, 0}
+ _, err := t.CopyOut(tz, timezone)
+ return 0, nil, err
+ }
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go
new file mode 100644
index 000000000..04ea7a4e9
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_timer.go
@@ -0,0 +1,203 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+const nsecPerSec = int64(time.Second)
+
+// copyItimerValIn copies an ItimerVal from the untrusted app range to the
+// kernel. The ItimerVal may be either 32 or 64 bits.
+// A NULL address is allowed because because Linux allows
+// setitimer(which, NULL, &old_value) which disables the timer.
+// There is a KERN_WARN message saying this misfeature will be removed.
+// However, that hasn't happened as of 3.19, so we continue to support it.
+func copyItimerValIn(t *kernel.Task, addr usermem.Addr) (linux.ItimerVal, error) {
+ if addr == usermem.Addr(0) {
+ return linux.ItimerVal{}, nil
+ }
+
+ switch t.Arch().Width() {
+ case 8:
+ // Native size, just copy directly.
+ var itv linux.ItimerVal
+ if _, err := t.CopyIn(addr, &itv); err != nil {
+ return linux.ItimerVal{}, err
+ }
+
+ return itv, nil
+ default:
+ return linux.ItimerVal{}, syscall.ENOSYS
+ }
+}
+
+// copyItimerValOut copies an ItimerVal to the untrusted app range.
+// The ItimerVal may be either 32 or 64 bits.
+// A NULL address is allowed, in which case no copy takes place
+func copyItimerValOut(t *kernel.Task, addr usermem.Addr, itv *linux.ItimerVal) error {
+ if addr == usermem.Addr(0) {
+ return nil
+ }
+
+ switch t.Arch().Width() {
+ case 8:
+ // Native size, just copy directly.
+ _, err := t.CopyOut(addr, itv)
+ return err
+ default:
+ return syscall.ENOSYS
+ }
+}
+
+// Getitimer implements linux syscall getitimer(2).
+func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := args[0].Int()
+ val := args[1].Pointer()
+
+ olditv, err := t.Getitimer(timerID)
+ if err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, copyItimerValOut(t, val, &olditv)
+}
+
+// Setitimer implements linux syscall setitimer(2).
+func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := args[0].Int()
+ newVal := args[1].Pointer()
+ oldVal := args[2].Pointer()
+
+ newitv, err := copyItimerValIn(t, newVal)
+ if err != nil {
+ return 0, nil, err
+ }
+ olditv, err := t.Setitimer(timerID, newitv)
+ if err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, copyItimerValOut(t, oldVal, &olditv)
+}
+
+// Alarm implements linux syscall alarm(2).
+func Alarm(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ duration := time.Duration(args[0].Uint()) * time.Second
+
+ olditv, err := t.Setitimer(linux.ITIMER_REAL, linux.ItimerVal{
+ Value: linux.DurationToTimeval(duration),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ olddur := olditv.Value.ToDuration()
+ secs := olddur.Round(time.Second).Nanoseconds() / nsecPerSec
+ if secs == 0 && olddur != 0 {
+ // We can't return 0 if an alarm was previously scheduled.
+ secs = 1
+ }
+ return uintptr(secs), nil, nil
+}
+
+// TimerCreate implements linux syscall timer_create(2).
+func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := args[0].Int()
+ sevp := args[1].Pointer()
+ timerIDp := args[2].Pointer()
+
+ c, err := getClock(t, clockID)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var sev *linux.Sigevent
+ if sevp != 0 {
+ sev = &linux.Sigevent{}
+ if _, err = t.CopyIn(sevp, sev); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ id, err := t.IntervalTimerCreate(c, sev)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if _, err := t.CopyOut(timerIDp, &id); err != nil {
+ t.IntervalTimerDelete(id)
+ return 0, nil, err
+ }
+
+ return uintptr(id), nil, nil
+}
+
+// TimerSettime implements linux syscall timer_settime(2).
+func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := linux.TimerID(args[0].Value)
+ flags := args[1].Int()
+ newValAddr := args[2].Pointer()
+ oldValAddr := args[3].Pointer()
+
+ var newVal linux.Itimerspec
+ if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ return 0, nil, err
+ }
+ oldVal, err := t.IntervalTimerSettime(timerID, newVal, flags&linux.TIMER_ABSTIME != 0)
+ if err != nil {
+ return 0, nil, err
+ }
+ if oldValAddr != 0 {
+ if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// TimerGettime implements linux syscall timer_gettime(2).
+func TimerGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := linux.TimerID(args[0].Value)
+ curValAddr := args[1].Pointer()
+
+ curVal, err := t.IntervalTimerGettime(timerID)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = t.CopyOut(curValAddr, &curVal)
+ return 0, nil, err
+}
+
+// TimerGetoverrun implements linux syscall timer_getoverrun(2).
+func TimerGetoverrun(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := linux.TimerID(args[0].Value)
+
+ o, err := t.IntervalTimerGetoverrun(timerID)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(o), nil, nil
+}
+
+// TimerDelete implements linux syscall timer_delete(2).
+func TimerDelete(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ timerID := linux.TimerID(args[0].Value)
+ return 0, nil, t.IntervalTimerDelete(timerID)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go
new file mode 100644
index 000000000..ec0155cbb
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_timerfd.go
@@ -0,0 +1,122 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/timerfd"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// TimerfdCreate implements Linux syscall timerfd_create(2).
+func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := args[0].Int()
+ flags := args[1].Int()
+
+ if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var c ktime.Clock
+ switch clockID {
+ case linux.CLOCK_REALTIME:
+ c = t.Kernel().RealtimeClock()
+ case linux.CLOCK_MONOTONIC:
+ c = t.Kernel().MonotonicClock()
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ f := timerfd.NewFile(t, c)
+ defer f.DecRef()
+ f.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&linux.TFD_NONBLOCK != 0,
+ })
+
+ fd, err := t.FDMap().NewFDFrom(0, f, kernel.FDFlags{
+ CloseOnExec: flags&linux.TFD_CLOEXEC != 0,
+ }, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// TimerfdSettime implements Linux syscall timerfd_settime(2).
+func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ flags := args[1].Int()
+ newValAddr := args[2].Pointer()
+ oldValAddr := args[3].Pointer()
+
+ if flags&^(linux.TFD_TIMER_ABSTIME) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ f := t.FDMap().GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ tf, ok := f.FileOperations.(*timerfd.TimerOperations)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var newVal linux.Itimerspec
+ if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ return 0, nil, err
+ }
+ newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tf.Clock())
+ if err != nil {
+ return 0, nil, err
+ }
+ tm, oldS := tf.SetTime(newS)
+ if oldValAddr != 0 {
+ oldVal := ktime.ItimerspecFromSetting(tm, oldS)
+ if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// TimerfdGettime implements Linux syscall timerfd_gettime(2).
+func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ curValAddr := args[1].Pointer()
+
+ f := t.FDMap().GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ tf, ok := f.FileOperations.(*timerfd.TimerOperations)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ tm, s := tf.GetTime()
+ curVal := ktime.ItimerspecFromSetting(tm, s)
+ _, err := t.CopyOut(curValAddr, &curVal)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/sys_tls.go b/pkg/sentry/syscalls/linux/sys_tls.go
new file mode 100644
index 000000000..1e8312e00
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_tls.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//+build amd64
+
+package linux
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+)
+
+// ArchPrctl implements linux syscall arch_prctl(2).
+// It sets architecture-specific process or thread state for t.
+func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ switch args[0].Int() {
+ case linux.ARCH_GET_FS:
+ addr := args[1].Pointer()
+ fsbase := t.Arch().TLS()
+ _, err := t.CopyOut(addr, uint64(fsbase))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ case linux.ARCH_SET_FS:
+ fsbase := args[1].Uint64()
+ if !t.Arch().SetTLS(uintptr(fsbase)) {
+ return 0, nil, syscall.EPERM
+ }
+
+ case linux.ARCH_GET_GS, linux.ARCH_SET_GS:
+ t.Kernel().EmitUnimplementedEvent(t)
+ fallthrough
+ default:
+ return 0, nil, syscall.EINVAL
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
new file mode 100644
index 000000000..fa81fe10e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Uname implements linux syscall uname.
+func Uname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ version := t.SyscallTable().Version
+
+ uts := t.UTSNamespace()
+
+ // Fill in structure fields.
+ var u linux.UtsName
+ copy(u.Sysname[:], version.Sysname)
+ copy(u.Nodename[:], uts.HostName())
+ copy(u.Release[:], version.Release)
+ copy(u.Version[:], version.Version)
+ copy(u.Machine[:], "x86_64") // build tag above.
+ copy(u.Domainname[:], uts.DomainName())
+
+ // Copy out the result.
+ va := args[0].Pointer()
+ _, err := t.CopyOut(va, u)
+ return 0, nil, err
+}
+
+// Setdomainname implements Linux syscall setdomainname.
+func Setdomainname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nameAddr := args[0].Pointer()
+ size := args[1].Int()
+
+ utsns := t.UTSNamespace()
+ if !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, utsns.UserNamespace()) {
+ return 0, nil, syserror.EPERM
+ }
+ if size < 0 || size > linux.UTSLen {
+ return 0, nil, syserror.EINVAL
+ }
+
+ name, err := t.CopyInString(nameAddr, int(size))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ utsns.SetDomainName(name)
+ return 0, nil, nil
+}
+
+// Sethostname implements Linux syscall sethostname.
+func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nameAddr := args[0].Pointer()
+ size := args[1].Int()
+
+ utsns := t.UTSNamespace()
+ if !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, utsns.UserNamespace()) {
+ return 0, nil, syserror.EPERM
+ }
+ if size < 0 || size > linux.UTSLen {
+ return 0, nil, syserror.EINVAL
+ }
+
+ name, err := t.CopyInString(nameAddr, int(size))
+ if err != nil {
+ return 0, nil, err
+ }
+
+ utsns.SetHostName(name)
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
new file mode 100644
index 000000000..1da72d606
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -0,0 +1,361 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // EventMaskWrite contains events that can be triggered on writes.
+ //
+ // Note that EventHUp is not going to happen for pipes but may for
+ // implementations of poll on some sockets, see net/core/datagram.c.
+ EventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Write implements linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is writable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := writev(t, file, src)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Pwrite64 implements linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is writing at an offset supported?
+ if !file.Flags().Pwrite {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is writable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwritev(t, file, src, offset)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Writev implements linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is writable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := writev(t, file, src)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+// Pwritev implements linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is writing at an offset supported?
+ if !file.Flags().Pwrite {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ // Check that the file is writable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwritev(t, file, src, offset)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements linux syscall pwritev2(2).
+// TODO(b/120162627): Implement RWF_HIPRI functionality.
+// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality.
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the syscall is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the linux internal call
+ // (https://elixir.bootlin.com/linux/v4.18/source/fs/read_write.c#L1354)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 5th argument.
+
+ fd := kdefs.FD(args[0].Int())
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := int(args[5].Int())
+
+ if int(args[4].Int())&0x4 == 1 {
+ return 0, nil, syserror.EACCES
+ }
+
+ file := t.FDMap().GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is writing at an offset supported?
+ if offset > -1 && !file.Flags().Pwrite {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ if flags&^linux.RWF_VALID != 0 {
+ return uintptr(flags), nil, syserror.EOPNOTSUPP
+ }
+
+ // Check that the file is writeable.
+ if !file.Flags().Write {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Read the iovecs that specify the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // If pwritev2 is called with an offset of -1, writev is called.
+ if offset == -1 {
+ n, err := writev(t, file, src)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+ }
+
+ n, err := pwritev(t, file, src, offset)
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) {
+ n, err := f.Writev(t, src)
+ if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if n > 0 {
+ // Queue notification if we wrote anything.
+ f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+ return n, err
+ }
+
+ // Sockets support write timeouts.
+ var haveDeadline bool
+ var deadline ktime.Time
+ if s, ok := f.FileOperations.(socket.Socket); ok {
+ dl := s.SendTimeout()
+ if dl < 0 && err == syserror.ErrWouldBlock {
+ return n, err
+ }
+ if dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ }
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ f.EventRegister(&w, EventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ // Issue the request and break out if it completes with
+ // anything other than "would block".
+ n, err = f.Writev(t, src)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+
+ f.EventUnregister(&w)
+
+ if total > 0 {
+ // Queue notification if we wrote anything.
+ f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+
+ return total, err
+}
+
+func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ n, err := f.Pwritev(t, src, offset)
+ if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if n > 0 {
+ // Queue notification if we wrote anything.
+ f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ f.EventRegister(&w, EventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ // Issue the request and break out if it completes with
+ // anything other than "would block".
+ n, err = f.Pwritev(t, src, offset+total)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.Block(ch); err != nil {
+ break
+ }
+ }
+
+ f.EventUnregister(&w)
+
+ if total > 0 {
+ // Queue notification if we wrote anything.
+ f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
+
+ return total, err
+}
diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go
new file mode 100644
index 000000000..fa6fcdc0b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/timespec.go
@@ -0,0 +1,112 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "syscall"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// copyTimespecIn copies a Timespec from the untrusted app range to the kernel.
+func copyTimespecIn(t *kernel.Task, addr usermem.Addr) (linux.Timespec, error) {
+ switch t.Arch().Width() {
+ case 8:
+ ts := linux.Timespec{}
+ in := t.CopyScratchBuffer(16)
+ _, err := t.CopyInBytes(addr, in)
+ if err != nil {
+ return ts, err
+ }
+ ts.Sec = int64(usermem.ByteOrder.Uint64(in[0:]))
+ ts.Nsec = int64(usermem.ByteOrder.Uint64(in[8:]))
+ return ts, nil
+ default:
+ return linux.Timespec{}, syserror.ENOSYS
+ }
+}
+
+// copyTimespecOut copies a Timespec to the untrusted app range.
+func copyTimespecOut(t *kernel.Task, addr usermem.Addr, ts *linux.Timespec) error {
+ switch t.Arch().Width() {
+ case 8:
+ out := t.CopyScratchBuffer(16)
+ usermem.ByteOrder.PutUint64(out[0:], uint64(ts.Sec))
+ usermem.ByteOrder.PutUint64(out[8:], uint64(ts.Nsec))
+ _, err := t.CopyOutBytes(addr, out)
+ return err
+ default:
+ return syserror.ENOSYS
+ }
+}
+
+// copyTimevalIn copies a Timeval from the untrusted app range to the kernel.
+func copyTimevalIn(t *kernel.Task, addr usermem.Addr) (linux.Timeval, error) {
+ switch t.Arch().Width() {
+ case 8:
+ tv := linux.Timeval{}
+ in := t.CopyScratchBuffer(16)
+ _, err := t.CopyInBytes(addr, in)
+ if err != nil {
+ return tv, err
+ }
+ tv.Sec = int64(usermem.ByteOrder.Uint64(in[0:]))
+ tv.Usec = int64(usermem.ByteOrder.Uint64(in[8:]))
+ return tv, nil
+ default:
+ return linux.Timeval{}, syscall.ENOSYS
+ }
+}
+
+// copyTimevalOut copies a Timeval to the untrusted app range.
+func copyTimevalOut(t *kernel.Task, addr usermem.Addr, tv *linux.Timeval) error {
+ switch t.Arch().Width() {
+ case 8:
+ out := t.CopyScratchBuffer(16)
+ usermem.ByteOrder.PutUint64(out[0:], uint64(tv.Sec))
+ usermem.ByteOrder.PutUint64(out[8:], uint64(tv.Usec))
+ _, err := t.CopyOutBytes(addr, out)
+ return err
+ default:
+ return syscall.ENOSYS
+ }
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ timespec, err := copyTimespecIn(t, timespecAddr)
+ if err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syscall.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
diff --git a/pkg/sentry/syscalls/syscalls.go b/pkg/sentry/syscalls/syscalls.go
new file mode 100644
index 000000000..5d10b3824
--- /dev/null
+++ b/pkg/sentry/syscalls/syscalls.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syscalls is the interface from the application to the kernel.
+// Traditionally, syscalls is the interface that is used by applications to
+// request services from the kernel of a operating system. We provide a
+// user-mode kernel that needs to handle those requests coming from unmodified
+// applications. Therefore, we still use the term "syscalls" to denote this
+// interface.
+//
+// Note that the stubs in this package may merely provide the interface, not
+// the actual implementation. It just makes writing syscall stubs
+// straightforward.
+package syscalls
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Error returns a syscall handler that will always give the passed error.
+func Error(err error) kernel.SyscallFn {
+ return func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, err
+ }
+}
+
+// ErrorWithEvent gives a syscall function that sends an unimplemented
+// syscall event via the event channel and returns the passed error.
+func ErrorWithEvent(err error) kernel.SyscallFn {
+ return func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, err
+ }
+}
+
+// CapError gives a syscall function that checks for capability c. If the task
+// has the capability, it returns ENOSYS, otherwise EPERM. To unprivileged
+// tasks, it will seem like there is an implementation.
+func CapError(c linux.Capability) kernel.SyscallFn {
+ return func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ if !t.HasCapability(c) {
+ return 0, nil, syserror.EPERM
+ }
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+}
diff --git a/pkg/sentry/syscalls/syscalls_state_autogen.go b/pkg/sentry/syscalls/syscalls_state_autogen.go
new file mode 100755
index 000000000..c114e7989
--- /dev/null
+++ b/pkg/sentry/syscalls/syscalls_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package syscalls
+
diff --git a/pkg/sentry/time/arith_arm64.go b/pkg/sentry/time/arith_arm64.go
new file mode 100644
index 000000000..b94740c2a
--- /dev/null
+++ b/pkg/sentry/time/arith_arm64.go
@@ -0,0 +1,70 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file provides a generic Go implementation of uint128 divided by uint64.
+
+// The code is derived from Go's generic math/big.divWW_g
+// (src/math/big/arith.go), but is only used on ARM64.
+
+package time
+
+import "math/bits"
+
+type word uint
+
+const (
+ _W = bits.UintSize // word size in bits
+ _W2 = _W / 2 // half word size in bits
+ _B2 = 1 << _W2 // half digit base
+ _M2 = _B2 - 1 // half digit mask
+)
+
+// nlz returns the number of leading zeros in x.
+// Wraps bits.LeadingZeros call for convenience.
+func nlz(x word) uint {
+ return uint(bits.LeadingZeros(uint(x)))
+}
+
+// q = (u1<<_W + u0 - r)/y
+// Adapted from Warren, Hacker's Delight, p. 152.
+func divWW(u1, u0, v word) (q, r word) {
+ if u1 >= v {
+ return 1<<_W - 1, 1<<_W - 1
+ }
+
+ s := nlz(v)
+ v <<= s
+
+ vn1 := v >> _W2
+ vn0 := v & _M2
+ un32 := u1<<s | u0>>(_W-s)
+ un10 := u0 << s
+ un1 := un10 >> _W2
+ un0 := un10 & _M2
+ q1 := un32 / vn1
+ rhat := un32 - q1*vn1
+
+ for q1 >= _B2 || q1*vn0 > _B2*rhat+un1 {
+ q1--
+ rhat += vn1
+
+ if rhat >= _B2 {
+ break
+ }
+ }
+
+ un21 := un32*_B2 + un1 - q1*v
+ q0 := un21 / vn1
+ rhat = un21 - q0*vn1
+
+ for q0 >= _B2 || q0*vn0 > _B2*rhat+un0 {
+ q0--
+ rhat += vn1
+ if rhat >= _B2 {
+ break
+ }
+ }
+
+ return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s
+}
diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go
new file mode 100644
index 000000000..c27e391c9
--- /dev/null
+++ b/pkg/sentry/time/calibrated_clock.go
@@ -0,0 +1,269 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package time provides a calibrated clock synchronized to a system reference
+// clock.
+package time
+
+import (
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// fallbackMetric tracks failed updates. It is not sync, as it is not critical
+// that all occurrences are captured and CalibratedClock may fallback many
+// times.
+var fallbackMetric = metric.MustCreateNewUint64Metric("/time/fallback", false /* sync */, "Incremented when a clock falls back to system calls due to a failed update")
+
+// CalibratedClock implements a clock that tracks a reference clock.
+//
+// Users should call Update at regular intervals of around approxUpdateInterval
+// to ensure that the clock does not drift significantly from the reference
+// clock.
+type CalibratedClock struct {
+ // mu protects the fields below.
+ // TODO(mpratt): consider a sequence counter for read locking.
+ mu sync.RWMutex
+
+ // ref sample the reference clock that this clock is calibrated
+ // against.
+ ref *sampler
+
+ // ready indicates that the fields below are ready for use calculating
+ // time.
+ ready bool
+
+ // params are the current timekeeping parameters.
+ params Parameters
+
+ // errorNS is the estimated clock error in nanoseconds.
+ errorNS ReferenceNS
+}
+
+// NewCalibratedClock creates a CalibratedClock that tracks the given ClockID.
+func NewCalibratedClock(c ClockID) *CalibratedClock {
+ return &CalibratedClock{
+ ref: newSampler(c),
+ }
+}
+
+// Debugf logs at debug level.
+func (c *CalibratedClock) Debugf(format string, v ...interface{}) {
+ if log.IsLogging(log.Debug) {
+ args := []interface{}{c.ref.clockID}
+ args = append(args, v...)
+ log.Debugf("CalibratedClock(%v): "+format, args...)
+ }
+}
+
+// Infof logs at debug level.
+func (c *CalibratedClock) Infof(format string, v ...interface{}) {
+ if log.IsLogging(log.Info) {
+ args := []interface{}{c.ref.clockID}
+ args = append(args, v...)
+ log.Infof("CalibratedClock(%v): "+format, args...)
+ }
+}
+
+// Warningf logs at debug level.
+func (c *CalibratedClock) Warningf(format string, v ...interface{}) {
+ if log.IsLogging(log.Warning) {
+ args := []interface{}{c.ref.clockID}
+ args = append(args, v...)
+ log.Warningf("CalibratedClock(%v): "+format, args...)
+ }
+}
+
+// reset forces the clock to restart the calibration process, logging the
+// passed message.
+func (c *CalibratedClock) reset(str string, v ...interface{}) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.resetLocked(str, v...)
+}
+
+// resetLocked is equivalent to reset with c.mu already held for writing.
+func (c *CalibratedClock) resetLocked(str string, v ...interface{}) {
+ c.Warningf(str+" Resetting clock; time may jump.", v...)
+ c.ready = false
+ c.ref.Reset()
+ fallbackMetric.Increment()
+}
+
+// updateParams updates the timekeeping parameters based on the passed
+// parameters.
+//
+// actual is the actual estimated timekeeping parameters. The stored parameters
+// may need to be adjusted slightly from these values to compensate for error.
+//
+// Preconditions: c.mu must be held for writing.
+func (c *CalibratedClock) updateParams(actual Parameters) {
+ if !c.ready {
+ // At initial calibration there is nothing to correct.
+ c.params = actual
+ c.ready = true
+
+ c.Infof("ready")
+
+ return
+ }
+
+ // Otherwise, adjust the params to correct for errors.
+ newParams, errorNS, err := errorAdjust(c.params, actual, actual.BaseCycles)
+ if err != nil {
+ // Something is very wrong. Reset and try again from the
+ // beginning.
+ c.resetLocked("Unable to update params: %v.", err)
+ return
+ }
+ logErrorAdjustment(c.ref.clockID, errorNS, c.params, newParams)
+
+ if errorNS.Magnitude() >= MaxClockError {
+ // We should never get such extreme error, something is very
+ // wrong. Reset everything and start again.
+ //
+ // N.B. logErrorAdjustment will have already logged the error
+ // at warning level.
+ //
+ // TODO(mpratt): We could allow Realtime clock jumps here.
+ c.resetLocked("Extreme clock error.")
+ return
+ }
+
+ c.params = newParams
+ c.errorNS = errorNS
+}
+
+// Update runs the update step of the clock, updating its synchronization with
+// the reference clock.
+//
+// Update returns timekeeping and true with the new timekeeping parameters if
+// the clock is calibrated. Update should be called regularly to prevent the
+// clock from getting significantly out of sync from the reference clock.
+//
+// The returned timekeeping parameters are invalidated on the next call to
+// Update.
+func (c *CalibratedClock) Update() (Parameters, bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := c.ref.Sample(); err != nil {
+ c.resetLocked("Unable to update calibrated clock: %v.", err)
+ return Parameters{}, false
+ }
+
+ oldest, newest, ok := c.ref.Range()
+ if !ok {
+ // Not ready yet.
+ return Parameters{}, false
+ }
+
+ minCount := uint64(newest.before - oldest.after)
+ maxCount := uint64(newest.after - oldest.before)
+ refInterval := uint64(newest.ref - oldest.ref)
+
+ // freq hz = count / (interval ns) * (nsPerS ns) / (1 s)
+ nsPerS := uint64(time.Second.Nanoseconds())
+
+ minHz, ok := muldiv64(minCount, nsPerS, refInterval)
+ if !ok {
+ c.resetLocked("Unable to update calibrated clock: (%v - %v) * %v / %v overflows.", newest.before, oldest.after, nsPerS, refInterval)
+ return Parameters{}, false
+ }
+
+ maxHz, ok := muldiv64(maxCount, nsPerS, refInterval)
+ if !ok {
+ c.resetLocked("Unable to update calibrated clock: (%v - %v) * %v / %v overflows.", newest.after, oldest.before, nsPerS, refInterval)
+ return Parameters{}, false
+ }
+
+ c.updateParams(Parameters{
+ Frequency: (minHz + maxHz) / 2,
+ BaseRef: newest.ref,
+ BaseCycles: newest.after,
+ })
+
+ return c.params, true
+}
+
+// GetTime returns the current time based on the clock calibration.
+func (c *CalibratedClock) GetTime() (int64, error) {
+ c.mu.RLock()
+
+ if !c.ready {
+ // Fallback to a syscall.
+ now, err := c.ref.Syscall()
+ c.mu.RUnlock()
+ return int64(now), err
+ }
+
+ now := c.ref.Cycles()
+ v, ok := c.params.ComputeTime(now)
+ if !ok {
+ // Something is seriously wrong with the clock. Try
+ // again with syscalls.
+ c.resetLocked("Time computation overflowed. params = %+v, now = %v.", c.params, now)
+ now, err := c.ref.Syscall()
+ c.mu.RUnlock()
+ return int64(now), err
+ }
+
+ c.mu.RUnlock()
+ return v, nil
+}
+
+// CalibratedClocks contains calibrated monotonic and realtime clocks.
+//
+// TODO(mpratt): We know that Linux runs the monotonic and realtime clocks at
+// the same rate, so rather than tracking both individually, we could do one
+// calibration for both clocks.
+type CalibratedClocks struct {
+ // monotonic is the clock tracking the system monotonic clock.
+ monotonic *CalibratedClock
+
+ // realtime is the realtime equivalent of monotonic.
+ realtime *CalibratedClock
+}
+
+// NewCalibratedClocks creates a CalibratedClocks.
+func NewCalibratedClocks() *CalibratedClocks {
+ return &CalibratedClocks{
+ monotonic: NewCalibratedClock(Monotonic),
+ realtime: NewCalibratedClock(Realtime),
+ }
+}
+
+// Update implements Clocks.Update.
+func (c *CalibratedClocks) Update() (Parameters, bool, Parameters, bool) {
+ monotonicParams, monotonicOk := c.monotonic.Update()
+ realtimeParams, realtimeOk := c.realtime.Update()
+
+ return monotonicParams, monotonicOk, realtimeParams, realtimeOk
+}
+
+// GetTime implements Clocks.GetTime.
+func (c *CalibratedClocks) GetTime(id ClockID) (int64, error) {
+ switch id {
+ case Monotonic:
+ return c.monotonic.GetTime()
+ case Realtime:
+ return c.realtime.GetTime()
+ default:
+ return 0, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/time/clock_id.go b/pkg/sentry/time/clock_id.go
new file mode 100644
index 000000000..724f59dd9
--- /dev/null
+++ b/pkg/sentry/time/clock_id.go
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "strconv"
+)
+
+// ClockID is a Linux clock identifier.
+type ClockID int32
+
+// These are the supported Linux clock identifiers.
+const (
+ Realtime ClockID = iota
+ Monotonic
+)
+
+// String implements fmt.Stringer.String.
+func (c ClockID) String() string {
+ switch c {
+ case Realtime:
+ return "Realtime"
+ case Monotonic:
+ return "Monotonic"
+ default:
+ return strconv.Itoa(int(c))
+ }
+}
diff --git a/pkg/sentry/time/clocks.go b/pkg/sentry/time/clocks.go
new file mode 100644
index 000000000..837e86094
--- /dev/null
+++ b/pkg/sentry/time/clocks.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+// Clocks represents a clock source that contains both a monotonic and realtime
+// clock.
+type Clocks interface {
+ // Update performs an update step, keeping the clocks in sync with the
+ // reference host clocks, and returning the new timekeeping parameters.
+ //
+ // Update should be called at approximately ApproxUpdateInterval.
+ Update() (monotonicParams Parameters, monotonicOk bool, realtimeParam Parameters, realtimeOk bool)
+
+ // GetTime returns the current time in nanoseconds for the given clock.
+ //
+ // Clocks implementations must support at least Monotonic and
+ // Realtime.
+ GetTime(c ClockID) (int64, error)
+}
diff --git a/pkg/sentry/time/muldiv_amd64.s b/pkg/sentry/time/muldiv_amd64.s
new file mode 100644
index 000000000..028c6684e
--- /dev/null
+++ b/pkg/sentry/time/muldiv_amd64.s
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// Documentation is available in parameters.go.
+//
+// func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
+TEXT ·muldiv64(SB),NOSPLIT,$0-33
+ MOVQ value+0(FP), AX
+ MOVQ multiplier+8(FP), BX
+ MOVQ divisor+16(FP), CX
+
+ // Multiply AX*BX and store result in DX:AX.
+ MULQ BX
+
+ // If divisor <= (value*multiplier) / 2^64, then the division will overflow.
+ //
+ // (value*multiplier) / 2^64 is DX:AX >> 64, or simply DX.
+ CMPQ CX, DX
+ JLE overflow
+
+ // Divide DX:AX by CX.
+ DIVQ CX
+
+ MOVQ AX, result+24(FP)
+ MOVB $1, ok+32(FP)
+ RET
+
+overflow:
+ MOVQ $0, result+24(FP)
+ MOVB $0, ok+32(FP)
+ RET
diff --git a/pkg/sentry/time/muldiv_arm64.s b/pkg/sentry/time/muldiv_arm64.s
new file mode 100644
index 000000000..5ad57a8a3
--- /dev/null
+++ b/pkg/sentry/time/muldiv_arm64.s
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// Documentation is available in parameters.go.
+//
+// func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
+TEXT ·muldiv64(SB),NOSPLIT,$40-33
+ MOVD value+0(FP), R0
+ MOVD multiplier+8(FP), R1
+ MOVD divisor+16(FP), R2
+
+ UMULH R0, R1, R3
+ MUL R0, R1, R4
+
+ CMP R2, R3
+ BHS overflow
+
+ MOVD R3, 8(RSP)
+ MOVD R4, 16(RSP)
+ MOVD R2, 24(RSP)
+ CALL ·divWW(SB)
+ MOVD 32(RSP), R0
+ MOVD R0, result+24(FP)
+ MOVD $1, R0
+ MOVB R0, ok+32(FP)
+ RET
+
+overflow:
+ MOVD ZR, result+24(FP)
+ MOVB ZR, ok+32(FP)
+ RET
diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go
new file mode 100644
index 000000000..63cf7c4a3
--- /dev/null
+++ b/pkg/sentry/time/parameters.go
@@ -0,0 +1,239 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+const (
+ // ApproxUpdateInterval is the approximate interval that parameters
+ // should be updated at.
+ //
+ // Error correction assumes that the next update will occur after this
+ // much time.
+ //
+ // If an update occurs before ApproxUpdateInterval passes, it has no
+ // adverse effect on error correction behavior.
+ //
+ // If an update occurs after ApproxUpdateInterval passes, the clock
+ // will overshoot its error correction target and begin accumulating
+ // error in the other direction.
+ //
+ // If updates occur after more than 2*ApproxUpdateInterval passes, the
+ // clock becomes unstable, accumulating more error than it had
+ // originally. Repeated updates after more than 2*ApproxUpdateInterval
+ // will cause unbounded increases in error.
+ //
+ // These statements assume that the host clock does not change. Actual
+ // error will depend upon host clock changes.
+ //
+ // TODO(b/68779214): make error correction more robust to delayed
+ // updates.
+ ApproxUpdateInterval = 1 * time.Second
+
+ // MaxClockError is the maximum amount of error that the clocks will
+ // try to correct.
+ //
+ // This limit:
+ //
+ // * Puts a limit on cases of otherwise unbounded increases in error.
+ //
+ // * Avoids unreasonably large frequency adjustments required to
+ // correct large errors over a single update interval.
+ MaxClockError = ReferenceNS(ApproxUpdateInterval) / 4
+)
+
+// Parameters are the timekeeping parameters needed to compute the current
+// time.
+type Parameters struct {
+ // BaseCycles was the TSC counter value when the time was BaseRef.
+ BaseCycles TSCValue
+
+ // BaseRef is the reference clock time in nanoseconds corresponding to
+ // BaseCycles.
+ BaseRef ReferenceNS
+
+ // Frequency is the frequency of the cycle clock in Hertz.
+ Frequency uint64
+}
+
+// muldiv64 multiplies two 64-bit numbers, then divides the result by another
+// 64-bit number.
+//
+// It requires that the result fit in 64 bits, but doesn't require that
+// intermediate values do; in particular, the result of the multiplication may
+// require 128 bits.
+//
+// It returns !ok if divisor is zero or the result does not fit in 64 bits.
+func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
+
+// ComputeTime calculates the current time from a "now" TSC value.
+//
+// time = ref + (now - base) / f
+func (p Parameters) ComputeTime(nowCycles TSCValue) (int64, bool) {
+ diffCycles := nowCycles - p.BaseCycles
+ if diffCycles < 0 {
+ log.Warningf("now cycles %v < base cycles %v", nowCycles, p.BaseCycles)
+ diffCycles = 0
+ }
+
+ // Overflow "won't ever happen". If diffCycles is the max value
+ // (2^63 - 1), then to overflow,
+ //
+ // frequency <= ((2^63 - 1) * 10^9) / 2^64 = 500Mhz
+ //
+ // A TSC running at 2GHz takes 201 years to reach 2^63-1. 805 years at
+ // 500MHz.
+ diffNS, ok := muldiv64(uint64(diffCycles), uint64(time.Second.Nanoseconds()), p.Frequency)
+ return int64(uint64(p.BaseRef) + diffNS), ok
+}
+
+// errorAdjust returns a new Parameters struct "adjusted" that satisfies:
+//
+// 1. adjusted.ComputeTime(now) = prevParams.ComputeTime(now)
+// * i.e., the current time does not jump.
+//
+// 2. adjusted.ComputeTime(TSC at next update) = newParams.ComputeTime(TSC at next update)
+// * i.e., Any error between prevParams and newParams will be corrected over
+// the course of the next update period.
+//
+// errorAdjust also returns the current clock error.
+//
+// Preconditions:
+// * newParams.BaseCycles >= prevParams.BaseCycles; i.e., TSC must not go
+// backwards.
+// * newParams.BaseCycles <= now; i.e., the new parameters be computed at or
+// before now.
+func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Parameters, ReferenceNS, error) {
+ if newParams.BaseCycles < prevParams.BaseCycles {
+ // Oh dear! Something is very wrong.
+ return Parameters{}, 0, fmt.Errorf("TSC went backwards in updated clock params: %v < %v", newParams.BaseCycles, prevParams.BaseCycles)
+ }
+ if newParams.BaseCycles > now {
+ return Parameters{}, 0, fmt.Errorf("parameters contain base cycles later than now: %v > %v", newParams.BaseCycles, now)
+ }
+
+ intervalNS := int64(ApproxUpdateInterval.Nanoseconds())
+ nsPerSec := uint64(time.Second.Nanoseconds())
+
+ // Current time as computed by prevParams.
+ oldNowNS, ok := prevParams.ComputeTime(now)
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("old now time computation overflowed. params = %+v, now = %v", prevParams, now)
+ }
+
+ // We expect the update ticker to run based on this clock (i.e., it has
+ // been using prevParams and will use the returned adjusted
+ // parameters). Hence it will decide to fire intervalNS from the
+ // current (oldNowNS) "now".
+ nextNS := oldNowNS + intervalNS
+
+ if nextNS <= int64(newParams.BaseRef) {
+ // The next update time already passed before the new
+ // parameters were created! We definitely can't correct the
+ // error by then.
+ return Parameters{}, 0, fmt.Errorf("unable to correct error in single period. oldNowNS = %v, nextNS = %v, p = %v", oldNowNS, nextNS, newParams)
+ }
+
+ // For what TSC value next will newParams.ComputeTime(next) = nextNS?
+ //
+ // Solve ComputeTime for next:
+ //
+ // next = newParams.Frequency * (nextNS - newParams.BaseRef) + newParams.BaseCycles
+ c, ok := muldiv64(newParams.Frequency, uint64(nextNS-int64(newParams.BaseRef)), nsPerSec)
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("%v * (%v - %v) / %v overflows", newParams.Frequency, nextNS, newParams.BaseRef, nsPerSec)
+ }
+
+ cycles := TSCValue(c)
+ next := cycles + newParams.BaseCycles
+
+ if next <= now {
+ // The next update time already passed now with the new
+ // parameters! We can't correct the error in a single period.
+ return Parameters{}, 0, fmt.Errorf("unable to correct error in single period. oldNowNS = %v, nextNS = %v, now = %v, next = %v", oldNowNS, nextNS, now, next)
+ }
+
+ // We want to solve for parameters that satisfy:
+ //
+ // adjusted.ComputeTime(now) = oldNowNS
+ //
+ // adjusted.ComputeTime(next) = nextNS
+ //
+ // i.e., the current time does not change, but by the time we reach
+ // next we reach the same time as newParams.
+
+ // We choose to keep BaseCycles fixed.
+ adjusted := Parameters{
+ BaseCycles: newParams.BaseCycles,
+ }
+
+ // We want a slope such that time goes from oldNowNS to nextNS when
+ // we reach next.
+ //
+ // In other words, cycles should increase by next - now in the next
+ // interval.
+
+ cycles = next - now
+ ns := intervalNS
+
+ // adjusted.Frequency = cycles / ns
+ adjusted.Frequency, ok = muldiv64(uint64(cycles), nsPerSec, uint64(ns))
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("(%v - %v) * %v / %v overflows", next, now, nsPerSec, ns)
+ }
+
+ // Now choose a base reference such that the current time remains the
+ // same. Note that this is just ComputeTime, solving for BaseRef:
+ //
+ // oldNowNS = BaseRef + (now - BaseCycles) / Frequency
+ // BaseRef = oldNowNS - (now - BaseCycles) / Frequency
+ diffNS, ok := muldiv64(uint64(now-adjusted.BaseCycles), nsPerSec, adjusted.Frequency)
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("(%v - %v) * %v / %v overflows", now, adjusted.BaseCycles, nsPerSec, adjusted.Frequency)
+ }
+
+ adjusted.BaseRef = ReferenceNS(oldNowNS - int64(diffNS))
+
+ // The error is the difference between the current time and what the
+ // new parameters say the current time should be.
+ newNowNS, ok := newParams.ComputeTime(now)
+ if !ok {
+ return Parameters{}, 0, fmt.Errorf("new now time computation overflowed. params = %+v, now = %v", newParams, now)
+ }
+
+ errorNS := ReferenceNS(oldNowNS - newNowNS)
+
+ return adjusted, errorNS, nil
+}
+
+// logErrorAdjustment logs the clock error and associated error correction
+// frequency adjustment.
+//
+// The log level is determined by the error severity.
+func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) {
+ fn := log.Debugf
+ if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() {
+ fn = log.Warningf
+ } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() {
+ fn = log.Infof
+ }
+
+ fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency)
+}
diff --git a/pkg/sentry/time/sampler.go b/pkg/sentry/time/sampler.go
new file mode 100644
index 000000000..2140a99b7
--- /dev/null
+++ b/pkg/sentry/time/sampler.go
@@ -0,0 +1,225 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "errors"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+)
+
+const (
+ // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
+ // It is further refined as syscalls are made.
+ defaultOverheadCycles = 1 * 1000
+
+ // maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles.
+ maxOverheadCycles = 100 * defaultOverheadCycles
+
+ // maxSampleLoops is the maximum number of times to try to get a clock sample
+ // under the expected overhead.
+ maxSampleLoops = 5
+
+ // maxSamples is the maximum number of samples to collect.
+ maxSamples = 10
+)
+
+// errOverheadTooHigh is returned from sampler.Sample if the syscall
+// overhead is too high.
+var errOverheadTooHigh = errors.New("time syscall overhead exceeds maximum")
+
+// TSCValue is a value from the TSC.
+type TSCValue int64
+
+// Rdtsc reads the TSC.
+//
+// Intel SDM, Vol 3, Ch 17.15:
+// "The RDTSC instruction reads the time-stamp counter and is guaranteed to
+// return a monotonically increasing unique value whenever executed, except for
+// a 64-bit counter wraparound. Intel guarantees that the time-stamp counter
+// will not wraparound within 10 years after being reset."
+//
+// We use int64, so we have 5 years before wrap-around.
+func Rdtsc() TSCValue
+
+// ReferenceNS are nanoseconds in the reference clock domain.
+// int64 gives us ~290 years before this overflows.
+type ReferenceNS int64
+
+// Magnitude returns the absolute value of r.
+func (r ReferenceNS) Magnitude() ReferenceNS {
+ if r < 0 {
+ return -r
+ }
+ return r
+}
+
+// cycleClock is a TSC-based cycle clock.
+type cycleClock interface {
+ // Cycles returns a count value from the TSC.
+ Cycles() TSCValue
+}
+
+// tscCycleClock is a cycleClock that uses the real TSC.
+type tscCycleClock struct{}
+
+// Cycles implements cycleClock.Cycles.
+func (tscCycleClock) Cycles() TSCValue {
+ return Rdtsc()
+}
+
+// sample contains a sample from the reference clock, with TSC values from
+// before and after the reference clock value was captured.
+type sample struct {
+ before TSCValue
+ after TSCValue
+ ref ReferenceNS
+}
+
+// Overhead returns the sample overhead in TSC cycles.
+func (s *sample) Overhead() TSCValue {
+ return s.after - s.before
+}
+
+// referenceClocks collects individual samples from a reference clock ID and
+// TSC.
+type referenceClocks interface {
+ cycleClock
+
+ // Sample returns a single sample from the reference clock ID.
+ Sample(c ClockID) (sample, error)
+}
+
+// sampler collects samples from a reference system clock, minimizing
+// the overhead in each sample.
+type sampler struct {
+ // clockID is the reference clock ID (e.g., CLOCK_MONOTONIC).
+ clockID ClockID
+
+ // clocks provides raw samples.
+ clocks referenceClocks
+
+ // overhead is the estimated sample overhead in TSC cycles.
+ overhead TSCValue
+
+ // samples is a ring buffer of the latest samples collected.
+ samples []sample
+}
+
+// newSampler creates a sampler for clockID.
+func newSampler(c ClockID) *sampler {
+ return &sampler{
+ clockID: c,
+ clocks: syscallTSCReferenceClocks{},
+ overhead: defaultOverheadCycles,
+ }
+}
+
+// Reset discards previously collected clock samples.
+func (s *sampler) Reset() {
+ s.overhead = defaultOverheadCycles
+ s.samples = []sample{}
+}
+
+// lowOverheadSample returns a reference clock sample with minimized syscall overhead.
+func (s *sampler) lowOverheadSample() (sample, error) {
+ for {
+ for i := 0; i < maxSampleLoops; i++ {
+ samp, err := s.clocks.Sample(s.clockID)
+ if err != nil {
+ return sample{}, err
+ }
+
+ if samp.before > samp.after {
+ log.Warningf("TSC went backwards: %v > %v", samp.before, samp.after)
+ continue
+ }
+
+ if samp.Overhead() <= s.overhead {
+ return samp, nil
+ }
+ }
+
+ // Couldn't get a sample with the current overhead. Increase it.
+ newOverhead := 2 * s.overhead
+ if newOverhead > maxOverheadCycles {
+ // We'll give it one more shot with the max overhead.
+
+ if s.overhead == maxOverheadCycles {
+ return sample{}, errOverheadTooHigh
+ }
+
+ newOverhead = maxOverheadCycles
+ }
+
+ s.overhead = newOverhead
+ log.Debugf("Time: Adjusting syscall overhead up to %v", s.overhead)
+ }
+}
+
+// Sample collects a reference clock sample.
+func (s *sampler) Sample() error {
+ sample, err := s.lowOverheadSample()
+ if err != nil {
+ return err
+ }
+
+ s.samples = append(s.samples, sample)
+ if len(s.samples) > maxSamples {
+ s.samples = s.samples[1:]
+ }
+
+ // If the 4 most recent samples all have an overhead less than half the
+ // expected overhead, adjust downwards.
+ if len(s.samples) < 4 {
+ return nil
+ }
+
+ for _, sample := range s.samples[len(s.samples)-4:] {
+ if sample.Overhead() > s.overhead/2 {
+ return nil
+ }
+ }
+
+ s.overhead -= s.overhead / 8
+ log.Debugf("Time: Adjusting syscall overhead down to %v", s.overhead)
+
+ return nil
+}
+
+// Syscall returns the current raw reference time without storing TSC
+// samples.
+func (s *sampler) Syscall() (ReferenceNS, error) {
+ sample, err := s.clocks.Sample(s.clockID)
+ if err != nil {
+ return 0, err
+ }
+
+ return sample.ref, nil
+}
+
+// Cycles returns a raw TSC value.
+func (s *sampler) Cycles() TSCValue {
+ return s.clocks.Cycles()
+}
+
+// Range returns the widest range of clock samples available.
+func (s *sampler) Range() (sample, sample, bool) {
+ if len(s.samples) < 2 {
+ return sample{}, sample{}, false
+ }
+
+ return s.samples[0], s.samples[len(s.samples)-1], true
+}
diff --git a/pkg/sentry/time/sampler_unsafe.go b/pkg/sentry/time/sampler_unsafe.go
new file mode 100644
index 000000000..e76180217
--- /dev/null
+++ b/pkg/sentry/time/sampler_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// syscallTSCReferenceClocks is the standard referenceClocks, collecting
+// samples using CLOCK_GETTIME and RDTSC.
+type syscallTSCReferenceClocks struct {
+ tscCycleClock
+}
+
+// Sample implements sampler.Sample.
+func (syscallTSCReferenceClocks) Sample(c ClockID) (sample, error) {
+ var s sample
+
+ s.before = Rdtsc()
+
+ // Don't call clockGettime to avoid a call which may call morestack.
+ var ts syscall.Timespec
+ _, _, e := syscall.RawSyscall(syscall.SYS_CLOCK_GETTIME, uintptr(c), uintptr(unsafe.Pointer(&ts)), 0)
+ if e != 0 {
+ return sample{}, e
+ }
+
+ s.after = Rdtsc()
+ s.ref = ReferenceNS(ts.Nano())
+
+ return s, nil
+}
+
+// clockGettime calls SYS_CLOCK_GETTIME, returning time in nanoseconds.
+func clockGettime(c ClockID) (ReferenceNS, error) {
+ var ts syscall.Timespec
+ _, _, e := syscall.RawSyscall(syscall.SYS_CLOCK_GETTIME, uintptr(c), uintptr(unsafe.Pointer(&ts)), 0)
+ if e != 0 {
+ return 0, e
+ }
+
+ return ReferenceNS(ts.Nano()), nil
+}
diff --git a/pkg/sentry/time/seqatomic_parameters.go b/pkg/sentry/time/seqatomic_parameters.go
new file mode 100755
index 000000000..ecbea4d94
--- /dev/null
+++ b/pkg/sentry/time/seqatomic_parameters.go
@@ -0,0 +1,55 @@
+package time
+
+import (
+ "reflect"
+ "strings"
+ "unsafe"
+
+ "fmt"
+ "gvisor.googlesource.com/gvisor/third_party/gvsync"
+)
+
+// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
+// with any writer critical sections in sc.
+func SeqAtomicLoadParameters(sc *gvsync.SeqCount, ptr *Parameters) Parameters {
+ // This function doesn't use SeqAtomicTryLoad because doing so is
+ // measurably, significantly (~20%) slower; Go is awful at inlining.
+ var val Parameters
+ for {
+ epoch := sc.BeginRead()
+ if gvsync.RaceEnabled {
+
+ gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ } else {
+
+ val = *ptr
+ }
+ if sc.ReadOk(epoch) {
+ break
+ }
+ }
+ return val
+}
+
+// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section
+// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read
+// would race with a writer critical section, SeqAtomicTryLoad returns
+// (unspecified, false).
+func SeqAtomicTryLoadParameters(sc *gvsync.SeqCount, epoch gvsync.SeqCountEpoch, ptr *Parameters) (Parameters, bool) {
+ var val Parameters
+ if gvsync.RaceEnabled {
+ gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ } else {
+ val = *ptr
+ }
+ return val, sc.ReadOk(epoch)
+}
+
+func initParameters() {
+ var val Parameters
+ typ := reflect.TypeOf(val)
+ name := typ.Name()
+ if ptrs := gvsync.PointersInType(typ, name); len(ptrs) != 0 {
+ panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
+ }
+}
diff --git a/pkg/sentry/time/time_state_autogen.go b/pkg/sentry/time/time_state_autogen.go
new file mode 100755
index 000000000..ea614b056
--- /dev/null
+++ b/pkg/sentry/time/time_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package time
+
diff --git a/pkg/sentry/time/tsc_amd64.s b/pkg/sentry/time/tsc_amd64.s
new file mode 100644
index 000000000..6a8eed664
--- /dev/null
+++ b/pkg/sentry/time/tsc_amd64.s
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+TEXT ·Rdtsc(SB),NOSPLIT,$0-8
+ // N.B. We need LFENCE on Intel, AMD is more complicated.
+ // Modern AMD CPUs with modern kernels make LFENCE behave like it does
+ // on Intel with MSR_F10H_DECFG_LFENCE_SERIALIZE_BIT. MFENCE is
+ // otherwise needed on AMD.
+ LFENCE
+ RDTSC
+ SHLQ $32, DX
+ ADDQ DX, AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/sentry/time/tsc_arm64.s b/pkg/sentry/time/tsc_arm64.s
new file mode 100644
index 000000000..da9fa4112
--- /dev/null
+++ b/pkg/sentry/time/tsc_arm64.s
@@ -0,0 +1,22 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+TEXT ·Rdtsc(SB),NOSPLIT,$0-8
+ // Get the virtual counter.
+ ISB $15
+ WORD $0xd53be040 //MRS CNTVCT_EL0, R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/unimpl/events.go b/pkg/sentry/unimpl/events.go
new file mode 100644
index 000000000..d92766e2d
--- /dev/null
+++ b/pkg/sentry/unimpl/events.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unimpl contains interface to emit events about unimplemented
+// features.
+package unimpl
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// contextID is the events package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxEvents is a Context.Value key for a Events.
+ CtxEvents contextID = iota
+)
+
+// Events interface defines method to emit unsupported events.
+type Events interface {
+ EmitUnimplementedEvent(context.Context)
+}
+
+// EmitUnimplementedEvent emits unsupported syscall event to the context.
+func EmitUnimplementedEvent(ctx context.Context) {
+ e := ctx.Value(CtxEvents)
+ if e == nil {
+ log.Warningf("Context.Value(CtxEvents) not present, unimplemented syscall event not reported.")
+ return
+ }
+ e.(Events).EmitUnimplementedEvent(ctx)
+}
diff --git a/pkg/sentry/unimpl/unimpl_state_autogen.go b/pkg/sentry/unimpl/unimpl_state_autogen.go
new file mode 100755
index 000000000..b9d1116f3
--- /dev/null
+++ b/pkg/sentry/unimpl/unimpl_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package unimpl
+
diff --git a/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go b/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go
new file mode 100755
index 000000000..bf30914dc
--- /dev/null
+++ b/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go
@@ -0,0 +1,91 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: pkg/sentry/unimpl/unimplemented_syscall.proto
+
+package gvisor
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ registers_go_proto "gvisor.googlesource.com/gvisor/pkg/sentry/arch/registers_go_proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type UnimplementedSyscall struct {
+ Tid int32 `protobuf:"varint,1,opt,name=tid,proto3" json:"tid,omitempty"`
+ Registers *registers_go_proto.Registers `protobuf:"bytes,2,opt,name=registers,proto3" json:"registers,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *UnimplementedSyscall) Reset() { *m = UnimplementedSyscall{} }
+func (m *UnimplementedSyscall) String() string { return proto.CompactTextString(m) }
+func (*UnimplementedSyscall) ProtoMessage() {}
+func (*UnimplementedSyscall) Descriptor() ([]byte, []int) {
+ return fileDescriptor_ddc2fcd2bea3c75d, []int{0}
+}
+
+func (m *UnimplementedSyscall) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_UnimplementedSyscall.Unmarshal(m, b)
+}
+func (m *UnimplementedSyscall) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_UnimplementedSyscall.Marshal(b, m, deterministic)
+}
+func (m *UnimplementedSyscall) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_UnimplementedSyscall.Merge(m, src)
+}
+func (m *UnimplementedSyscall) XXX_Size() int {
+ return xxx_messageInfo_UnimplementedSyscall.Size(m)
+}
+func (m *UnimplementedSyscall) XXX_DiscardUnknown() {
+ xxx_messageInfo_UnimplementedSyscall.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_UnimplementedSyscall proto.InternalMessageInfo
+
+func (m *UnimplementedSyscall) GetTid() int32 {
+ if m != nil {
+ return m.Tid
+ }
+ return 0
+}
+
+func (m *UnimplementedSyscall) GetRegisters() *registers_go_proto.Registers {
+ if m != nil {
+ return m.Registers
+ }
+ return nil
+}
+
+func init() {
+ proto.RegisterType((*UnimplementedSyscall)(nil), "gvisor.UnimplementedSyscall")
+}
+
+func init() {
+ proto.RegisterFile("pkg/sentry/unimpl/unimplemented_syscall.proto", fileDescriptor_ddc2fcd2bea3c75d)
+}
+
+var fileDescriptor_ddc2fcd2bea3c75d = []byte{
+ // 149 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2d, 0xc8, 0x4e, 0xd7,
+ 0x2f, 0x4e, 0xcd, 0x2b, 0x29, 0xaa, 0xd4, 0x2f, 0xcd, 0xcb, 0xcc, 0x2d, 0xc8, 0x81, 0x52, 0xa9,
+ 0xb9, 0xa9, 0x79, 0x25, 0xa9, 0x29, 0xf1, 0xc5, 0x95, 0xc5, 0xc9, 0x89, 0x39, 0x39, 0x7a, 0x05,
+ 0x45, 0xf9, 0x25, 0xf9, 0x42, 0x6c, 0xe9, 0x65, 0x99, 0xc5, 0xf9, 0x45, 0x52, 0xf2, 0x48, 0xda,
+ 0x12, 0x8b, 0x92, 0x33, 0xf4, 0x8b, 0x52, 0xd3, 0x33, 0x8b, 0x4b, 0x52, 0x8b, 0x8a, 0x21, 0x0a,
+ 0x95, 0x22, 0xb9, 0x44, 0x42, 0x91, 0xcd, 0x09, 0x86, 0x18, 0x23, 0x24, 0xc0, 0xc5, 0x5c, 0x92,
+ 0x99, 0x22, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1, 0x1a, 0x04, 0x62, 0x0a, 0xe9, 0x73, 0x71, 0xc2, 0x35,
+ 0x4b, 0x30, 0x29, 0x30, 0x6a, 0x70, 0x1b, 0x09, 0xea, 0x41, 0xac, 0xd1, 0x0b, 0x82, 0x49, 0x04,
+ 0x21, 0xd4, 0x24, 0xb1, 0x81, 0x6d, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x51, 0x4a, 0x47,
+ 0x79, 0xbb, 0x00, 0x00, 0x00,
+}
diff --git a/pkg/sentry/uniqueid/context.go b/pkg/sentry/uniqueid/context.go
new file mode 100644
index 000000000..e55b89689
--- /dev/null
+++ b/pkg/sentry/uniqueid/context.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uniqueid defines context.Context keys for obtaining system-wide
+// unique identifiers.
+package uniqueid
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// contextID is the kernel package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxGlobalUniqueID is a Context.Value key for a system-wide
+ // unique identifier.
+ CtxGlobalUniqueID contextID = iota
+
+ // CtxGlobalUniqueIDProvider is a Context.Value key for a
+ // system-wide unique identifier generator.
+ CtxGlobalUniqueIDProvider
+
+ // CtxInotifyCookie is a Context.Value key for a unique inotify
+ // event cookie.
+ CtxInotifyCookie
+)
+
+// GlobalFromContext returns a system-wide unique identifier from ctx.
+func GlobalFromContext(ctx context.Context) uint64 {
+ return ctx.Value(CtxGlobalUniqueID).(uint64)
+}
+
+// GlobalProviderFromContext returns a system-wide unique identifier from ctx.
+func GlobalProviderFromContext(ctx context.Context) transport.UniqueIDProvider {
+ return ctx.Value(CtxGlobalUniqueIDProvider).(transport.UniqueIDProvider)
+}
+
+// InotifyCookie generates a unique inotify event cookie from ctx.
+func InotifyCookie(ctx context.Context) uint32 {
+ return ctx.Value(CtxInotifyCookie).(uint32)
+}
diff --git a/pkg/sentry/uniqueid/uniqueid_state_autogen.go b/pkg/sentry/uniqueid/uniqueid_state_autogen.go
new file mode 100755
index 000000000..09e4327e4
--- /dev/null
+++ b/pkg/sentry/uniqueid/uniqueid_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package uniqueid
+
diff --git a/pkg/sentry/usage/cpu.go b/pkg/sentry/usage/cpu.go
new file mode 100644
index 000000000..bfc282d69
--- /dev/null
+++ b/pkg/sentry/usage/cpu.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usage
+
+import (
+ "time"
+)
+
+// CPUStats contains the subset of struct rusage fields that relate to CPU
+// scheduling.
+//
+// +stateify savable
+type CPUStats struct {
+ // UserTime is the amount of time spent executing application code.
+ UserTime time.Duration
+
+ // SysTime is the amount of time spent executing sentry code.
+ SysTime time.Duration
+
+ // VoluntarySwitches is the number of times control has been voluntarily
+ // ceded due to blocking, etc.
+ VoluntarySwitches uint64
+
+ // InvoluntarySwitches (struct rusage::ru_nivcsw) is unsupported, since
+ // "preemptive" scheduling is managed by the Go runtime, which doesn't
+ // provide this information.
+}
+
+// Accumulate adds s2 to s.
+func (s *CPUStats) Accumulate(s2 CPUStats) {
+ s.UserTime += s2.UserTime
+ s.SysTime += s2.SysTime
+ s.VoluntarySwitches += s2.VoluntarySwitches
+}
diff --git a/pkg/sentry/usage/io.go b/pkg/sentry/usage/io.go
new file mode 100644
index 000000000..dfcd3a49d
--- /dev/null
+++ b/pkg/sentry/usage/io.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usage
+
+import (
+ "sync/atomic"
+)
+
+// IO contains I/O-related statistics.
+//
+// +stateify savable
+type IO struct {
+ // CharsRead is the number of bytes read by read syscalls.
+ CharsRead uint64
+
+ // CharsWritten is the number of bytes written by write syscalls.
+ CharsWritten uint64
+
+ // ReadSyscalls is the number of read syscalls.
+ ReadSyscalls uint64
+
+ // WriteSyscalls is the number of write syscalls.
+ WriteSyscalls uint64
+
+ // The following counter is only meaningful when Sentry has internal
+ // pagecache.
+
+ // BytesRead is the number of bytes actually read into pagecache.
+ BytesRead uint64
+
+ // BytesWritten is the number of bytes actually written from pagecache.
+ BytesWritten uint64
+
+ // BytesWriteCancelled is the number of bytes not written out due to
+ // truncation.
+ BytesWriteCancelled uint64
+}
+
+// AccountReadSyscall does the accounting for a read syscall.
+func (i *IO) AccountReadSyscall(bytes int64) {
+ atomic.AddUint64(&i.ReadSyscalls, 1)
+ if bytes > 0 {
+ atomic.AddUint64(&i.CharsRead, uint64(bytes))
+ }
+}
+
+// AccountWriteSyscall does the accounting for a write syscall.
+func (i *IO) AccountWriteSyscall(bytes int64) {
+ atomic.AddUint64(&i.WriteSyscalls, 1)
+ if bytes > 0 {
+ atomic.AddUint64(&i.CharsWritten, uint64(bytes))
+ }
+}
+
+// AccountReadIO does the accounting for a read IO into the file system.
+func (i *IO) AccountReadIO(bytes int64) {
+ if bytes > 0 {
+ atomic.AddUint64(&i.BytesRead, uint64(bytes))
+ }
+}
+
+// AccountWriteIO does the accounting for a write IO into the file system.
+func (i *IO) AccountWriteIO(bytes int64) {
+ if bytes > 0 {
+ atomic.AddUint64(&i.BytesWritten, uint64(bytes))
+ }
+}
+
+// Accumulate adds up io usages.
+func (i *IO) Accumulate(io *IO) {
+ atomic.AddUint64(&i.CharsRead, atomic.LoadUint64(&io.CharsRead))
+ atomic.AddUint64(&i.CharsWritten, atomic.LoadUint64(&io.CharsWritten))
+ atomic.AddUint64(&i.ReadSyscalls, atomic.LoadUint64(&io.ReadSyscalls))
+ atomic.AddUint64(&i.WriteSyscalls, atomic.LoadUint64(&io.WriteSyscalls))
+ atomic.AddUint64(&i.BytesRead, atomic.LoadUint64(&io.BytesRead))
+ atomic.AddUint64(&i.BytesWritten, atomic.LoadUint64(&io.BytesWritten))
+ atomic.AddUint64(&i.BytesWriteCancelled, atomic.LoadUint64(&io.BytesWriteCancelled))
+}
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
new file mode 100644
index 000000000..c316f1597
--- /dev/null
+++ b/pkg/sentry/usage/memory.go
@@ -0,0 +1,284 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usage
+
+import (
+ "fmt"
+ "os"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/bits"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/memutil"
+)
+
+// MemoryKind represents a type of memory used by the application.
+//
+// For efficiency reasons, it is assumed that the Memory implementation is
+// responsible for specific stats (documented below), and those may be reported
+// in aggregate independently. See the platform.Memory interface as well as the
+// control.Usage.Collect method for more information.
+type MemoryKind int
+
+const (
+ // System represents miscellaneous system memory. This may include
+ // memory that is in the process of being reclaimed, system caches,
+ // page tables, swap, etc.
+ //
+ // This memory kind is backed by platform memory.
+ System MemoryKind = iota
+
+ // Anonymous represents anonymous application memory.
+ //
+ // This memory kind is backed by platform memory.
+ Anonymous
+
+ // PageCache represents memory allocated to back sandbox-visible files that
+ // do not have a local fd. The contents of these files are buffered in
+ // memory to support application mmaps.
+ //
+ // This memory kind is backed by platform memory.
+ PageCache
+
+ // Tmpfs represents memory used by the sandbox-visible tmpfs.
+ //
+ // This memory kind is backed by platform memory.
+ Tmpfs
+
+ // Ramdiskfs represents memory used by the ramdiskfs.
+ //
+ // This memory kind is backed by platform memory.
+ Ramdiskfs
+
+ // Mapped represents memory related to files which have a local fd on the
+ // host, and thus can be directly mapped. Typically these are files backed
+ // by gofers with donated-fd support. Note that this value may not track the
+ // exact amount of memory used by mapping on the host, because we don't have
+ // any visibility into the host kernel memory management. In particular,
+ // once we map some part of a host file, the host kernel is free to
+ // abitrarily populate/decommit the pages, which it may do for various
+ // reasons (ex. host memory reclaim, NUMA balancing).
+ //
+ // This memory kind is backed by the host pagecache, via host mmaps.
+ Mapped
+)
+
+// MemoryStats tracks application memory usage in bytes. All fields correspond to the
+// memory category with the same name. This object is thread-safe if accessed
+// through the provided methods. The public fields may be safely accessed
+// directly on a copy of the object obtained from Memory.Copy().
+type MemoryStats struct {
+ System uint64
+ Anonymous uint64
+ PageCache uint64
+ Tmpfs uint64
+ // Lazily updated based on the value in RTMapped.
+ Mapped uint64
+ Ramdiskfs uint64
+}
+
+// RTMemoryStats contains the memory usage values that need to be directly
+// exposed through a shared memory file for real-time access. These are
+// categories not backed by platform memory. For details about how this works,
+// see the memory accounting docs.
+//
+// N.B. Please keep the struct in sync with the API. Notably, changes to this
+// struct requires a version bump and addition of compatibility logic in the
+// control server. As a special-case, adding fields without re-ordering existing
+// ones do not require a version bump because the mapped page we use is
+// initially zeroed. Any added field will be ignored by an older API and will be
+// zero if read by a newer API.
+type RTMemoryStats struct {
+ RTMapped uint64
+}
+
+// MemoryLocked is Memory with access methods.
+type MemoryLocked struct {
+ mu sync.RWMutex
+ // MemoryStats records the memory stats.
+ MemoryStats
+ // RTMemoryStats records the memory stats that need to be exposed through
+ // shared page.
+ *RTMemoryStats
+ // File is the backing file storing the memory stats.
+ File *os.File
+}
+
+// Init initializes global 'MemoryAccounting'.
+func Init() error {
+ const name = "memory-usage"
+ fd, err := memutil.CreateMemFD(name, 0)
+ if err != nil {
+ return fmt.Errorf("error creating usage file: %v", err)
+ }
+ file := os.NewFile(uintptr(fd), name)
+ if err := file.Truncate(int64(RTMemoryStatsSize)); err != nil {
+ return fmt.Errorf("error truncating usage file: %v", err)
+ }
+ // Note: We rely on the returned page being initially zeroed. This will
+ // always be the case for a newly mapped page from /dev/shm. If we obtain
+ // the shared memory through some other means in the future, we may have to
+ // explicitly zero the page.
+ mmap, err := syscall.Mmap(int(file.Fd()), 0, int(RTMemoryStatsSize), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
+ if err != nil {
+ return fmt.Errorf("error mapping usage file: %v", err)
+ }
+
+ MemoryAccounting = &MemoryLocked{
+ File: file,
+ RTMemoryStats: RTMemoryStatsPointer(mmap),
+ }
+ return nil
+}
+
+// MemoryAccounting is the global memory stats.
+//
+// There is no need to save or restore the global memory accounting object,
+// because individual frame kinds are saved and charged only when they become
+// resident.
+var MemoryAccounting *MemoryLocked
+
+func (m *MemoryLocked) incLocked(val uint64, kind MemoryKind) {
+ switch kind {
+ case System:
+ atomic.AddUint64(&m.System, val)
+ case Anonymous:
+ atomic.AddUint64(&m.Anonymous, val)
+ case PageCache:
+ atomic.AddUint64(&m.PageCache, val)
+ case Mapped:
+ atomic.AddUint64(&m.RTMapped, val)
+ case Tmpfs:
+ atomic.AddUint64(&m.Tmpfs, val)
+ case Ramdiskfs:
+ atomic.AddUint64(&m.Ramdiskfs, val)
+ default:
+ panic(fmt.Sprintf("invalid memory kind: %v", kind))
+ }
+}
+
+// Inc adds an additional usage of 'val' bytes to memory category 'kind'.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Inc(val uint64, kind MemoryKind) {
+ m.mu.RLock()
+ m.incLocked(val, kind)
+ m.mu.RUnlock()
+}
+
+func (m *MemoryLocked) decLocked(val uint64, kind MemoryKind) {
+ switch kind {
+ case System:
+ atomic.AddUint64(&m.System, ^(val - 1))
+ case Anonymous:
+ atomic.AddUint64(&m.Anonymous, ^(val - 1))
+ case PageCache:
+ atomic.AddUint64(&m.PageCache, ^(val - 1))
+ case Mapped:
+ atomic.AddUint64(&m.RTMapped, ^(val - 1))
+ case Tmpfs:
+ atomic.AddUint64(&m.Tmpfs, ^(val - 1))
+ case Ramdiskfs:
+ atomic.AddUint64(&m.Ramdiskfs, ^(val - 1))
+ default:
+ panic(fmt.Sprintf("invalid memory kind: %v", kind))
+ }
+}
+
+// Dec remove a usage of 'val' bytes from memory category 'kind'.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Dec(val uint64, kind MemoryKind) {
+ m.mu.RLock()
+ m.decLocked(val, kind)
+ m.mu.RUnlock()
+}
+
+// Move moves a usage of 'val' bytes from 'from' to 'to'.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Move(val uint64, to MemoryKind, from MemoryKind) {
+ m.mu.RLock()
+ // Just call decLocked and incLocked directly. We held the RLock to
+ // protect against concurrent callers to Total().
+ m.decLocked(val, from)
+ m.incLocked(val, to)
+ m.mu.RUnlock()
+}
+
+// totalLocked returns a total usage.
+//
+// Precondition: must be called when locked.
+func (m *MemoryLocked) totalLocked() (total uint64) {
+ total += atomic.LoadUint64(&m.System)
+ total += atomic.LoadUint64(&m.Anonymous)
+ total += atomic.LoadUint64(&m.PageCache)
+ total += atomic.LoadUint64(&m.RTMapped)
+ total += atomic.LoadUint64(&m.Tmpfs)
+ total += atomic.LoadUint64(&m.Ramdiskfs)
+ return
+}
+
+// Total returns a total memory usage.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Total() uint64 {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.totalLocked()
+}
+
+// Copy returns a copy of the structure with a total.
+//
+// This method is thread-safe.
+func (m *MemoryLocked) Copy() (MemoryStats, uint64) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ ms := m.MemoryStats
+ ms.Mapped = m.RTMapped
+ return ms, m.totalLocked()
+}
+
+// MinimumTotalMemoryBytes is the minimum reported total system memory.
+var MinimumTotalMemoryBytes uint64 = 2 << 30 // 2 GB
+
+// TotalMemory returns the "total usable memory" available.
+//
+// This number doesn't really have a true value so it's based on the following
+// inputs and further bounded to be above some minimum guaranteed value (2GB),
+// additionally ensuring that total memory reported is always less than used.
+//
+// memSize should be the platform.Memory size reported by platform.Memory.TotalSize()
+// used is the total memory reported by MemoryLocked.Total()
+func TotalMemory(memSize, used uint64) uint64 {
+ if memSize < MinimumTotalMemoryBytes {
+ memSize = MinimumTotalMemoryBytes
+ }
+ if memSize < used {
+ memSize = used
+ // Bump totalSize to the next largest power of 2, if one exists, so
+ // that MemFree isn't 0.
+ if msb := bits.MostSignificantOne64(memSize); msb < 63 {
+ memSize = uint64(1) << (uint(msb) + 1)
+ }
+ }
+ return memSize
+}
+
+// IncrementalMappedAccounting controls whether host mapped memory is accounted
+// incrementally during map translation. This may be modified during early
+// initialization, and is read-only afterward.
+var IncrementalMappedAccounting = false
diff --git a/pkg/sentry/usage/memory_unsafe.go b/pkg/sentry/usage/memory_unsafe.go
new file mode 100644
index 000000000..9e0014ca0
--- /dev/null
+++ b/pkg/sentry/usage/memory_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usage
+
+import (
+ "unsafe"
+)
+
+// RTMemoryStatsSize is the size of the RTMemoryStats struct.
+var RTMemoryStatsSize = unsafe.Sizeof(RTMemoryStats{})
+
+// RTMemoryStatsPointer casts the address of the byte slice into a RTMemoryStats pointer.
+func RTMemoryStatsPointer(b []byte) *RTMemoryStats {
+ return (*RTMemoryStats)(unsafe.Pointer(&b[0]))
+}
diff --git a/pkg/sentry/usage/usage.go b/pkg/sentry/usage/usage.go
new file mode 100644
index 000000000..e3d33a965
--- /dev/null
+++ b/pkg/sentry/usage/usage.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package usage provides representations of resource usage.
+package usage
diff --git a/pkg/sentry/usage/usage_state_autogen.go b/pkg/sentry/usage/usage_state_autogen.go
new file mode 100755
index 000000000..38411db2e
--- /dev/null
+++ b/pkg/sentry/usage/usage_state_autogen.go
@@ -0,0 +1,50 @@
+// automatically generated by stateify.
+
+package usage
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *CPUStats) beforeSave() {}
+func (x *CPUStats) save(m state.Map) {
+ x.beforeSave()
+ m.Save("UserTime", &x.UserTime)
+ m.Save("SysTime", &x.SysTime)
+ m.Save("VoluntarySwitches", &x.VoluntarySwitches)
+}
+
+func (x *CPUStats) afterLoad() {}
+func (x *CPUStats) load(m state.Map) {
+ m.Load("UserTime", &x.UserTime)
+ m.Load("SysTime", &x.SysTime)
+ m.Load("VoluntarySwitches", &x.VoluntarySwitches)
+}
+
+func (x *IO) beforeSave() {}
+func (x *IO) save(m state.Map) {
+ x.beforeSave()
+ m.Save("CharsRead", &x.CharsRead)
+ m.Save("CharsWritten", &x.CharsWritten)
+ m.Save("ReadSyscalls", &x.ReadSyscalls)
+ m.Save("WriteSyscalls", &x.WriteSyscalls)
+ m.Save("BytesRead", &x.BytesRead)
+ m.Save("BytesWritten", &x.BytesWritten)
+ m.Save("BytesWriteCancelled", &x.BytesWriteCancelled)
+}
+
+func (x *IO) afterLoad() {}
+func (x *IO) load(m state.Map) {
+ m.Load("CharsRead", &x.CharsRead)
+ m.Load("CharsWritten", &x.CharsWritten)
+ m.Load("ReadSyscalls", &x.ReadSyscalls)
+ m.Load("WriteSyscalls", &x.WriteSyscalls)
+ m.Load("BytesRead", &x.BytesRead)
+ m.Load("BytesWritten", &x.BytesWritten)
+ m.Load("BytesWriteCancelled", &x.BytesWriteCancelled)
+}
+
+func init() {
+ state.Register("usage.CPUStats", (*CPUStats)(nil), state.Fns{Save: (*CPUStats).save, Load: (*CPUStats).load})
+ state.Register("usage.IO", (*IO)(nil), state.Fns{Save: (*IO).save, Load: (*IO).load})
+}
diff --git a/pkg/sentry/usermem/access_type.go b/pkg/sentry/usermem/access_type.go
new file mode 100644
index 000000000..9c1742a59
--- /dev/null
+++ b/pkg/sentry/usermem/access_type.go
@@ -0,0 +1,128 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "syscall"
+)
+
+// AccessType specifies memory access types. This is used for
+// setting mapping permissions, as well as communicating faults.
+//
+// +stateify savable
+type AccessType struct {
+ // Read is read access.
+ Read bool
+
+ // Write is write access.
+ Write bool
+
+ // Execute is executable access.
+ Execute bool
+}
+
+// String returns a pretty representation of access. This looks like the
+// familiar r-x, rw-, etc. and can be relied on as such.
+func (a AccessType) String() string {
+ bits := [3]byte{'-', '-', '-'}
+ if a.Read {
+ bits[0] = 'r'
+ }
+ if a.Write {
+ bits[1] = 'w'
+ }
+ if a.Execute {
+ bits[2] = 'x'
+ }
+ return string(bits[:])
+}
+
+// Any returns true iff at least one of Read, Write or Execute is true.
+func (a AccessType) Any() bool {
+ return a.Read || a.Write || a.Execute
+}
+
+// Prot returns the system prot (syscall.PROT_READ, etc.) for this access.
+func (a AccessType) Prot() int {
+ var prot int
+ if a.Read {
+ prot |= syscall.PROT_READ
+ }
+ if a.Write {
+ prot |= syscall.PROT_WRITE
+ }
+ if a.Execute {
+ prot |= syscall.PROT_EXEC
+ }
+ return prot
+}
+
+// SupersetOf returns true iff the access types in a are a superset of the
+// access types in other.
+func (a AccessType) SupersetOf(other AccessType) bool {
+ if !a.Read && other.Read {
+ return false
+ }
+ if !a.Write && other.Write {
+ return false
+ }
+ if !a.Execute && other.Execute {
+ return false
+ }
+ return true
+}
+
+// Intersect returns the access types set in both a and other.
+func (a AccessType) Intersect(other AccessType) AccessType {
+ return AccessType{
+ Read: a.Read && other.Read,
+ Write: a.Write && other.Write,
+ Execute: a.Execute && other.Execute,
+ }
+}
+
+// Union returns the access types set in either a or other.
+func (a AccessType) Union(other AccessType) AccessType {
+ return AccessType{
+ Read: a.Read || other.Read,
+ Write: a.Write || other.Write,
+ Execute: a.Execute || other.Execute,
+ }
+}
+
+// Effective returns the set of effective access types allowed by a, even if
+// some types are not explicitly allowed.
+func (a AccessType) Effective() AccessType {
+ // In Linux, Write and Execute access generally imply Read access. See
+ // mm/mmap.c:protection_map.
+ //
+ // The notable exception is get_user_pages, which only checks against
+ // the original vma flags. That said, most user memory accesses do not
+ // use GUP.
+ if a.Write || a.Execute {
+ a.Read = true
+ }
+ return a
+}
+
+// Convenient access types.
+var (
+ NoAccess = AccessType{}
+ Read = AccessType{Read: true}
+ Write = AccessType{Write: true}
+ Execute = AccessType{Execute: true}
+ ReadWrite = AccessType{Read: true, Write: true}
+ AnyAccess = AccessType{Read: true, Write: true, Execute: true}
+)
diff --git a/pkg/sentry/usermem/addr.go b/pkg/sentry/usermem/addr.go
new file mode 100644
index 000000000..e79210804
--- /dev/null
+++ b/pkg/sentry/usermem/addr.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "fmt"
+)
+
+// Addr represents a generic virtual address.
+//
+// +stateify savable
+type Addr uintptr
+
+// AddLength adds the given length to start and returns the result. ok is true
+// iff adding the length did not overflow the range of Addr.
+//
+// Note: This function is usually used to get the end of an address range
+// defined by its start address and length. Since the resulting end is
+// exclusive, end == 0 is technically valid, and corresponds to a range that
+// extends to the end of the address space, but ok will be false. This isn't
+// expected to ever come up in practice.
+func (v Addr) AddLength(length uint64) (end Addr, ok bool) {
+ end = v + Addr(length)
+ // The second half of the following check is needed in case uintptr is
+ // smaller than 64 bits.
+ ok = end >= v && length <= uint64(^Addr(0))
+ return
+}
+
+// RoundDown returns the address rounded down to the nearest page boundary.
+func (v Addr) RoundDown() Addr {
+ return v & ^Addr(PageSize-1)
+}
+
+// RoundUp returns the address rounded up to the nearest page boundary. ok is
+// true iff rounding up did not wrap around.
+func (v Addr) RoundUp() (addr Addr, ok bool) {
+ addr = Addr(v + PageSize - 1).RoundDown()
+ ok = addr >= v
+ return
+}
+
+// MustRoundUp is equivalent to RoundUp, but panics if rounding up wraps
+// around.
+func (v Addr) MustRoundUp() Addr {
+ addr, ok := v.RoundUp()
+ if !ok {
+ panic(fmt.Sprintf("usermem.Addr(%d).RoundUp() wraps", v))
+ }
+ return addr
+}
+
+// HugeRoundDown returns the address rounded down to the nearest huge page
+// boundary.
+func (v Addr) HugeRoundDown() Addr {
+ return v & ^Addr(HugePageSize-1)
+}
+
+// HugeRoundUp returns the address rounded up to the nearest huge page boundary.
+// ok is true iff rounding up did not wrap around.
+func (v Addr) HugeRoundUp() (addr Addr, ok bool) {
+ addr = Addr(v + HugePageSize - 1).HugeRoundDown()
+ ok = addr >= v
+ return
+}
+
+// PageOffset returns the offset of v into the current page.
+func (v Addr) PageOffset() uint64 {
+ return uint64(v & Addr(PageSize-1))
+}
+
+// IsPageAligned returns true if v.PageOffset() == 0.
+func (v Addr) IsPageAligned() bool {
+ return v.PageOffset() == 0
+}
+
+// AddrRange is a range of Addrs.
+//
+// type AddrRange <generated by go_generics>
+
+// ToRange returns [v, v+length).
+func (v Addr) ToRange(length uint64) (AddrRange, bool) {
+ end, ok := v.AddLength(length)
+ return AddrRange{v, end}, ok
+}
+
+// IsPageAligned returns true if ar.Start.IsPageAligned() and
+// ar.End.IsPageAligned().
+func (ar AddrRange) IsPageAligned() bool {
+ return ar.Start.IsPageAligned() && ar.End.IsPageAligned()
+}
+
+// String implements fmt.Stringer.String.
+func (ar AddrRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", ar.Start, ar.End)
+}
diff --git a/pkg/sentry/usermem/addr_range.go b/pkg/sentry/usermem/addr_range.go
new file mode 100755
index 000000000..152ed1434
--- /dev/null
+++ b/pkg/sentry/usermem/addr_range.go
@@ -0,0 +1,62 @@
+package usermem
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type AddrRange struct {
+ // Start is the inclusive start of the range.
+ Start Addr
+
+ // End is the exclusive end of the range.
+ End Addr
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r AddrRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r AddrRange) Length() Addr {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r AddrRange) Contains(x Addr) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r AddrRange) Overlaps(r2 AddrRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r AddrRange) IsSupersetOf(r2 AddrRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r AddrRange) Intersect(r2 AddrRange) AddrRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r AddrRange) CanSplitAt(x Addr) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/sentry/usermem/addr_range_seq_unsafe.go b/pkg/sentry/usermem/addr_range_seq_unsafe.go
new file mode 100644
index 000000000..c09337c15
--- /dev/null
+++ b/pkg/sentry/usermem/addr_range_seq_unsafe.go
@@ -0,0 +1,277 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "unsafe"
+)
+
+// An AddrRangeSeq represents a sequence of AddrRanges.
+//
+// AddrRangeSeqs are immutable and may be copied by value. The zero value of
+// AddrRangeSeq represents an empty sequence.
+//
+// An AddrRangeSeq may contain AddrRanges with a length of 0. This is necessary
+// since zero-length AddrRanges are significant to MM bounds checks.
+type AddrRangeSeq struct {
+ // If length is 0, then the AddrRangeSeq represents no AddrRanges.
+ // Invariants: data == 0; offset == 0; limit == 0.
+ //
+ // If length is 1, then the AddrRangeSeq represents the single
+ // AddrRange{offset, offset+limit}. Invariants: data == 0.
+ //
+ // Otherwise, length >= 2, and the AddrRangeSeq represents the `length`
+ // AddrRanges in the array of AddrRanges starting at address `data`,
+ // starting at `offset` bytes into the first AddrRange and limited to the
+ // following `limit` bytes. (AddrRanges after `limit` are still iterated,
+ // but are truncated to a length of 0.) Invariants: data != 0; offset <=
+ // data[0].Length(); limit > 0; offset+limit <= the combined length of all
+ // AddrRanges in the array.
+ data unsafe.Pointer
+ length int
+ offset Addr
+ limit Addr
+}
+
+// AddrRangeSeqOf returns an AddrRangeSeq representing the single AddrRange ar.
+func AddrRangeSeqOf(ar AddrRange) AddrRangeSeq {
+ return AddrRangeSeq{
+ length: 1,
+ offset: ar.Start,
+ limit: ar.Length(),
+ }
+}
+
+// AddrRangeSeqFromSlice returns an AddrRangeSeq representing all AddrRanges in
+// slice.
+//
+// Whether the returned AddrRangeSeq shares memory with slice is unspecified;
+// clients should avoid mutating slices passed to AddrRangeSeqFromSlice.
+//
+// Preconditions: The combined length of all AddrRanges in slice <=
+// math.MaxInt64.
+func AddrRangeSeqFromSlice(slice []AddrRange) AddrRangeSeq {
+ var limit int64
+ for _, ar := range slice {
+ len64 := int64(ar.Length())
+ if len64 < 0 {
+ panic(fmt.Sprintf("Length of AddrRange %v overflows int64", ar))
+ }
+ sum := limit + len64
+ if sum < limit {
+ panic(fmt.Sprintf("Total length of AddrRanges %v overflows int64", slice))
+ }
+ limit = sum
+ }
+ return addrRangeSeqFromSliceLimited(slice, limit)
+}
+
+// Preconditions: The combined length of all AddrRanges in slice <= limit.
+// limit >= 0. If len(slice) != 0, then limit > 0.
+func addrRangeSeqFromSliceLimited(slice []AddrRange, limit int64) AddrRangeSeq {
+ switch len(slice) {
+ case 0:
+ return AddrRangeSeq{}
+ case 1:
+ return AddrRangeSeq{
+ length: 1,
+ offset: slice[0].Start,
+ limit: Addr(limit),
+ }
+ default:
+ return AddrRangeSeq{
+ data: unsafe.Pointer(&slice[0]),
+ length: len(slice),
+ limit: Addr(limit),
+ }
+ }
+}
+
+// IsEmpty returns true if ars.NumRanges() == 0.
+//
+// Note that since AddrRangeSeq may contain AddrRanges with a length of zero,
+// an AddrRange representing 0 bytes (AddrRangeSeq.NumBytes() == 0) is not
+// necessarily empty.
+func (ars AddrRangeSeq) IsEmpty() bool {
+ return ars.length == 0
+}
+
+// NumRanges returns the number of AddrRanges in ars.
+func (ars AddrRangeSeq) NumRanges() int {
+ return ars.length
+}
+
+// NumBytes returns the number of bytes represented by ars.
+func (ars AddrRangeSeq) NumBytes() int64 {
+ return int64(ars.limit)
+}
+
+// Head returns the first AddrRange in ars.
+//
+// Preconditions: !ars.IsEmpty().
+func (ars AddrRangeSeq) Head() AddrRange {
+ if ars.length == 0 {
+ panic("empty AddrRangeSeq")
+ }
+ if ars.length == 1 {
+ return AddrRange{ars.offset, ars.offset + ars.limit}
+ }
+ ar := *(*AddrRange)(ars.data)
+ ar.Start += ars.offset
+ if ar.Length() > ars.limit {
+ ar.End = ar.Start + ars.limit
+ }
+ return ar
+}
+
+// Tail returns an AddrRangeSeq consisting of all AddrRanges in ars after the
+// first.
+//
+// Preconditions: !ars.IsEmpty().
+func (ars AddrRangeSeq) Tail() AddrRangeSeq {
+ if ars.length == 0 {
+ panic("empty AddrRangeSeq")
+ }
+ if ars.length == 1 {
+ return AddrRangeSeq{}
+ }
+ return ars.externalTail()
+}
+
+// Preconditions: ars.length >= 2.
+func (ars AddrRangeSeq) externalTail() AddrRangeSeq {
+ headLen := (*AddrRange)(ars.data).Length() - ars.offset
+ var tailLimit int64
+ if ars.limit > headLen {
+ tailLimit = int64(ars.limit - headLen)
+ }
+ var extSlice []AddrRange
+ extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice))
+ extSliceHdr.Data = uintptr(ars.data)
+ extSliceHdr.Len = ars.length
+ extSliceHdr.Cap = ars.length
+ return addrRangeSeqFromSliceLimited(extSlice[1:], tailLimit)
+}
+
+// DropFirst returns an AddrRangeSeq equivalent to ars, but with the first n
+// bytes omitted. If n > ars.NumBytes(), DropFirst returns an empty
+// AddrRangeSeq.
+//
+// If !ars.IsEmpty() and ars.Head().Length() == 0, DropFirst will always omit
+// at least ars.Head(), even if n == 0. This guarantees that the basic pattern
+// of:
+//
+// for !ars.IsEmpty() {
+// n, err = doIOWith(ars.Head())
+// if err != nil {
+// return err
+// }
+// ars = ars.DropFirst(n)
+// }
+//
+// works even in the presence of zero-length AddrRanges.
+//
+// Preconditions: n >= 0.
+func (ars AddrRangeSeq) DropFirst(n int) AddrRangeSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return ars.DropFirst64(int64(n))
+}
+
+// DropFirst64 is equivalent to DropFirst but takes an int64.
+func (ars AddrRangeSeq) DropFirst64(n int64) AddrRangeSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ if Addr(n) > ars.limit {
+ return AddrRangeSeq{}
+ }
+ // Handle initial empty AddrRange.
+ switch ars.length {
+ case 0:
+ return AddrRangeSeq{}
+ case 1:
+ if ars.limit == 0 {
+ return AddrRangeSeq{}
+ }
+ default:
+ if rawHeadLen := (*AddrRange)(ars.data).Length(); ars.offset == rawHeadLen {
+ ars = ars.externalTail()
+ }
+ }
+ for n != 0 {
+ // Calling ars.Head() here is surprisingly expensive, so inline getting
+ // the head's length.
+ var headLen Addr
+ if ars.length == 1 {
+ headLen = ars.limit
+ } else {
+ headLen = (*AddrRange)(ars.data).Length() - ars.offset
+ }
+ if Addr(n) < headLen {
+ // Dropping ends partway through the head AddrRange.
+ ars.offset += Addr(n)
+ ars.limit -= Addr(n)
+ return ars
+ }
+ n -= int64(headLen)
+ ars = ars.Tail()
+ }
+ return ars
+}
+
+// TakeFirst returns an AddrRangeSeq equivalent to ars, but iterating at most n
+// bytes. TakeFirst never removes AddrRanges from ars; AddrRanges beyond the
+// first n bytes are reduced to a length of zero, but will still be iterated.
+//
+// Preconditions: n >= 0.
+func (ars AddrRangeSeq) TakeFirst(n int) AddrRangeSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ return ars.TakeFirst64(int64(n))
+}
+
+// TakeFirst64 is equivalent to TakeFirst but takes an int64.
+func (ars AddrRangeSeq) TakeFirst64(n int64) AddrRangeSeq {
+ if n < 0 {
+ panic(fmt.Sprintf("invalid n: %d", n))
+ }
+ if ars.limit > Addr(n) {
+ ars.limit = Addr(n)
+ }
+ return ars
+}
+
+// String implements fmt.Stringer.String.
+func (ars AddrRangeSeq) String() string {
+ // This is deliberately chosen to be the same as fmt's automatic stringer
+ // for []AddrRange.
+ var buf bytes.Buffer
+ buf.WriteByte('[')
+ var sep string
+ for !ars.IsEmpty() {
+ buf.WriteString(sep)
+ sep = " "
+ buf.WriteString(ars.Head().String())
+ ars = ars.Tail()
+ }
+ buf.WriteByte(']')
+ return buf.String()
+}
diff --git a/pkg/sentry/usermem/bytes_io.go b/pkg/sentry/usermem/bytes_io.go
new file mode 100644
index 000000000..f98d82168
--- /dev/null
+++ b/pkg/sentry/usermem/bytes_io.go
@@ -0,0 +1,126 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+const maxInt = int(^uint(0) >> 1)
+
+// BytesIO implements IO using a byte slice. Addresses are interpreted as
+// offsets into the slice. Reads and writes beyond the end of the slice return
+// EFAULT.
+type BytesIO struct {
+ Bytes []byte
+}
+
+// CopyOut implements IO.CopyOut.
+func (b *BytesIO) CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) {
+ rngN, rngErr := b.rangeCheck(addr, len(src))
+ if rngN == 0 {
+ return 0, rngErr
+ }
+ return copy(b.Bytes[int(addr):], src[:rngN]), rngErr
+}
+
+// CopyIn implements IO.CopyIn.
+func (b *BytesIO) CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) {
+ rngN, rngErr := b.rangeCheck(addr, len(dst))
+ if rngN == 0 {
+ return 0, rngErr
+ }
+ return copy(dst[:rngN], b.Bytes[int(addr):]), rngErr
+}
+
+// ZeroOut implements IO.ZeroOut.
+func (b *BytesIO) ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) {
+ if toZero > int64(maxInt) {
+ return 0, syserror.EINVAL
+ }
+ rngN, rngErr := b.rangeCheck(addr, int(toZero))
+ if rngN == 0 {
+ return 0, rngErr
+ }
+ zeroSlice := b.Bytes[int(addr) : int(addr)+rngN]
+ for i := range zeroSlice {
+ zeroSlice[i] = 0
+ }
+ return int64(rngN), rngErr
+}
+
+// CopyOutFrom implements IO.CopyOutFrom.
+func (b *BytesIO) CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) {
+ dsts, rngErr := b.blocksFromAddrRanges(ars)
+ n, err := src.ReadToBlocks(dsts)
+ if err != nil {
+ return int64(n), err
+ }
+ return int64(n), rngErr
+}
+
+// CopyInTo implements IO.CopyInTo.
+func (b *BytesIO) CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) {
+ srcs, rngErr := b.blocksFromAddrRanges(ars)
+ n, err := dst.WriteFromBlocks(srcs)
+ if err != nil {
+ return int64(n), err
+ }
+ return int64(n), rngErr
+}
+
+func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) {
+ if length == 0 {
+ return 0, nil
+ }
+ if length < 0 {
+ return 0, syserror.EINVAL
+ }
+ max := Addr(len(b.Bytes))
+ if addr >= max {
+ return 0, syserror.EFAULT
+ }
+ end, ok := addr.AddLength(uint64(length))
+ if !ok || end > max {
+ return int(max - addr), syserror.EFAULT
+ }
+ return length, nil
+}
+
+func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) {
+ blocks := make([]safemem.Block, 0, ars.NumRanges())
+ for !ars.IsEmpty() {
+ ar := ars.Head()
+ n, err := b.rangeCheck(ar.Start, int(ar.Length()))
+ if n != 0 {
+ blocks = append(blocks, safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start):int(ar.Start)+n]))
+ }
+ if err != nil {
+ return safemem.BlockSeqFromSlice(blocks), err
+ }
+ ars = ars.Tail()
+ }
+ return safemem.BlockSeqFromSlice(blocks), nil
+}
+
+// BytesIOSequence returns an IOSequence representing the given byte slice.
+func BytesIOSequence(buf []byte) IOSequence {
+ return IOSequence{
+ IO: &BytesIO{buf},
+ Addrs: AddrRangeSeqOf(AddrRange{0, Addr(len(buf))}),
+ }
+}
diff --git a/pkg/sentry/usermem/bytes_io_unsafe.go b/pkg/sentry/usermem/bytes_io_unsafe.go
new file mode 100644
index 000000000..bb49d2ff3
--- /dev/null
+++ b/pkg/sentry/usermem/bytes_io_unsafe.go
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/atomicbitops"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+)
+
+// SwapUint32 implements IO.SwapUint32.
+func (b *BytesIO) SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) {
+ if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil {
+ return 0, rngErr
+ }
+ return atomic.SwapUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)])), new), nil
+}
+
+// CompareAndSwapUint32 implements IO.CompareAndSwapUint32.
+func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) {
+ if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil {
+ return 0, rngErr
+ }
+ return atomicbitops.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)])), old, new), nil
+}
+
+// LoadUint32 implements IO.LoadUint32.
+func (b *BytesIO) LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) {
+ if _, err := b.rangeCheck(addr, 4); err != nil {
+ return 0, err
+ }
+ return atomic.LoadUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)]))), nil
+}
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go
new file mode 100644
index 000000000..31e4d6ada
--- /dev/null
+++ b/pkg/sentry/usermem/usermem.go
@@ -0,0 +1,587 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package usermem governs access to user memory.
+package usermem
+
+import (
+ "errors"
+ "io"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// IO provides access to the contents of a virtual memory space.
+//
+// FIXME(b/38173783): Implementations of IO cannot expect ctx to contain any
+// meaningful data.
+type IO interface {
+ // CopyOut copies len(src) bytes from src to the memory mapped at addr. It
+ // returns the number of bytes copied. If the number of bytes copied is <
+ // len(src), it returns a non-nil error explaining why.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order.
+ //
+ // Postconditions: CopyOut does not retain src.
+ CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error)
+
+ // CopyIn copies len(dst) bytes from the memory mapped at addr to dst.
+ // It returns the number of bytes copied. If the number of bytes copied is
+ // < len(dst), it returns a non-nil error explaining why.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order.
+ //
+ // Postconditions: CopyIn does not retain dst.
+ CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error)
+
+ // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number
+ // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a
+ // non-nil error explaining why.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. toZero >= 0.
+ ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error)
+
+ // CopyOutFrom copies ars.NumBytes() bytes from src to the memory mapped at
+ // ars. It returns the number of bytes copied, which may be less than the
+ // number of bytes read from src if copying fails. CopyOutFrom may return a
+ // partial copy without an error iff src.ReadToBlocks returns a partial
+ // read without an error.
+ //
+ // CopyOutFrom calls src.ReadToBlocks at most once.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. src.ReadToBlocks must not block
+ // on mm.MemoryManager.activeMu or any preceding locks in the lock order.
+ CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error)
+
+ // CopyInTo copies ars.NumBytes() bytes from the memory mapped at ars to
+ // dst. It returns the number of bytes copied. CopyInTo may return a
+ // partial copy without an error iff dst.WriteFromBlocks returns a partial
+ // write without an error.
+ //
+ // CopyInTo calls dst.WriteFromBlocks at most once.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. dst.WriteFromBlocks must not
+ // block on mm.MemoryManager.activeMu or any preceding locks in the lock
+ // order.
+ CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error)
+
+ // TODO(jamieliu): The requirement that CopyOutFrom/CopyInTo call src/dst
+ // at most once, which is unnecessary in most cases, forces implementations
+ // to gather safemem.Blocks into a single slice to pass to src/dst. Add
+ // CopyOutFromIter/CopyInToIter, which relaxes this restriction, to avoid
+ // this allocation.
+
+ // SwapUint32 atomically sets the uint32 value at addr to new and
+ // returns the previous value.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. addr must be aligned to a 4-byte
+ // boundary.
+ SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error)
+
+ // CompareAndSwapUint32 atomically compares the uint32 value at addr to
+ // old; if they are equal, the value in memory is replaced by new. In
+ // either case, the previous value stored in memory is returned.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. addr must be aligned to a 4-byte
+ // boundary.
+ CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error)
+
+ // LoadUint32 atomically loads the uint32 value at addr and returns it.
+ //
+ // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or
+ // any following locks in the lock order. addr must be aligned to a 4-byte
+ // boundary.
+ LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error)
+}
+
+// IOOpts contains options applicable to all IO methods.
+type IOOpts struct {
+ // If IgnorePermissions is true, application-defined memory protections set
+ // by mmap(2) or mprotect(2) will be ignored. (Memory protections required
+ // by the target of the mapping are never ignored.)
+ IgnorePermissions bool
+
+ // If AddressSpaceActive is true, the IO implementation may assume that it
+ // has an active AddressSpace and can therefore use AddressSpace copying
+ // without performing activation. See mm/io.go for details.
+ AddressSpaceActive bool
+}
+
+// IOReadWriter is an io.ReadWriter that reads from / writes to addresses
+// starting at addr in IO. The preconditions that apply to IO.CopyIn and
+// IO.CopyOut also apply to IOReadWriter.Read and IOReadWriter.Write
+// respectively.
+type IOReadWriter struct {
+ Ctx context.Context
+ IO IO
+ Addr Addr
+ Opts IOOpts
+}
+
+// Read implements io.Reader.Read.
+//
+// Note that an address space does not have an "end of file", so Read can only
+// return io.EOF if IO.CopyIn returns io.EOF. Attempts to read unmapped or
+// unreadable memory, or beyond the end of the address space, should return
+// EFAULT.
+func (rw *IOReadWriter) Read(dst []byte) (int, error) {
+ n, err := rw.IO.CopyIn(rw.Ctx, rw.Addr, dst, rw.Opts)
+ end, ok := rw.Addr.AddLength(uint64(n))
+ if ok {
+ rw.Addr = end
+ } else {
+ // Disallow wraparound.
+ rw.Addr = ^Addr(0)
+ if err != nil {
+ err = syserror.EFAULT
+ }
+ }
+ return n, err
+}
+
+// Writer implements io.Writer.Write.
+func (rw *IOReadWriter) Write(src []byte) (int, error) {
+ n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts)
+ end, ok := rw.Addr.AddLength(uint64(n))
+ if ok {
+ rw.Addr = end
+ } else {
+ // Disallow wraparound.
+ rw.Addr = ^Addr(0)
+ if err != nil {
+ err = syserror.EFAULT
+ }
+ }
+ return n, err
+}
+
+// CopyObjectOut copies a fixed-size value or slice of fixed-size values from
+// src to the memory mapped at addr in uio. It returns the number of bytes
+// copied.
+//
+// CopyObjectOut must use reflection to encode src; performance-sensitive
+// clients should do encoding manually and use uio.CopyOut directly.
+//
+// Preconditions: As for IO.CopyOut.
+func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) {
+ w := &IOReadWriter{
+ Ctx: ctx,
+ IO: uio,
+ Addr: addr,
+ Opts: opts,
+ }
+ // Allocate a byte slice the size of the object being marshaled. This
+ // adds an extra reflection call, but avoids needing to grow the slice
+ // during encoding, which can result in many heap-allocated slices.
+ b := make([]byte, 0, binary.Size(src))
+ return w.Write(binary.Marshal(b, ByteOrder, src))
+}
+
+// CopyObjectIn copies a fixed-size value or slice of fixed-size values from
+// the memory mapped at addr in uio to dst. It returns the number of bytes
+// copied.
+//
+// CopyObjectIn must use reflection to decode dst; performance-sensitive
+// clients should use uio.CopyIn directly and do decoding manually.
+//
+// Preconditions: As for IO.CopyIn.
+func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) {
+ r := &IOReadWriter{
+ Ctx: ctx,
+ IO: uio,
+ Addr: addr,
+ Opts: opts,
+ }
+ buf := make([]byte, binary.Size(dst))
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return 0, err
+ }
+ binary.Unmarshal(buf, ByteOrder, dst)
+ return int(r.Addr - addr), nil
+}
+
+// copyStringIncrement is the maximum number of bytes that are copied from
+// virtual memory at a time by CopyStringIn.
+const copyStringIncrement = 64
+
+// CopyStringIn copies a NUL-terminated string of unknown length from the
+// memory mapped at addr in uio and returns it as a string (not including the
+// trailing NUL). If the length of the string, including the terminating NUL,
+// would exceed maxlen, CopyStringIn returns the string truncated to maxlen and
+// ENAMETOOLONG.
+//
+// Preconditions: As for IO.CopyFromUser. maxlen >= 0.
+func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) {
+ buf := make([]byte, maxlen)
+ var done int
+ for done < maxlen {
+ start, ok := addr.AddLength(uint64(done))
+ if !ok {
+ // Last page of kernel memory. The application can't use this
+ // anyway.
+ return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ }
+ // Read up to copyStringIncrement bytes at a time.
+ readlen := copyStringIncrement
+ if readlen > maxlen-done {
+ readlen = maxlen - done
+ }
+ end, ok := start.AddLength(uint64(readlen))
+ if !ok {
+ return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ }
+ // Shorten the read to avoid crossing page boundaries, since faulting
+ // in a page unnecessarily is expensive. This also ensures that partial
+ // copies up to the end of application-mappable memory succeed.
+ if start.RoundDown() != end.RoundDown() {
+ end = end.RoundDown()
+ }
+ n, err := uio.CopyIn(ctx, start, buf[done:done+int(end-start)], opts)
+ // Look for the terminating zero byte, which may have occurred before
+ // hitting err.
+ for i, c := range buf[done : done+n] {
+ if c == 0 {
+ return stringFromImmutableBytes(buf[:done+i]), nil
+ }
+ }
+ done += n
+ if err != nil {
+ return stringFromImmutableBytes(buf[:done]), err
+ }
+ }
+ return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
+}
+
+// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The
+// maximum number of bytes copied is ars.NumBytes() or len(src), whichever is
+// less. CopyOutVec returns the number of bytes copied; if this is less than
+// the maximum, it returns a non-nil error explaining why.
+//
+// Preconditions: As for IO.CopyOut.
+func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts IOOpts) (int, error) {
+ var done int
+ for !ars.IsEmpty() && done < len(src) {
+ ar := ars.Head()
+ cplen := len(src) - done
+ if Addr(cplen) >= ar.Length() {
+ cplen = int(ar.Length())
+ }
+ n, err := uio.CopyOut(ctx, ar.Start, src[done:done+cplen], opts)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ ars = ars.DropFirst(n)
+ }
+ return done, nil
+}
+
+// CopyInVec copies bytes from the memory mapped at ars in uio to dst. The
+// maximum number of bytes copied is ars.NumBytes() or len(dst), whichever is
+// less. CopyInVec returns the number of bytes copied; if this is less than the
+// maximum, it returns a non-nil error explaining why.
+//
+// Preconditions: As for IO.CopyIn.
+func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts IOOpts) (int, error) {
+ var done int
+ for !ars.IsEmpty() && done < len(dst) {
+ ar := ars.Head()
+ cplen := len(dst) - done
+ if Addr(cplen) >= ar.Length() {
+ cplen = int(ar.Length())
+ }
+ n, err := uio.CopyIn(ctx, ar.Start, dst[done:done+cplen], opts)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ ars = ars.DropFirst(n)
+ }
+ return done, nil
+}
+
+// ZeroOutVec writes zeroes to the memory mapped at ars in uio. The maximum
+// number of bytes written is ars.NumBytes() or toZero, whichever is less.
+// ZeroOutVec returns the number of bytes written; if this is less than the
+// maximum, it returns a non-nil error explaining why.
+//
+// Preconditions: As for IO.ZeroOut.
+func ZeroOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) {
+ var done int64
+ for !ars.IsEmpty() && done < toZero {
+ ar := ars.Head()
+ cplen := toZero - done
+ if Addr(cplen) >= ar.Length() {
+ cplen = int64(ar.Length())
+ }
+ n, err := uio.ZeroOut(ctx, ar.Start, cplen, opts)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ ars = ars.DropFirst64(n)
+ }
+ return done, nil
+}
+
+func isASCIIWhitespace(b byte) bool {
+ // Compare Linux include/linux/ctype.h, lib/ctype.c.
+ // 9 => horizontal tab '\t'
+ // 10 => line feed '\n'
+ // 11 => vertical tab '\v'
+ // 12 => form feed '\c'
+ // 13 => carriage return '\r'
+ return b == ' ' || (b >= 9 && b <= 13)
+}
+
+// CopyInt32StringsInVec copies up to len(dsts) whitespace-separated decimal
+// strings from the memory mapped at ars in uio and converts them to int32
+// values in dsts. It returns the number of bytes read.
+//
+// CopyInt32StringsInVec shares the following properties with Linux's
+// kernel/sysctl.c:proc_dointvec(write=1):
+//
+// - If any read value overflows the range of int32, or any invalid characters
+// are encountered during the read, CopyInt32StringsInVec returns EINVAL.
+//
+// - If, upon reaching the end of ars, fewer than len(dsts) values have been
+// read, CopyInt32StringsInVec returns no error if at least 1 value was read
+// and EINVAL otherwise.
+//
+// - Trailing whitespace after the last successfully read value is counted in
+// the number of bytes read.
+//
+// Unlike proc_dointvec():
+//
+// - CopyInt32StringsInVec does not implicitly limit ars.NumBytes() to
+// PageSize-1; callers that require this must do so explicitly.
+//
+// - CopyInt32StringsInVec returns EINVAL if ars.NumBytes() == 0.
+//
+// Preconditions: As for CopyInVec.
+func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) {
+ if len(dsts) == 0 {
+ return 0, nil
+ }
+
+ buf := make([]byte, ars.NumBytes())
+ n, cperr := CopyInVec(ctx, uio, ars, buf, opts)
+ buf = buf[:n]
+
+ var i, j int
+ for ; j < len(dsts); j++ {
+ // Skip leading whitespace.
+ for i < len(buf) && isASCIIWhitespace(buf[i]) {
+ i++
+ }
+ if i == len(buf) {
+ break
+ }
+
+ // Find the end of the value to be parsed (next whitespace or end of string).
+ nextI := i + 1
+ for nextI < len(buf) && !isASCIIWhitespace(buf[nextI]) {
+ nextI++
+ }
+
+ // Parse a single value.
+ val, err := strconv.ParseInt(string(buf[i:nextI]), 10, 32)
+ if err != nil {
+ return int64(i), syserror.EINVAL
+ }
+ dsts[j] = int32(val)
+
+ i = nextI
+ }
+
+ // Skip trailing whitespace.
+ for i < len(buf) && isASCIIWhitespace(buf[i]) {
+ i++
+ }
+
+ if cperr != nil {
+ return int64(i), cperr
+ }
+ if j == 0 {
+ return int64(i), syserror.EINVAL
+ }
+ return int64(i), nil
+}
+
+// CopyInt32StringInVec is equivalent to CopyInt32StringsInVec, but copies at
+// most one int32.
+func CopyInt32StringInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst *int32, opts IOOpts) (int64, error) {
+ dsts := [1]int32{*dst}
+ n, err := CopyInt32StringsInVec(ctx, uio, ars, dsts[:], opts)
+ *dst = dsts[0]
+ return n, err
+}
+
+// IOSequence holds arguments to IO methods.
+type IOSequence struct {
+ IO IO
+ Addrs AddrRangeSeq
+ Opts IOOpts
+}
+
+// NumBytes returns s.Addrs.NumBytes().
+//
+// Note that NumBytes() may return 0 even if !s.Addrs.IsEmpty(), since
+// s.Addrs may contain a non-zero number of zero-length AddrRanges.
+// Many clients of
+// IOSequence currently do something like:
+//
+// if ioseq.NumBytes() == 0 {
+// return 0, nil
+// }
+// if f.availableBytes == 0 {
+// return 0, syserror.ErrWouldBlock
+// }
+// return ioseq.CopyOutFrom(..., reader)
+//
+// In such cases, using s.Addrs.IsEmpty() will cause them to have the wrong
+// behavior for zero-length I/O. However, using s.NumBytes() == 0 instead means
+// that we will return success for zero-length I/O in cases where Linux would
+// return EFAULT due to a failed access_ok() check, so in the long term we
+// should move checks for ErrWouldBlock etc. into the body of
+// reader.ReadToBlocks and use s.Addrs.IsEmpty() instead.
+func (s IOSequence) NumBytes() int64 {
+ return s.Addrs.NumBytes()
+}
+
+// DropFirst returns a copy of s with s.Addrs.DropFirst(n).
+//
+// Preconditions: As for AddrRangeSeq.DropFirst.
+func (s IOSequence) DropFirst(n int) IOSequence {
+ return IOSequence{s.IO, s.Addrs.DropFirst(n), s.Opts}
+}
+
+// DropFirst64 returns a copy of s with s.Addrs.DropFirst64(n).
+//
+// Preconditions: As for AddrRangeSeq.DropFirst64.
+func (s IOSequence) DropFirst64(n int64) IOSequence {
+ return IOSequence{s.IO, s.Addrs.DropFirst64(n), s.Opts}
+}
+
+// TakeFirst returns a copy of s with s.Addrs.TakeFirst(n).
+//
+// Preconditions: As for AddrRangeSeq.TakeFirst.
+func (s IOSequence) TakeFirst(n int) IOSequence {
+ return IOSequence{s.IO, s.Addrs.TakeFirst(n), s.Opts}
+}
+
+// TakeFirst64 returns a copy of s with s.Addrs.TakeFirst64(n).
+//
+// Preconditions: As for AddrRangeSeq.TakeFirst64.
+func (s IOSequence) TakeFirst64(n int64) IOSequence {
+ return IOSequence{s.IO, s.Addrs.TakeFirst64(n), s.Opts}
+}
+
+// CopyOut invokes CopyOutVec over s.Addrs.
+//
+// As with CopyOutVec, if s.NumBytes() < len(src), the copy will be truncated
+// to s.NumBytes(), and a nil error will be returned.
+//
+// Preconditions: As for CopyOutVec.
+func (s IOSequence) CopyOut(ctx context.Context, src []byte) (int, error) {
+ return CopyOutVec(ctx, s.IO, s.Addrs, src, s.Opts)
+}
+
+// CopyIn invokes CopyInVec over s.Addrs.
+//
+// As with CopyInVec, if s.NumBytes() < len(dst), the copy will be truncated to
+// s.NumBytes(), and a nil error will be returned.
+//
+// Preconditions: As for CopyInVec.
+func (s IOSequence) CopyIn(ctx context.Context, dst []byte) (int, error) {
+ return CopyInVec(ctx, s.IO, s.Addrs, dst, s.Opts)
+}
+
+// ZeroOut invokes ZeroOutVec over s.Addrs.
+//
+// As with ZeroOutVec, if s.NumBytes() < toZero, the write will be truncated
+// to s.NumBytes(), and a nil error will be returned.
+//
+// Preconditions: As for ZeroOutVec.
+func (s IOSequence) ZeroOut(ctx context.Context, toZero int64) (int64, error) {
+ return ZeroOutVec(ctx, s.IO, s.Addrs, toZero, s.Opts)
+}
+
+// CopyOutFrom invokes s.CopyOutFrom over s.Addrs.
+//
+// Preconditions: As for IO.CopyOutFrom.
+func (s IOSequence) CopyOutFrom(ctx context.Context, src safemem.Reader) (int64, error) {
+ return s.IO.CopyOutFrom(ctx, s.Addrs, src, s.Opts)
+}
+
+// CopyInTo invokes s.CopyInTo over s.Addrs.
+//
+// Preconditions: As for IO.CopyInTo.
+func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, error) {
+ return s.IO.CopyInTo(ctx, s.Addrs, dst, s.Opts)
+}
+
+// Reader returns an io.Reader that reads from s. Reads beyond the end of s
+// return io.EOF. The preconditions that apply to s.CopyIn also apply to the
+// returned io.Reader.Read.
+func (s IOSequence) Reader(ctx context.Context) io.Reader {
+ return &ioSequenceReadWriter{ctx, s}
+}
+
+// Writer returns an io.Writer that writes to s. Writes beyond the end of s
+// return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also
+// apply to the returned io.Writer.Write.
+func (s IOSequence) Writer(ctx context.Context) io.Writer {
+ return &ioSequenceReadWriter{ctx, s}
+}
+
+// ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when
+// attempting to write beyond the end of the IOSequence.
+var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence")
+
+type ioSequenceReadWriter struct {
+ ctx context.Context
+ s IOSequence
+}
+
+// Read implements io.Reader.Read.
+func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) {
+ n, err := rw.s.CopyIn(rw.ctx, dst)
+ rw.s = rw.s.DropFirst(n)
+ if err == nil && rw.s.NumBytes() == 0 {
+ err = io.EOF
+ }
+ return n, err
+}
+
+// Write implements io.Writer.Write.
+func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) {
+ n, err := rw.s.CopyOut(rw.ctx, src)
+ rw.s = rw.s.DropFirst(n)
+ if err == nil && n < len(src) {
+ err = ErrEndOfIOSequence
+ }
+ return n, err
+}
diff --git a/pkg/sentry/usermem/usermem_arm64.go b/pkg/sentry/usermem/usermem_arm64.go
new file mode 100644
index 000000000..fdfc30a66
--- /dev/null
+++ b/pkg/sentry/usermem/usermem_arm64.go
@@ -0,0 +1,53 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package usermem
+
+import (
+ "encoding/binary"
+ "syscall"
+)
+
+const (
+ // PageSize is the system page size.
+ // arm64 support 4K/16K/64K page size,
+ // which can be get by syscall.Getpagesize().
+ // Currently, only 4K page size is supported.
+ PageSize = 1 << PageShift
+
+ // HugePageSize is the system huge page size.
+ HugePageSize = 1 << HugePageShift
+
+ // PageShift is the binary log of the system page size.
+ PageShift = 12
+
+ // HugePageShift is the binary log of the system huge page size.
+ // Should be calculated by "PageShift + (PageShift - 3)"
+ // when multiple page size support is ready.
+ HugePageShift = 21
+)
+
+var (
+ // ByteOrder is the native byte order (little endian).
+ ByteOrder = binary.LittleEndian
+)
+
+func init() {
+ // Make sure the page size is 4K on arm64 platform.
+ if size := syscall.Getpagesize(); size != PageSize {
+ panic("Only 4K page size is supported on arm64!")
+ }
+}
diff --git a/pkg/sentry/usermem/usermem_state_autogen.go b/pkg/sentry/usermem/usermem_state_autogen.go
new file mode 100755
index 000000000..bc728eab3
--- /dev/null
+++ b/pkg/sentry/usermem/usermem_state_autogen.go
@@ -0,0 +1,49 @@
+// automatically generated by stateify.
+
+package usermem
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *AccessType) beforeSave() {}
+func (x *AccessType) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Read", &x.Read)
+ m.Save("Write", &x.Write)
+ m.Save("Execute", &x.Execute)
+}
+
+func (x *AccessType) afterLoad() {}
+func (x *AccessType) load(m state.Map) {
+ m.Load("Read", &x.Read)
+ m.Load("Write", &x.Write)
+ m.Load("Execute", &x.Execute)
+}
+
+func (x *Addr) save(m state.Map) {
+ m.SaveValue("", (uintptr)(*x))
+}
+
+func (x *Addr) load(m state.Map) {
+ m.LoadValue("", new(uintptr), func(y interface{}) { *x = (Addr)(y.(uintptr)) })
+}
+
+func (x *AddrRange) beforeSave() {}
+func (x *AddrRange) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *AddrRange) afterLoad() {}
+func (x *AddrRange) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func init() {
+ state.Register("usermem.AccessType", (*AccessType)(nil), state.Fns{Save: (*AccessType).save, Load: (*AccessType).load})
+ state.Register("usermem.Addr", (*Addr)(nil), state.Fns{Save: (*Addr).save, Load: (*Addr).load})
+ state.Register("usermem.AddrRange", (*AddrRange)(nil), state.Fns{Save: (*AddrRange).save, Load: (*AddrRange).load})
+}
diff --git a/pkg/sentry/usermem/usermem_unsafe.go b/pkg/sentry/usermem/usermem_unsafe.go
new file mode 100644
index 000000000..876783e78
--- /dev/null
+++ b/pkg/sentry/usermem/usermem_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "unsafe"
+)
+
+// stringFromImmutableBytes is equivalent to string(bs), except that it never
+// copies even if escape analysis can't prove that bs does not escape. This is
+// only valid if bs is never mutated after stringFromImmutableBytes returns.
+func stringFromImmutableBytes(bs []byte) string {
+ // Compare strings.Builder.String().
+ return *(*string)(unsafe.Pointer(&bs))
+}
diff --git a/pkg/sentry/usermem/usermem_x86.go b/pkg/sentry/usermem/usermem_x86.go
new file mode 100644
index 000000000..8059b72d2
--- /dev/null
+++ b/pkg/sentry/usermem/usermem_x86.go
@@ -0,0 +1,38 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 i386
+
+package usermem
+
+import "encoding/binary"
+
+const (
+ // PageSize is the system page size.
+ PageSize = 1 << PageShift
+
+ // HugePageSize is the system huge page size.
+ HugePageSize = 1 << HugePageShift
+
+ // PageShift is the binary log of the system page size.
+ PageShift = 12
+
+ // HugePageShift is the binary log of the system huge page size.
+ HugePageShift = 21
+)
+
+var (
+ // ByteOrder is the native byte order (little endian).
+ ByteOrder = binary.LittleEndian
+)
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
new file mode 100644
index 000000000..2fc4472dd
--- /dev/null
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -0,0 +1,305 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package watchdog is responsible for monitoring the sentry for tasks that may
+// potentially be stuck or looping inderterminally causing hard to debug hungs in
+// the untrusted app.
+//
+// It works by periodically querying all tasks to check whether they are in user
+// mode (RunUser), kernel mode (RunSys), or blocked in the kernel (OffCPU). Tasks
+// that have been running in kernel mode for a long time in the same syscall
+// without blocking are considered stuck and are reported.
+//
+// When a stuck task is detected, the watchdog can take one of the following actions:
+// 1. LogWarning: Logs a warning message followed by a stack dump of all goroutines.
+// If a tasks continues to be stuck, the message will repeat every minute, unless
+// a new stuck task is detected
+// 2. Panic: same as above, followed by panic()
+//
+package watchdog
+
+import (
+ "bytes"
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/metric"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+)
+
+// DefaultTimeout is a resonable timeout value for most applications.
+const DefaultTimeout = 3 * time.Minute
+
+// descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period
+// is discounted from task's last update time. It's set high enough that small scheduling delays won't
+// trigger it.
+const descheduleThreshold = 1 * time.Second
+
+var stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
+
+// Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck.
+var stackDumpSameTaskPeriod = time.Minute
+
+// Action defines what action to take when a stuck task is detected.
+type Action int
+
+const (
+ // LogWarning logs warning message followed by stack trace.
+ LogWarning Action = iota
+ // Panic will do the same logging as LogWarning and panic().
+ Panic
+)
+
+// String returns Action's string representation.
+func (a Action) String() string {
+ switch a {
+ case LogWarning:
+ return "LogWarning"
+ case Panic:
+ return "Panic"
+ default:
+ panic(fmt.Sprintf("Invalid action: %d", a))
+ }
+}
+
+// Watchdog is the main watchdog class. It controls a goroutine that periodically
+// analyses all tasks and reports if any of them appear to be stuck.
+type Watchdog struct {
+ // period indicates how often to check all tasks. It's calculated based on
+ // 'taskTimeout'.
+ period time.Duration
+
+ // taskTimeout is the amount of time to allow a task to execute the same syscall
+ // without blocking before it's declared stuck.
+ taskTimeout time.Duration
+
+ // timeoutAction indicates what action to take when a stuck tasks is detected.
+ timeoutAction Action
+
+ // k is where the tasks come from.
+ k *kernel.Kernel
+
+ // stop is used to notify to watchdog should stop.
+ stop chan struct{}
+
+ // done is used to notify when the watchdog has stopped.
+ done chan struct{}
+
+ // offenders map contains all tasks that are currently stuck.
+ offenders map[*kernel.Task]*offender
+
+ // lastStackDump tracks the last time a stack dump was generated to prevent
+ // spamming the log.
+ lastStackDump time.Time
+
+ // lastRun is set to the last time the watchdog executed a monitoring loop.
+ lastRun ktime.Time
+
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // started is true if the watchdog has been started before.
+ started bool
+}
+
+type offender struct {
+ lastUpdateTime ktime.Time
+}
+
+// New creates a new watchdog.
+func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog {
+ // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much.
+ period := taskTimeout / 4
+ return &Watchdog{
+ k: k,
+ period: period,
+ taskTimeout: taskTimeout,
+ timeoutAction: a,
+ offenders: make(map[*kernel.Task]*offender),
+ stop: make(chan struct{}),
+ done: make(chan struct{}),
+ }
+}
+
+// Start starts the watchdog.
+func (w *Watchdog) Start() {
+ if w.taskTimeout == 0 {
+ log.Infof("Watchdog disabled")
+ return
+ }
+
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.started {
+ return
+ }
+
+ w.lastRun = w.k.MonotonicClock().Now()
+
+ log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction)
+ go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore.
+ w.started = true
+}
+
+// Stop requests the watchdog to stop and wait for it.
+func (w *Watchdog) Stop() {
+ if w.taskTimeout == 0 {
+ return
+ }
+
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if !w.started {
+ return
+ }
+ log.Infof("Stopping watchdog")
+ w.stop <- struct{}{}
+ <-w.done
+ w.started = false
+ log.Infof("Watchdog stopped")
+}
+
+// loop is the main watchdog routine. It only returns when 'Stop()' is called.
+func (w *Watchdog) loop() {
+ // Loop until someone stops it.
+ for {
+ select {
+ case <-w.stop:
+ w.done <- struct{}{}
+ return
+ case <-time.After(w.period):
+ w.runTurn()
+ }
+ }
+}
+
+// runTurn runs a single pass over all tasks and reports anything it finds.
+func (w *Watchdog) runTurn() {
+ // Someone needs to watch the watchdog. The call below can get stuck if there
+ // is a deadlock affecting root's PID namespace mutex. Run it in a goroutine
+ // and report if it takes too long to return.
+ var tasks []*kernel.Task
+ done := make(chan struct{})
+ go func() { // S/R-SAFE: watchdog is stopped and restarted during S/R.
+ tasks = w.k.TaskSet().Root.Tasks()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(w.taskTimeout):
+ // Report if the watchdog is not making progress.
+ // No one is wathching the watchdog watcher though.
+ w.reportStuckWatchdog()
+ <-done
+ }
+
+ newOffenders := make(map[*kernel.Task]*offender)
+ newTaskFound := false
+ now := ktime.FromNanoseconds(int64(w.k.CPUClockNow() * uint64(linux.ClockTick)))
+
+ // The process may be running with low CPU limit making tasks appear stuck because
+ // are starved of CPU cycles. An estimate is that Tasks could have been starved
+ // since the last time the watchdog run. If the watchdog detects that scheduling
+ // is off, it will discount the entire duration since last run from 'lastUpdateTime'.
+ discount := time.Duration(0)
+ if now.Sub(w.lastRun.Add(w.period)) > descheduleThreshold {
+ discount = now.Sub(w.lastRun)
+ }
+ w.lastRun = now
+
+ log.Infof("Watchdog starting loop, tasks: %d, discount: %v", len(tasks), discount)
+ for _, t := range tasks {
+ tsched := t.TaskGoroutineSchedInfo()
+
+ // An offender is a task running inside the kernel for longer than the specified timeout.
+ if tsched.State == kernel.TaskGoroutineRunningSys {
+ lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick)))
+ elapsed := now.Sub(lastUpdateTime) - discount
+ if elapsed > w.taskTimeout {
+ tc, ok := w.offenders[t]
+ if !ok {
+ // New stuck task detected.
+ //
+ // TODO(b/65849403): Tasks blocked doing IO may be considered stuck in kernel.
+ tc = &offender{lastUpdateTime: lastUpdateTime}
+ stuckTasks.Increment()
+ newTaskFound = true
+ }
+ newOffenders[t] = tc
+ }
+ }
+ }
+ if len(newOffenders) > 0 {
+ w.report(newOffenders, newTaskFound, now)
+ }
+
+ // Remember which tasks have been reported.
+ w.offenders = newOffenders
+}
+
+// report takes appropriate action when a stuck task is detected.
+func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound bool, now ktime.Time) {
+ var buf bytes.Buffer
+ buf.WriteString(fmt.Sprintf("Sentry detected %d stuck task(s):\n", len(offenders)))
+ for t, o := range offenders {
+ tid := w.k.TaskSet().Root.IDOfTask(t)
+ buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime)))
+ }
+ buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine")
+ w.onStuckTask(newTaskFound, &buf)
+}
+
+func (w *Watchdog) reportStuckWatchdog() {
+ var buf bytes.Buffer
+ buf.WriteString("Watchdog goroutine is stuck:\n")
+ w.onStuckTask(true, &buf)
+}
+
+func (w *Watchdog) onStuckTask(newTaskFound bool, buf *bytes.Buffer) {
+ switch w.timeoutAction {
+ case LogWarning:
+ // Dump stack only if a new task is detected or if it sometime has passed since
+ // the last time a stack dump was generated.
+ if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod {
+ buf.WriteString("\n...[stack dump skipped]...")
+ log.Warningf(buf.String())
+ } else {
+ log.TracebackAll(buf.String())
+ w.lastStackDump = time.Now()
+ }
+
+ case Panic:
+ // Panic will skip over running tasks, which is likely the culprit here. So manually
+ // dump all stacks before panic'ing.
+ log.TracebackAll(buf.String())
+
+ // Attempt to flush metrics, timeout and move on in case metrics are stuck as well.
+ metricsEmitted := make(chan struct{}, 1)
+ go func() { // S/R-SAFE: watchdog is stopped during save and restarted after restore.
+ // Flush metrics before killing process.
+ metric.EmitMetricUpdate()
+ metricsEmitted <- struct{}{}
+ }()
+ select {
+ case <-metricsEmitted:
+ case <-time.After(1 * time.Second):
+ }
+ panic("Sentry detected stuck task(s). See stack trace and message above for more details")
+ }
+}
diff --git a/pkg/sentry/watchdog/watchdog_state_autogen.go b/pkg/sentry/watchdog/watchdog_state_autogen.go
new file mode 100755
index 000000000..530ac6a07
--- /dev/null
+++ b/pkg/sentry/watchdog/watchdog_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package watchdog
+
diff --git a/pkg/sleep/commit_amd64.s b/pkg/sleep/commit_amd64.s
new file mode 100644
index 000000000..bc4ac2c3c
--- /dev/null
+++ b/pkg/sleep/commit_amd64.s
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+#define preparingG 1
+
+// See commit_noasm.go for a description of commitSleep.
+//
+// func commitSleep(g uintptr, waitingG *uintptr) bool
+TEXT ·commitSleep(SB),NOSPLIT,$0-24
+ MOVQ waitingG+8(FP), CX
+ MOVQ g+0(FP), DX
+
+ // Store the G in waitingG if it's still preparingG. If it's anything
+ // else it means a waker has aborted the sleep.
+ MOVQ $preparingG, AX
+ LOCK
+ CMPXCHGQ DX, 0(CX)
+
+ SETEQ AX
+ MOVB AX, ret+16(FP)
+
+ RET
diff --git a/pkg/sleep/commit_asm.go b/pkg/sleep/commit_asm.go
new file mode 100644
index 000000000..35e2cc337
--- /dev/null
+++ b/pkg/sleep/commit_asm.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package sleep
+
+// See commit_noasm.go for a description of commitSleep.
+func commitSleep(g uintptr, waitingG *uintptr) bool
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
new file mode 100644
index 000000000..686b1da3d
--- /dev/null
+++ b/pkg/sleep/commit_noasm.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+// +build !amd64
+
+package sleep
+
+import "sync/atomic"
+
+// commitSleep signals to wakers that the given g is now sleeping. Wakers can
+// then fetch it and wake it.
+//
+// The commit may fail if wakers have been asserted after our last check, in
+// which case they will have set s.waitingG to zero.
+//
+// It is written in assembly because it is called from g0, so it doesn't have
+// a race context.
+func commitSleep(g uintptr, waitingG *uintptr) bool {
+ for {
+ // Check if the wait was aborted.
+ if atomic.LoadUintptr(waitingG) == 0 {
+ return false
+ }
+
+ // Try to store the G so that wakers know who to wake.
+ if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) {
+ return true
+ }
+ }
+}
diff --git a/pkg/sleep/sleep_state_autogen.go b/pkg/sleep/sleep_state_autogen.go
new file mode 100755
index 000000000..e444aa91a
--- /dev/null
+++ b/pkg/sleep/sleep_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package sleep
+
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
new file mode 100644
index 000000000..8f5e60a25
--- /dev/null
+++ b/pkg/sleep/sleep_unsafe.go
@@ -0,0 +1,403 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.11
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+// Package sleep allows goroutines to efficiently sleep on multiple sources of
+// notifications (wakers). It offers O(1) complexity, which is different from
+// multi-channel selects which have O(n) complexity (where n is the number of
+// channels) and a considerable constant factor.
+//
+// It is similar to edge-triggered epoll waits, where the user registers each
+// object of interest once, and then can repeatedly wait on all of them.
+//
+// A Waker object is used to wake a sleeping goroutine (G) up, or prevent it
+// from going to sleep next. A Sleeper object is used to receive notifications
+// from wakers, and if no notifications are available, to optionally sleep until
+// one becomes available.
+//
+// A Waker can be associated with at most one Sleeper, but a Sleeper can be
+// associated with multiple Wakers. A Sleeper has a list of asserted (ready)
+// wakers; when Fetch() is called repeatedly, elements from this list are
+// returned until the list becomes empty in which case the goroutine goes to
+// sleep. When Assert() is called on a Waker, it adds itself to the Sleeper's
+// asserted list and wakes the G up from its sleep if needed.
+//
+// Sleeper objects are expected to be used as follows, with just one goroutine
+// executing this code:
+//
+// // One time set-up.
+// s := sleep.Sleeper{}
+// s.AddWaker(&w1, constant1)
+// s.AddWaker(&w2, constant2)
+//
+// // Called repeatedly.
+// for {
+// switch id, _ := s.Fetch(true); id {
+// case constant1:
+// // Do work triggered by w1 being asserted.
+// case constant2:
+// // Do work triggered by w2 being asserted.
+// }
+// }
+//
+// And Waker objects are expected to call w.Assert() when they want the sleeper
+// to wake up and perform work.
+//
+// The notifications are edge-triggered, which means that if a Waker calls
+// Assert() several times before the sleeper has the chance to wake up, it will
+// only be notified once and should perform all pending work (alternatively, it
+// can also call Assert() on the waker, to ensure that it will wake up again).
+//
+// The "unsafeness" here is in the casts to/from unsafe.Pointer, which is safe
+// when only one type is used for each unsafe.Pointer (which is the case here),
+// we should just make sure that this remains the case in the future. The usage
+// of unsafe package could be confined to sharedWaker and sharedSleeper types
+// that would hold pointers in atomic.Pointers, but the go compiler currently
+// can't optimize these as well (it won't inline their method calls), which
+// reduces performance.
+package sleep
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+const (
+ // preparingG is stored in sleepers to indicate that they're preparing
+ // to sleep.
+ preparingG = 1
+)
+
+var (
+ // assertedSleeper is a sentinel sleeper. A pointer to it is stored in
+ // wakers that are asserted.
+ assertedSleeper Sleeper
+)
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(uintptr, *uintptr) bool, wg *uintptr, reason uint8, traceEv byte, traceskip int)
+
+//go:linkname goready runtime.goready
+func goready(g uintptr, traceskip int)
+
+// Sleeper allows a goroutine to sleep and receive wake up notifications from
+// Wakers in an efficient way.
+//
+// This is similar to edge-triggered epoll in that wakers are added to the
+// sleeper once and the sleeper can then repeatedly sleep in O(1) time while
+// waiting on all wakers.
+//
+// None of the methods in a Sleeper can be called concurrently. Wakers that have
+// been added to a sleeper A can only be added to another sleeper after A.Done()
+// returns. These restrictions allow this to be implemented lock-free.
+//
+// This struct is thread-compatible.
+type Sleeper struct {
+ // sharedList is a "stack" of asserted wakers. They atomically add
+ // themselves to the front of this list as they become asserted.
+ sharedList unsafe.Pointer
+
+ // localList is a list of asserted wakers that is only accessible to the
+ // waiter, and thus doesn't have to be accessed atomically. When
+ // fetching more wakers, the waiter will first go through this list, and
+ // only when it's empty will it atomically fetch wakers from
+ // sharedList.
+ localList *Waker
+
+ // allWakers is a list with all wakers that have been added to this
+ // sleeper. It is used during cleanup to remove associations.
+ allWakers *Waker
+
+ // waitingG holds the G that is sleeping, if any. It is used by wakers
+ // to determine which G, if any, they should wake.
+ waitingG uintptr
+}
+
+// AddWaker associates the given waker to the sleeper. id is the value to be
+// returned when the sleeper is woken by the given waker.
+func (s *Sleeper) AddWaker(w *Waker, id int) {
+ // Add the waker to the list of all wakers.
+ w.allWakersNext = s.allWakers
+ s.allWakers = w
+ w.id = id
+
+ // Try to associate the waker with the sleeper. If it's already
+ // asserted, we simply enqueue it in the "ready" list.
+ for {
+ p := (*Sleeper)(atomic.LoadPointer(&w.s))
+ if p == &assertedSleeper {
+ s.enqueueAssertedWaker(w)
+ return
+ }
+
+ if atomic.CompareAndSwapPointer(&w.s, usleeper(p), usleeper(s)) {
+ return
+ }
+ }
+}
+
+// nextWaker returns the next waker in the notification list, blocking if
+// needed.
+func (s *Sleeper) nextWaker(block bool) *Waker {
+ // Attempt to replenish the local list if it's currently empty.
+ if s.localList == nil {
+ for atomic.LoadPointer(&s.sharedList) == nil {
+ // Fail request if caller requested that we
+ // don't block.
+ if !block {
+ return nil
+ }
+
+ // Indicate to wakers that we're about to sleep,
+ // this allows them to abort the wait by setting
+ // waitingG back to zero (which we'll notice
+ // before committing the sleep).
+ atomic.StoreUintptr(&s.waitingG, preparingG)
+
+ // Check if something was queued while we were
+ // preparing to sleep. We need this interleaving
+ // to avoid missing wake ups.
+ if atomic.LoadPointer(&s.sharedList) != nil {
+ atomic.StoreUintptr(&s.waitingG, 0)
+ break
+ }
+
+ // Try to commit the sleep and report it to the
+ // tracer as a select.
+ //
+ // gopark puts the caller to sleep and calls
+ // commitSleep to decide whether to immediately
+ // wake the caller up or to leave it sleeping.
+ const traceEvGoBlockSelect = 24
+ // See:runtime2.go in the go runtime package for
+ // the values to pass as the waitReason here.
+ const waitReasonSelect = 9
+ gopark(commitSleep, &s.waitingG, waitReasonSelect, traceEvGoBlockSelect, 0)
+ }
+
+ // Pull the shared list out and reverse it in the local
+ // list. Given that wakers push themselves in reverse
+ // order, we fix things here.
+ v := (*Waker)(atomic.SwapPointer(&s.sharedList, nil))
+ for v != nil {
+ cur := v
+ v = v.next
+
+ cur.next = s.localList
+ s.localList = cur
+ }
+ }
+
+ // Remove the waker in the front of the list.
+ w := s.localList
+ s.localList = w.next
+
+ return w
+}
+
+// Fetch fetches the next wake-up notification. If a notification is immediately
+// available, it is returned right away. Otherwise, the behavior depends on the
+// value of 'block': if true, the current goroutine blocks until a notification
+// arrives, then returns it; if false, returns 'ok' as false.
+//
+// When 'ok' is true, the value of 'id' corresponds to the id associated with
+// the waker; when 'ok' is false, 'id' is undefined.
+//
+// N.B. This method is *not* thread-safe. Only one goroutine at a time is
+// allowed to call this method.
+func (s *Sleeper) Fetch(block bool) (id int, ok bool) {
+ for {
+ w := s.nextWaker(block)
+ if w == nil {
+ return -1, false
+ }
+
+ // Reassociate the waker with the sleeper. If the waker was
+ // still asserted we can return it, otherwise try the next one.
+ old := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(s)))
+ if old == &assertedSleeper {
+ return w.id, true
+ }
+ }
+}
+
+// Done is used to indicate that the caller won't use this Sleeper anymore. It
+// removes the association with all wakers so that they can be safely reused
+// by another sleeper after Done() returns.
+func (s *Sleeper) Done() {
+ // Remove all associations that we can, and build a list of the ones
+ // we could not. An association can be removed right away from waker w
+ // if w.s has a pointer to the sleeper, that is, the waker is not
+ // asserted yet. By atomically switching w.s to nil, we guarantee that
+ // subsequent calls to Assert() on the waker will not result in it being
+ // queued to this sleeper.
+ var pending *Waker
+ w := s.allWakers
+ for w != nil {
+ next := w.allWakersNext
+ for {
+ t := atomic.LoadPointer(&w.s)
+ if t != usleeper(s) {
+ w.allWakersNext = pending
+ pending = w
+ break
+ }
+
+ if atomic.CompareAndSwapPointer(&w.s, t, nil) {
+ break
+ }
+ }
+ w = next
+ }
+
+ // The associations that we could not remove are either asserted, or in
+ // the process of being asserted, or have been asserted and cleared
+ // before being pulled from the sleeper lists. We must wait for them all
+ // to make it to the sleeper lists, so that we know that the wakers
+ // won't do any more work towards waking this sleeper up.
+ for pending != nil {
+ pulled := s.nextWaker(true)
+
+ // Remove the waker we just pulled from the list of associated
+ // wakers.
+ prev := &pending
+ for w := *prev; w != nil; w = *prev {
+ if pulled == w {
+ *prev = w.allWakersNext
+ break
+ }
+ prev = &w.allWakersNext
+ }
+ }
+ s.allWakers = nil
+}
+
+// enqueueAssertedWaker enqueues an asserted waker to the "ready" circular list
+// of wakers that want to notify the sleeper.
+func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
+ // Add the new waker to the front of the list.
+ for {
+ v := (*Waker)(atomic.LoadPointer(&s.sharedList))
+ w.next = v
+ if atomic.CompareAndSwapPointer(&s.sharedList, uwaker(v), uwaker(w)) {
+ break
+ }
+ }
+
+ for {
+ // Nothing to do if there isn't a G waiting.
+ g := atomic.LoadUintptr(&s.waitingG)
+ if g == 0 {
+ return
+ }
+
+ // Signal to the sleeper that a waker has been asserted.
+ if atomic.CompareAndSwapUintptr(&s.waitingG, g, 0) {
+ if g != preparingG {
+ // We managed to get a G. Wake it up.
+ goready(g, 0)
+ }
+ }
+ }
+}
+
+// Waker represents a source of wake-up notifications to be sent to sleepers. A
+// waker can be associated with at most one sleeper at a time, and at any given
+// time is either in asserted or non-asserted state.
+//
+// Once asserted, the waker remains so until it is manually cleared or a sleeper
+// consumes its assertion (i.e., a sleeper wakes up or is prevented from going
+// to sleep due to the waker).
+//
+// This struct is thread-safe, that is, its methods can be called concurrently
+// by multiple goroutines.
+type Waker struct {
+ // s is the sleeper that this waker can wake up. Only one sleeper at a
+ // time is allowed. This field can have three classes of values:
+ // nil -- the waker is not asserted: it either is not associated with
+ // a sleeper, or is queued to a sleeper due to being previously
+ // asserted. This is the zero value.
+ // &assertedSleeper -- the waker is asserted.
+ // otherwise -- the waker is not asserted, and is associated with the
+ // given sleeper. Once it transitions to asserted state, the
+ // associated sleeper will be woken.
+ s unsafe.Pointer
+
+ // next is used to form a linked list of asserted wakers in a sleeper.
+ next *Waker
+
+ // allWakersNext is used to form a linked list of all wakers associated
+ // to a given sleeper.
+ allWakersNext *Waker
+
+ // id is the value to be returned to sleepers when they wake up due to
+ // this waker being asserted.
+ id int
+}
+
+// Assert moves the waker to an asserted state, if it isn't asserted yet. When
+// asserted, the waker will cause its matching sleeper to wake up.
+func (w *Waker) Assert() {
+ // Nothing to do if the waker is already asserted. This check allows us
+ // to complete this case (already asserted) without any interlocked
+ // operations on x86.
+ if atomic.LoadPointer(&w.s) == usleeper(&assertedSleeper) {
+ return
+ }
+
+ // Mark the waker as asserted, and wake up a sleeper if there is one.
+ switch s := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(&assertedSleeper))); s {
+ case nil:
+ case &assertedSleeper:
+ default:
+ s.enqueueAssertedWaker(w)
+ }
+}
+
+// Clear moves the waker to then non-asserted state and returns whether it was
+// asserted before being cleared.
+//
+// N.B. The waker isn't removed from the "ready" list of a sleeper (if it
+// happens to be in one), but the sleeper will notice that it is not asserted
+// anymore and won't return it to the caller.
+func (w *Waker) Clear() bool {
+ // Nothing to do if the waker is not asserted. This check allows us to
+ // complete this case (already not asserted) without any interlocked
+ // operations on x86.
+ if atomic.LoadPointer(&w.s) != usleeper(&assertedSleeper) {
+ return false
+ }
+
+ // Try to store nil in the sleeper, which indicates that the waker is
+ // not asserted.
+ return atomic.CompareAndSwapPointer(&w.s, usleeper(&assertedSleeper), nil)
+}
+
+// IsAsserted returns whether the waker is currently asserted (i.e., if it's
+// currently in a state that would cause its matching sleeper to wake up).
+func (w *Waker) IsAsserted() bool {
+ return (*Sleeper)(atomic.LoadPointer(&w.s)) == &assertedSleeper
+}
+
+func usleeper(s *Sleeper) unsafe.Pointer {
+ return unsafe.Pointer(s)
+}
+
+func uwaker(w *Waker) unsafe.Pointer {
+ return unsafe.Pointer(w)
+}
diff --git a/pkg/state/addr_range.go b/pkg/state/addr_range.go
new file mode 100755
index 000000000..45720c643
--- /dev/null
+++ b/pkg/state/addr_range.go
@@ -0,0 +1,62 @@
+package state
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type addrRange struct {
+ // Start is the inclusive start of the range.
+ Start uintptr
+
+ // End is the exclusive end of the range.
+ End uintptr
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r addrRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r addrRange) Length() uintptr {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r addrRange) Contains(x uintptr) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r addrRange) Overlaps(r2 addrRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r addrRange) IsSupersetOf(r2 addrRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r addrRange) Intersect(r2 addrRange) addrRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r addrRange) CanSplitAt(x uintptr) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/state/addr_set.go b/pkg/state/addr_set.go
new file mode 100755
index 000000000..bce7da87d
--- /dev/null
+++ b/pkg/state/addr_set.go
@@ -0,0 +1,1274 @@
+package state
+
+import (
+ __generics_imported0 "reflect"
+)
+
+import (
+ "bytes"
+ "fmt"
+)
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ addrminDegree = 10
+
+ addrmaxDegree = 2 * addrminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type addrSet struct {
+ root addrnode `state:".(*addrSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *addrSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *addrSet) IsEmptyRange(r addrRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *addrSet) Span() uintptr {
+ var sz uintptr
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *addrSet) SpanRange(r addrRange) uintptr {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uintptr
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *addrSet) FirstSegment() addrIterator {
+ if s.root.nrSegments == 0 {
+ return addrIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *addrSet) LastSegment() addrIterator {
+ if s.root.nrSegments == 0 {
+ return addrIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *addrSet) FirstGap() addrGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return addrGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *addrSet) LastGap() addrGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return addrGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *addrSet) Find(key uintptr) (addrIterator, addrGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return addrIterator{n, i}, addrGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return addrIterator{}, addrGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *addrSet) FindSegment(key uintptr) addrIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *addrSet) LowerBoundSegment(min uintptr) addrIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *addrSet) UpperBoundSegment(max uintptr) addrIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *addrSet) FindGap(key uintptr) addrGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *addrSet) LowerBoundGap(min uintptr) addrGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *addrSet) UpperBoundGap(max uintptr) addrGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *addrSet) Add(r addrRange, val __generics_imported0.Value) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *addrSet) AddWithoutMerging(r addrRange, val __generics_imported0.Value) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *addrSet) Insert(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (addrSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ return next
+ }
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *addrSet) InsertWithoutMerging(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
+func (s *addrSet) InsertWithoutMergingUnchecked(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ return addrIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *addrSet) Remove(seg addrIterator) addrGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ addrSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ return seg.node.rebalanceAfterRemove(addrGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *addrSet) RemoveAll() {
+ s.root = addrnode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *addrSet) RemoveRange(r addrRange) addrGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *addrSet) Merge(first, second addrIterator) addrIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *addrSet) MergeUnchecked(first, second addrIterator) addrIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (addrSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return addrIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *addrSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *addrSet) MergeRange(r addrRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *addrSet) MergeAdjacent(r addrRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *addrSet) Split(seg addrIterator, split uintptr) (addrIterator, addrIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *addrSet) SplitUnchecked(seg addrIterator, split uintptr) (addrIterator, addrIterator) {
+ val1, val2 := (addrSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), addrRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *addrSet) SplitAt(split uintptr) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *addrSet) Isolate(seg addrIterator, r addrRange) addrIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *addrSet) ApplyContiguous(r addrRange, fn func(seg addrIterator)) addrGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return addrGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return addrGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type addrnode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *addrnode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [addrmaxDegree - 1]addrRange
+ values [addrmaxDegree - 1]__generics_imported0.Value
+ children [addrmaxDegree]*addrnode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *addrnode) firstSegment() addrIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return addrIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *addrnode) lastSegment() addrIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return addrIterator{n, n.nrSegments - 1}
+}
+
+func (n *addrnode) prevSibling() *addrnode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *addrnode) nextSibling() *addrnode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *addrnode) rebalanceBeforeInsert(gap addrGapIterator) addrGapIterator {
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.nrSegments < addrmaxDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ left := &addrnode{
+ nrSegments: addrminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &addrnode{
+ nrSegments: addrminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:addrminDegree-1], n.keys[:addrminDegree-1])
+ copy(left.values[:addrminDegree-1], n.values[:addrminDegree-1])
+ copy(right.keys[:addrminDegree-1], n.keys[addrminDegree:])
+ copy(right.values[:addrminDegree-1], n.values[addrminDegree:])
+ n.keys[0], n.values[0] = n.keys[addrminDegree-1], n.values[addrminDegree-1]
+ addrzeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:addrminDegree], n.children[:addrminDegree])
+ copy(right.children[:addrminDegree], n.children[addrminDegree:])
+ addrzeroNodeSlice(n.children[2:])
+ for i := 0; i < addrminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < addrminDegree {
+ return addrGapIterator{left, gap.index}
+ }
+ return addrGapIterator{right, gap.index - addrminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[addrminDegree-1], n.values[addrminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &addrnode{
+ nrSegments: addrminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:addrminDegree-1], n.keys[addrminDegree:])
+ copy(sibling.values[:addrminDegree-1], n.values[addrminDegree:])
+ addrzeroValueSlice(n.values[addrminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:addrminDegree], n.children[addrminDegree:])
+ addrzeroNodeSlice(n.children[addrminDegree:])
+ for i := 0; i < addrminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = addrminDegree - 1
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < addrminDegree {
+ return gap
+ }
+ return addrGapIterator{sibling, gap.index - addrminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *addrnode) rebalanceAfterRemove(gap addrGapIterator) addrGapIterator {
+ for {
+ if n.nrSegments >= addrminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= addrminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return addrGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return addrGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= addrminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return addrGapIterator{n, n.nrSegments}
+ }
+ return addrGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+ if gap.node == left {
+ return addrGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return addrGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *addrnode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = addrGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ addrSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ n = p
+ }
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type addrIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *addrnode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg addrIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg addrIterator) Range() addrRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg addrIterator) Start() uintptr {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg addrIterator) End() uintptr {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+//
+// - r.Length() > 0.
+//
+// - The new range must not overlap an existing one: If seg.NextSegment().Ok(),
+// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then
+// r.start >= seg.PrevSegment().End().
+func (seg addrIterator) SetRangeUnchecked(r addrRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg addrIterator) SetRange(r addrRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid: start < seg.End(); if
+// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg addrIterator) SetStartUnchecked(start uintptr) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg addrIterator) SetStart(start uintptr) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid: end > seg.Start(); if
+// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg addrIterator) SetEndUnchecked(end uintptr) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg addrIterator) SetEnd(end uintptr) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg addrIterator) Value() __generics_imported0.Value {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg addrIterator) ValuePtr() *__generics_imported0.Value {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg addrIterator) SetValue(val __generics_imported0.Value) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg addrIterator) PrevSegment() addrIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return addrIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return addrIterator{}
+ }
+ return addrsegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg addrIterator) NextSegment() addrIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return addrIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return addrIterator{}
+ }
+ return addrsegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg addrIterator) PrevGap() addrGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return addrGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg addrIterator) NextGap() addrGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return addrGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg addrIterator) PrevNonEmpty() (addrIterator, addrGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return addrIterator{}, gap
+ }
+ return gap.PrevSegment(), addrGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg addrIterator) NextNonEmpty() (addrIterator, addrGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return addrIterator{}, gap
+ }
+ return gap.NextSegment(), addrGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type addrGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *addrnode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap addrGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap addrGapIterator) Range() addrRange {
+ return addrRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap addrGapIterator) Start() uintptr {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return addrSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap addrGapIterator) End() uintptr {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return addrSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap addrGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap addrGapIterator) PrevSegment() addrIterator {
+ return addrsegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap addrGapIterator) NextSegment() addrIterator {
+ return addrsegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap addrGapIterator) PrevGap() addrGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return addrGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap addrGapIterator) NextGap() addrGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return addrGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func addrsegmentBeforePosition(n *addrnode, i int) addrIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return addrIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return addrIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func addrsegmentAfterPosition(n *addrnode, i int) addrIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return addrIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return addrIterator{n, i}
+}
+
+func addrzeroValueSlice(slice []__generics_imported0.Value) {
+
+ for i := range slice {
+ addrSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func addrzeroNodeSlice(slice []*addrnode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *addrSet) String() string {
+ return s.root.String()
+}
+
+// String stringifes a node (and all of its children) for debugging.
+func (n *addrnode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *addrnode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type addrSegmentDataSlices struct {
+ Start []uintptr
+ End []uintptr
+ Values []__generics_imported0.Value
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *addrSet) ExportSortedSlices() *addrSegmentDataSlices {
+ var sds addrSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions: s must be empty. sds must represent a valid set (the segments
+// in sds must have valid lengths that do not overlap). The segments in sds
+// must be sorted in ascending key order.
+func (s *addrSet) ImportSortedSlices(sds *addrSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := addrRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+func (s *addrSet) saveRoot() *addrSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *addrSet) loadRoot(sds *addrSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
new file mode 100644
index 000000000..73a59f871
--- /dev/null
+++ b/pkg/state/decode.go
@@ -0,0 +1,605 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "sort"
+
+ "github.com/golang/protobuf/proto"
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// objectState represents an object that may be in the process of being
+// decoded. Specifically, it represents either a decoded object, or an an
+// interest in a future object that will be decoded. When that interest is
+// registered (via register), the storage for the object will be created, but
+// it will not be decoded until the object is encountered in the stream.
+type objectState struct {
+ // id is the id for this object.
+ //
+ // If this field is zero, then this is an anonymous (unregistered,
+ // non-reference primitive) object. This is immutable.
+ id uint64
+
+ // obj is the object. This may or may not be valid yet, depending on
+ // whether complete returns true. However, regardless of whether the
+ // object is valid, obj contains a final storage location for the
+ // object. This is immutable.
+ //
+ // Note that this must be addressable (obj.Addr() must not panic).
+ //
+ // The obj passed to the decode methods below will equal this obj only
+ // in the case of decoding the top-level object. However, the passed
+ // obj may represent individual fields, elements of a slice, etc. that
+ // are effectively embedded within the reflect.Value below but with
+ // distinct types.
+ obj reflect.Value
+
+ // blockedBy is the number of dependencies this object has.
+ blockedBy int
+
+ // blocking is a list of the objects blocked by this one.
+ blocking []*objectState
+
+ // callbacks is a set of callbacks to execute on load.
+ callbacks []func()
+
+ // path is the decoding path to the object.
+ path recoverable
+}
+
+// complete indicates the object is complete.
+func (os *objectState) complete() bool {
+ return os.blockedBy == 0 && len(os.callbacks) == 0
+}
+
+// checkComplete checks for completion. If the object is complete, pending
+// callbacks will be executed and checkComplete will be called on downstream
+// objects (those depending on this one).
+func (os *objectState) checkComplete(stats *Stats) {
+ if os.blockedBy > 0 {
+ return
+ }
+ stats.Start(os.obj)
+
+ // Fire all callbacks.
+ for _, fn := range os.callbacks {
+ fn()
+ }
+ os.callbacks = nil
+
+ // Clear all blocked objects.
+ for _, other := range os.blocking {
+ other.blockedBy--
+ other.checkComplete(stats)
+ }
+ os.blocking = nil
+ stats.Done()
+}
+
+// waitFor queues a dependency on the given object.
+func (os *objectState) waitFor(other *objectState, callback func()) {
+ os.blockedBy++
+ other.blocking = append(other.blocking, os)
+ if callback != nil {
+ other.callbacks = append(other.callbacks, callback)
+ }
+}
+
+// findCycleFor returns when the given object is found in the blocking set.
+func (os *objectState) findCycleFor(target *objectState) []*objectState {
+ for _, other := range os.blocking {
+ if other == target {
+ return []*objectState{target}
+ } else if childList := other.findCycleFor(target); childList != nil {
+ return append(childList, other)
+ }
+ }
+ return nil
+}
+
+// findCycle finds a dependency cycle.
+func (os *objectState) findCycle() []*objectState {
+ return append(os.findCycleFor(os), os)
+}
+
+// decodeState is a graph of objects in the process of being decoded.
+//
+// The decode process involves loading the breadth-first graph generated by
+// encode. This graph is read in it's entirety, ensuring that all object
+// storage is complete.
+//
+// As the graph is being serialized, a set of completion callbacks are
+// executed. These completion callbacks should form a set of acyclic subgraphs
+// over the original one. After decoding is complete, the objects are scanned
+// to ensure that all callbacks are executed, otherwise the callback graph was
+// not acyclic.
+type decodeState struct {
+ // objectByID is the set of objects in progress.
+ objectsByID map[uint64]*objectState
+
+ // deferred are objects that have been read, by no interest has been
+ // registered yet. These will be decoded once interest in registered.
+ deferred map[uint64]*pb.Object
+
+ // outstanding is the number of outstanding objects.
+ outstanding uint32
+
+ // r is the input stream.
+ r io.Reader
+
+ // stats is the passed stats object.
+ stats *Stats
+
+ // recoverable is the panic recover facility.
+ recoverable
+}
+
+// lookup looks up an object in decodeState or returns nil if no such object
+// has been previously registered.
+func (ds *decodeState) lookup(id uint64) *objectState {
+ return ds.objectsByID[id]
+}
+
+// wait registers a dependency on an object.
+//
+// As a special case, we always allow _useable_ references back to the first
+// decoding object because it may have fields that are already decoded. We also
+// allow trivial self reference, since they can be handled internally.
+func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
+ switch id {
+ case 0:
+ // Nil pointer; nothing to wait for.
+ fallthrough
+ case waiter.id:
+ // Trivial self reference.
+ fallthrough
+ case 1:
+ // Root object; see above.
+ if callback != nil {
+ callback()
+ }
+ return
+ }
+
+ // No nil can be returned here.
+ waiter.waitFor(ds.lookup(id), callback)
+}
+
+// waitObject notes a blocking relationship.
+func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
+ if rv, ok := p.Value.(*pb.Object_RefValue); ok {
+ // Refs can encode pointers and maps.
+ ds.wait(os, rv.RefValue, callback)
+ } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
+ // See decodeObject; we need to wait for the array (if non-nil).
+ ds.wait(os, sv.SliceValue.RefValue, callback)
+ } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
+ // It's an interface (wait recurisvely).
+ ds.waitObject(os, iv.InterfaceValue.Value, callback)
+ } else if callback != nil {
+ // Nothing to wait for: execute the callback immediately.
+ callback()
+ }
+}
+
+// register registers a decode with a type.
+//
+// This type is only used to instantiate a new object if it has not been
+// registered previously.
+func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
+ os, ok := ds.objectsByID[id]
+ if ok {
+ return os
+ }
+
+ // Record in the object index.
+ if typ.Kind() == reflect.Map {
+ os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
+ } else {
+ os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
+ }
+ ds.objectsByID[id] = os
+
+ if o, ok := ds.deferred[id]; ok {
+ // There is a deferred object.
+ delete(ds.deferred, id) // Free memory.
+ ds.decodeObject(os, os.obj, o, "", nil)
+ } else {
+ // There is no deferred object.
+ ds.outstanding++
+ }
+
+ return os
+}
+
+// decodeStruct decodes a struct value.
+func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
+ // Set the fields.
+ m := Map{newInternalMap(nil, ds, os)}
+ defer internalMapPool.Put(m.internalMap)
+ for _, field := range s.Fields {
+ m.data = append(m.data, entry{
+ name: field.Name,
+ object: field.Value,
+ })
+ }
+
+ // Sort the fields for efficient searching.
+ //
+ // Technically, these should already appear in sorted order in the
+ // state ordering, so this cost is effectively a single scan to ensure
+ // that the order is correct.
+ if len(m.data) > 1 {
+ sort.Slice(m.data, func(i, j int) bool {
+ return m.data[i].name < m.data[j].name
+ })
+ }
+
+ // Invoke the load; this will recursively decode other objects.
+ fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
+ if ok {
+ // Invoke the loader.
+ fns.invokeLoad(obj.Addr(), m)
+ } else if obj.NumField() == 0 {
+ // Allow anonymous empty structs.
+ return
+ } else {
+ // Propagate an error.
+ panic(fmt.Errorf("unregistered type %s", obj.Type()))
+ }
+}
+
+// decodeMap decodes a map value.
+func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
+ if obj.IsNil() {
+ obj.Set(reflect.MakeMap(obj.Type()))
+ }
+ for i := 0; i < len(m.Keys); i++ {
+ // Decode the objects.
+ kv := reflect.New(obj.Type().Key()).Elem()
+ vv := reflect.New(obj.Type().Elem()).Elem()
+ ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
+ ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
+ ds.waitObject(os, m.Keys[i], nil)
+ ds.waitObject(os, m.Values[i], nil)
+
+ // Set in the map.
+ obj.SetMapIndex(kv, vv)
+ }
+}
+
+// decodeArray decodes an array value.
+func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
+ if len(a.Contents) != obj.Len() {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
+ }
+ // Decode the contents into the array.
+ for i := 0; i < len(a.Contents); i++ {
+ ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
+ ds.waitObject(os, a.Contents[i], nil)
+ }
+}
+
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
+ // Is this a nil value?
+ if i.Type == "" {
+ return // Just leave obj alone.
+ }
+
+ // Get the dispatchable type. This may not be used if the given
+ // reference has already been resolved, but if not we need to know the
+ // type to create.
+ t, ok := registeredTypes.lookupType(i.Type)
+ if !ok {
+ panic(fmt.Errorf("no valid type for %q", i.Type))
+ }
+
+ if obj.Kind() != reflect.Map {
+ // Set the obj to be the given typed value; this actually sets
+ // obj to be a non-zero value -- namely, it inserts type
+ // information. There's no need to do this for maps.
+ obj.Set(reflect.Zero(t))
+ }
+
+ // Decode the dereferenced element; there is no need to wait here, as
+ // the interface object shares the current object state.
+ ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
+}
+
+// decodeObject decodes a object value.
+func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
+ ds.push(false, format, param)
+ ds.stats.Add(obj)
+ ds.stats.Start(obj)
+
+ switch x := object.GetValue().(type) {
+ case *pb.Object_BoolValue:
+ obj.SetBool(x.BoolValue)
+ case *pb.Object_StringValue:
+ obj.SetString(string(x.StringValue))
+ case *pb.Object_Int64Value:
+ obj.SetInt(x.Int64Value)
+ if obj.Int() != x.Int64Value {
+ panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_Uint64Value:
+ obj.SetUint(x.Uint64Value)
+ if obj.Uint() != x.Uint64Value {
+ panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_DoubleValue:
+ obj.SetFloat(x.DoubleValue)
+ if obj.Float() != x.DoubleValue {
+ panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_RefValue:
+ // Resolve the pointer itself, even though the object may not
+ // be decoded yet. You need to use wait() in order to ensure
+ // that is the case. See wait above, and Map.Barrier.
+ if id := x.RefValue; id != 0 {
+ // Decoding the interface should have imparted type
+ // information, so from this point it's safe to resolve
+ // and use this dynamic information for actually
+ // creating the object in register.
+ //
+ // (For non-interfaces this is a no-op).
+ dyntyp := reflect.TypeOf(obj.Interface())
+ if dyntyp.Kind() == reflect.Map {
+ // Remove the map object count here to avoid
+ // double counting, as this object will be
+ // counted again when it gets processed later.
+ // We do not add a reference count as the
+ // reference is artificial.
+ ds.stats.Remove(obj)
+ obj.Set(ds.register(id, dyntyp).obj)
+ } else if dyntyp.Kind() == reflect.Ptr {
+ ds.push(true /* dereference */, "", nil)
+ obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+ ds.pop()
+ } else {
+ obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+ }
+ } else {
+ // We leave obj alone here. That's because if obj
+ // represents an interface, it may have been embued
+ // with type information in decodeInterface, and we
+ // don't want to destroy that information.
+ }
+ case *pb.Object_SliceValue:
+ // It's okay to slice the array here, since the contents will
+ // still be provided later on. These semantics are a bit
+ // strange but they are handled in the Map.Barrier properly.
+ //
+ // The special semantics of zero ref apply here too.
+ if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
+ v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
+ obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
+ }
+ case *pb.Object_ArrayValue:
+ ds.decodeArray(os, obj, x.ArrayValue)
+ case *pb.Object_StructValue:
+ ds.decodeStruct(os, obj, x.StructValue)
+ case *pb.Object_MapValue:
+ ds.decodeMap(os, obj, x.MapValue)
+ case *pb.Object_InterfaceValue:
+ ds.decodeInterface(os, obj, x.InterfaceValue)
+ case *pb.Object_ByteArrayValue:
+ copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
+ case *pb.Object_Uint16ArrayValue:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := x.Uint16ArrayValue.Values
+ t := obj.Slice(0, obj.Len()).Interface().([]uint16)
+ if len(t) != len(s) {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ }
+ for i := range s {
+ t[i] = uint16(s[i])
+ }
+ case *pb.Object_Uint32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
+ case *pb.Object_Uint64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
+ case *pb.Object_UintptrArrayValue:
+ copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
+ case *pb.Object_Int8ArrayValue:
+ copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
+ case *pb.Object_Int16ArrayValue:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := x.Int16ArrayValue.Values
+ t := obj.Slice(0, obj.Len()).Interface().([]int16)
+ if len(t) != len(s) {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ }
+ for i := range s {
+ t[i] = int16(s[i])
+ }
+ case *pb.Object_Int32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
+ case *pb.Object_Int64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
+ case *pb.Object_BoolArrayValue:
+ copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
+ case *pb.Object_Float64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
+ case *pb.Object_Float32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
+ default:
+ // Shoud not happen, not propagated as an error.
+ panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
+ }
+
+ ds.stats.Done()
+ ds.pop()
+}
+
+func copyArray(dest reflect.Value, src reflect.Value) {
+ if dest.Len() != src.Len() {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
+ }
+ reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
+}
+
+// Deserialize deserializes the object state.
+//
+// This function may panic and should be run in safely().
+func (ds *decodeState) Deserialize(obj reflect.Value) {
+ ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
+ ds.outstanding = 1 // The root object.
+
+ // Decode all objects in the stream.
+ //
+ // See above, we never process objects while we have no outstanding
+ // interests (other than the very first object).
+ for id := uint64(1); ds.outstanding > 0; id++ {
+ os := ds.lookup(id)
+ ds.stats.Start(os.obj)
+
+ o, err := ds.readObject()
+ if err != nil {
+ panic(err)
+ }
+
+ if os != nil {
+ // Decode the object.
+ ds.from = &os.path
+ ds.decodeObject(os, os.obj, o, "", nil)
+ ds.outstanding--
+ } else {
+ // If an object hasn't had interest registered
+ // previously, we deferred decoding until interest is
+ // registered.
+ ds.deferred[id] = o
+ }
+
+ ds.stats.Done()
+ }
+
+ // Check the zero-length header at the end.
+ length, object, err := ReadHeader(ds.r)
+ if err != nil {
+ panic(err)
+ }
+ if length != 0 {
+ panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
+ }
+ if object {
+ panic("expected non-object terminal")
+ }
+
+ // Check if we have any deferred objects.
+ if count := len(ds.deferred); count > 0 {
+ // Shoud not happen, not propagated as an error.
+ panic(fmt.Sprintf("still have %d deferred objects", count))
+ }
+
+ // Scan and fire all callbacks.
+ for _, os := range ds.objectsByID {
+ os.checkComplete(ds.stats)
+ }
+
+ // Check if we have any remaining dependency cycles.
+ for _, os := range ds.objectsByID {
+ if !os.complete() {
+ // This must be the result of a dependency cycle.
+ cycle := os.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
+ }
+ buf.WriteString("}")
+ // Panic as an error; propagate to the caller.
+ panic(errors.New(string(buf.Bytes())))
+ }
+ }
+}
+
+type byteReader struct {
+ io.Reader
+}
+
+// ReadByte implements io.ByteReader.
+func (br byteReader) ReadByte() (byte, error) {
+ var b [1]byte
+ n, err := br.Reader.Read(b[:])
+ if n > 0 {
+ return b[0], nil
+ } else if err != nil {
+ return 0, err
+ } else {
+ return 0, io.ErrUnexpectedEOF
+ }
+}
+
+// ReadHeader reads an object header.
+//
+// Each object written to the statefile is prefixed with a header. See
+// WriteHeader for more information; these functions are exported to allow
+// non-state writes to the file to play nice with debugging tools.
+func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
+ // Read the header.
+ length, err = binary.ReadUvarint(byteReader{r})
+ if err != nil {
+ return
+ }
+
+ // Decode whether the object is valid.
+ object = length&0x1 != 0
+ length = length >> 1
+ return
+}
+
+// readObject reads an object from the stream.
+func (ds *decodeState) readObject() (*pb.Object, error) {
+ // Read the header.
+ length, object, err := ReadHeader(ds.r)
+ if err != nil {
+ return nil, err
+ }
+ if !object {
+ return nil, fmt.Errorf("invalid object header")
+ }
+
+ // Read the object.
+ buf := make([]byte, length)
+ for done := 0; done < len(buf); {
+ n, err := ds.r.Read(buf[done:])
+ done += n
+ if n == 0 && err != nil {
+ return nil, err
+ }
+ }
+
+ // Unmarshal.
+ obj := new(pb.Object)
+ if err := proto.Unmarshal(buf, obj); err != nil {
+ return nil, err
+ }
+
+ return obj, nil
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
new file mode 100644
index 000000000..b0714170b
--- /dev/null
+++ b/pkg/state/encode.go
@@ -0,0 +1,466 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "container/list"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "reflect"
+ "sort"
+
+ "github.com/golang/protobuf/proto"
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// queuedObject is an object queued for encoding.
+type queuedObject struct {
+ id uint64
+ obj reflect.Value
+ path recoverable
+}
+
+// encodeState is state used for encoding.
+//
+// The encoding process is a breadth-first traversal of the object graph. The
+// inherent races and dependencies are much simpler than the decode case.
+type encodeState struct {
+ // lastID is the last object ID.
+ //
+ // See idsByObject for context. Because of the special zero encoding
+ // used for reference values, the first ID must be 1.
+ lastID uint64
+
+ // idsByObject is a set of objects, indexed via:
+ //
+ // reflect.ValueOf(x).UnsafeAddr
+ //
+ // This provides IDs for objects.
+ idsByObject map[uintptr]uint64
+
+ // values stores values that span the addresses.
+ //
+ // addrSet is a a generated type which efficiently stores ranges of
+ // addresses. When encoding pointers, these ranges are filled in and
+ // used to check for overlapping or conflicting pointers. This would
+ // indicate a pointer to an field, or a non-type safe value, neither of
+ // which are currently decodable.
+ //
+ // See the usage of values below for more context.
+ values addrSet
+
+ // w is the output stream.
+ w io.Writer
+
+ // pending is the list of objects to be serialized.
+ //
+ // This is a set of queuedObjects.
+ pending list.List
+
+ // done is the a list of finished objects.
+ //
+ // This is kept to prevent garbage collection and address reuse.
+ done list.List
+
+ // stats is the passed stats object.
+ stats *Stats
+
+ // recoverable is the panic recover facility.
+ recoverable
+}
+
+// register looks up an ID, registering if necessary.
+//
+// If the object was not previosly registered, it is enqueued to be serialized.
+// See the documentation for idsByObject for more information.
+func (es *encodeState) register(obj reflect.Value) uint64 {
+ // It is not legal to call register for any non-pointer objects (see
+ // below), so we panic with a recoverable error if this is a mismatch.
+ if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map {
+ panic(fmt.Errorf("non-pointer %#v registered", obj.Interface()))
+ }
+
+ addr := obj.Pointer()
+ if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 {
+ // For zero-sized objects, we always provide a unique ID.
+ // That's because the runtime internally multiplexes pointers
+ // to the same address. We can't be certain what the intent is
+ // with pointers to zero-sized objects, so we just give them
+ // all unique identities.
+ } else if id, ok := es.idsByObject[addr]; ok {
+ // Already registered.
+ return id
+ }
+
+ // Ensure that the first ID given out is one. See note on lastID. The
+ // ID zero is used to indicate nil values.
+ es.lastID++
+ id := es.lastID
+ es.idsByObject[addr] = id
+ if obj.Kind() == reflect.Ptr {
+ // Dereference and treat as a pointer.
+ es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()})
+
+ // Register this object at all addresses.
+ typ := obj.Elem().Type()
+ if size := typ.Size(); size > 0 {
+ r := addrRange{addr, addr + size}
+ if !es.values.IsEmptyRange(r) {
+ old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable)
+ panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path()))
+ }
+ es.values.Add(r, reflect.ValueOf(es.recoverable.copy()))
+ }
+ } else {
+ // Push back the map itself; when maps are encoded from the
+ // top-level, forceMap will be equal to true.
+ es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()})
+ }
+
+ return id
+}
+
+// encodeMap encodes a map.
+func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map {
+ var (
+ keys []*pb.Object
+ values []*pb.Object
+ )
+ for i, k := range obj.MapKeys() {
+ v := obj.MapIndex(k)
+ kp := es.encodeObject(k, false, ".(key %d)", i)
+ vp := es.encodeObject(v, false, "[%#v]", k.Interface())
+ keys = append(keys, kp)
+ values = append(values, vp)
+ }
+ return &pb.Map{Keys: keys, Values: values}
+}
+
+// encodeStruct encodes a composite object.
+func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct {
+ // Invoke the save.
+ m := Map{newInternalMap(es, nil, nil)}
+ defer internalMapPool.Put(m.internalMap)
+ if !obj.CanAddr() {
+ // Force it to a * type of the above; this involves a copy.
+ localObj := reflect.New(obj.Type())
+ localObj.Elem().Set(obj)
+ obj = localObj.Elem()
+ }
+ fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
+ if ok {
+ // Invoke the provided saver.
+ fns.invokeSave(obj.Addr(), m)
+ } else if obj.NumField() == 0 {
+ // Allow unregistered anonymous, empty structs.
+ return &pb.Struct{}
+ } else {
+ // Propagate an error.
+ panic(fmt.Errorf("unregistered type %T", obj.Interface()))
+ }
+
+ // Sort the underlying slice, and check for duplicates. This is done
+ // once instead of on each add, because performing this sort once is
+ // far more efficient.
+ if len(m.data) > 1 {
+ sort.Slice(m.data, func(i, j int) bool {
+ return m.data[i].name < m.data[j].name
+ })
+ for i := range m.data {
+ if i > 0 && m.data[i-1].name == m.data[i].name {
+ panic(fmt.Errorf("duplicate name %s", m.data[i].name))
+ }
+ }
+ }
+
+ // Encode the resulting fields.
+ fields := make([]*pb.Field, 0, len(m.data))
+ for _, e := range m.data {
+ fields = append(fields, &pb.Field{
+ Name: e.name,
+ Value: e.object,
+ })
+ }
+
+ // Return the encoded object.
+ return &pb.Struct{Fields: fields}
+}
+
+// encodeArray encodes an array.
+func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array {
+ var (
+ contents []*pb.Object
+ )
+ for i := 0; i < obj.Len(); i++ {
+ entry := es.encodeObject(obj.Index(i), false, "[%d]", i)
+ contents = append(contents, entry)
+ }
+ return &pb.Array{Contents: contents}
+}
+
+// encodeInterface encodes an interface.
+//
+// Precondition: the value is not nil.
+func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface {
+ // Check for the nil interface.
+ obj = reflect.ValueOf(obj.Interface())
+ if !obj.IsValid() {
+ return &pb.Interface{
+ Type: "", // left alone in decode.
+ Value: &pb.Object{Value: &pb.Object_RefValue{0}},
+ }
+ }
+ // We have an interface value here. How do we save that? We
+ // resolve the underlying type and save it as a dispatchable.
+ typName, ok := registeredTypes.lookupName(obj.Type())
+ if !ok {
+ panic(fmt.Errorf("type %s is not registered", obj.Type()))
+ }
+
+ // Encode the object again.
+ return &pb.Interface{
+ Type: typName,
+ Value: es.encodeObject(obj, false, ".(%s)", typName),
+ }
+}
+
+// encodeObject encodes an object.
+//
+// If mapAsValue is true, then a map will be encoded directly.
+func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) {
+ es.push(false, format, param)
+ es.stats.Add(obj)
+ es.stats.Start(obj)
+
+ switch obj.Kind() {
+ case reflect.Bool:
+ object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}}
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}}
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}}
+ case reflect.Float32, reflect.Float64:
+ object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}}
+ case reflect.Array:
+ switch obj.Type().Elem().Kind() {
+ case reflect.Uint8:
+ object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}}
+ case reflect.Uint16:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := pbSlice(obj).Interface().([]uint16)
+ t := make([]uint32, len(s))
+ for i := range s {
+ t[i] = uint32(s[i])
+ }
+ object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}}
+ case reflect.Uint32:
+ object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}}
+ case reflect.Uint64:
+ object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}}
+ case reflect.Uintptr:
+ object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}}
+ case reflect.Int8:
+ object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}}
+ case reflect.Int16:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := pbSlice(obj).Interface().([]int16)
+ t := make([]int32, len(s))
+ for i := range s {
+ t[i] = int32(s[i])
+ }
+ object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}}
+ case reflect.Int32:
+ object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}}
+ case reflect.Int64:
+ object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}}
+ case reflect.Bool:
+ object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}}
+ case reflect.Float32:
+ object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}}
+ case reflect.Float64:
+ object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}}
+ default:
+ object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}}
+ }
+ case reflect.Slice:
+ if obj.IsNil() || obj.Cap() == 0 {
+ // Handled specially in decode; store as nil value.
+ object = &pb.Object{Value: &pb.Object_RefValue{0}}
+ } else {
+ // Serialize a slice as the array plus length and capacity.
+ object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{
+ Capacity: uint32(obj.Cap()),
+ Length: uint32(obj.Len()),
+ RefValue: es.register(arrayFromSlice(obj)),
+ }}}
+ }
+ case reflect.String:
+ object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}}
+ case reflect.Ptr:
+ if obj.IsNil() {
+ // Handled specially in decode; store as a nil value.
+ object = &pb.Object{Value: &pb.Object_RefValue{0}}
+ } else {
+ es.push(true /* dereference */, "", nil)
+ object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
+ es.pop()
+ }
+ case reflect.Interface:
+ // We don't check for IsNil here, as we want to encode type
+ // information. The case of the empty interface (no type, no
+ // value) is handled by encodeInteface.
+ object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}}
+ case reflect.Struct:
+ object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}}
+ case reflect.Map:
+ if obj.IsNil() {
+ // Handled specially in decode; store as a nil value.
+ object = &pb.Object{Value: &pb.Object_RefValue{0}}
+ } else if mapAsValue {
+ // Encode the map directly.
+ object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}}
+ } else {
+ // Encode a reference to the map.
+ //
+ // Remove the map object count here to avoid double
+ // counting, as this object will be counted again when
+ // it gets processed later. We do not add a reference
+ // count as the reference is artificial.
+ es.stats.Remove(obj)
+ object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
+ }
+ default:
+ panic(fmt.Errorf("unknown primitive %#v", obj.Interface()))
+ }
+
+ es.stats.Done()
+ es.pop()
+ return
+}
+
+// Serialize serializes the object state.
+//
+// This function may panic and should be run in safely().
+func (es *encodeState) Serialize(obj reflect.Value) {
+ es.register(obj.Addr())
+
+ // Pop off the list until we're done.
+ for es.pending.Len() > 0 {
+ e := es.pending.Front()
+
+ // Extract the queued object.
+ qo := e.Value.(queuedObject)
+ es.stats.Start(qo.obj)
+
+ es.pending.Remove(e)
+
+ es.from = &qo.path
+ o := es.encodeObject(qo.obj, true, "", nil)
+
+ // Emit to our output stream.
+ if err := es.writeObject(qo.id, o); err != nil {
+ panic(err)
+ }
+
+ // Mark as done.
+ es.done.PushBack(e)
+ es.stats.Done()
+ }
+
+ // Write a zero-length terminal at the end; this is a sanity check
+ // applied at decode time as well (see decode.go).
+ if err := WriteHeader(es.w, 0, false); err != nil {
+ panic(err)
+ }
+}
+
+// WriteHeader writes a header.
+//
+// Each object written to the statefile should be prefixed with a header. In
+// order to generate statefiles that play nicely with debugging tools, raw
+// writes should be prefixed with a header with object set to false and the
+// appropriate length. This will allow tools to skip these regions.
+func WriteHeader(w io.Writer, length uint64, object bool) error {
+ // The lowest-order bit encodes whether this is a valid object. This is
+ // a purely internal convention, but allows the object flag to be
+ // returned from ReadHeader.
+ length = length << 1
+ if object {
+ length |= 0x1
+ }
+
+ // Write a header.
+ var hdr [32]byte
+ encodedLen := binary.PutUvarint(hdr[:], length)
+ for done := 0; done < encodedLen; {
+ n, err := w.Write(hdr[done:encodedLen])
+ done += n
+ if n == 0 && err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// writeObject writes an object to the stream.
+func (es *encodeState) writeObject(id uint64, obj *pb.Object) error {
+ // Marshal the proto.
+ buf, err := proto.Marshal(obj)
+ if err != nil {
+ return err
+ }
+
+ // Write the object header.
+ if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil {
+ return err
+ }
+
+ // Write the object.
+ for done := 0; done < len(buf); {
+ n, err := es.w.Write(buf[done:])
+ done += n
+ if n == 0 && err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// addrSetFunctions is used by addrSet.
+type addrSetFunctions struct{}
+
+func (addrSetFunctions) MinKey() uintptr {
+ return 0
+}
+
+func (addrSetFunctions) MaxKey() uintptr {
+ return ^uintptr(0)
+}
+
+func (addrSetFunctions) ClearValue(val *reflect.Value) {
+}
+
+func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) {
+ return val1, val1 == val2
+}
+
+func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) {
+ return val, val
+}
diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go
new file mode 100644
index 000000000..457e6dbb7
--- /dev/null
+++ b/pkg/state/encode_unsafe.go
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// arrayFromSlice constructs a new pointer to the slice data.
+//
+// It would be similar to the following:
+//
+// x := make([]Foo, l, c)
+// a := ([l]Foo*)(unsafe.Pointer(x[0]))
+//
+func arrayFromSlice(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(
+ reflect.ArrayOf(obj.Cap(), obj.Type().Elem()),
+ unsafe.Pointer(obj.Pointer()))
+}
+
+// pbSlice returns a protobuf-supported slice of the array and erase the
+// original element type (which could be a defined type or non-supported type).
+func pbSlice(obj reflect.Value) reflect.Value {
+ var typ reflect.Type
+ switch obj.Type().Elem().Kind() {
+ case reflect.Uint8:
+ typ = reflect.TypeOf(byte(0))
+ case reflect.Uint16:
+ typ = reflect.TypeOf(uint16(0))
+ case reflect.Uint32:
+ typ = reflect.TypeOf(uint32(0))
+ case reflect.Uint64:
+ typ = reflect.TypeOf(uint64(0))
+ case reflect.Uintptr:
+ typ = reflect.TypeOf(uint64(0))
+ case reflect.Int8:
+ typ = reflect.TypeOf(byte(0))
+ case reflect.Int16:
+ typ = reflect.TypeOf(int16(0))
+ case reflect.Int32:
+ typ = reflect.TypeOf(int32(0))
+ case reflect.Int64:
+ typ = reflect.TypeOf(int64(0))
+ case reflect.Bool:
+ typ = reflect.TypeOf(bool(false))
+ case reflect.Float32:
+ typ = reflect.TypeOf(float32(0))
+ case reflect.Float64:
+ typ = reflect.TypeOf(float64(0))
+ default:
+ panic("slice element is not of basic value type")
+ }
+ return reflect.NewAt(
+ reflect.ArrayOf(obj.Len(), typ),
+ unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
+ ).Elem().Slice(0, obj.Len())
+}
+
+func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value {
+ if obj.Type().Elem().Size() != elemTyp.Size() {
+ panic("cannot cast slice into other element type of different size")
+ }
+ return reflect.NewAt(
+ reflect.ArrayOf(obj.Len(), elemTyp),
+ unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
+ ).Elem()
+}
diff --git a/pkg/state/map.go b/pkg/state/map.go
new file mode 100644
index 000000000..1fb9b47b8
--- /dev/null
+++ b/pkg/state/map.go
@@ -0,0 +1,221 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "fmt"
+ "reflect"
+ "sort"
+ "sync"
+
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// entry is a single map entry.
+type entry struct {
+ name string
+ object *pb.Object
+}
+
+// internalMap is the internal Map state.
+//
+// These are recycled via a pool to avoid churn.
+type internalMap struct {
+ // es is encodeState.
+ es *encodeState
+
+ // ds is decodeState.
+ ds *decodeState
+
+ // os is current object being decoded.
+ //
+ // This will always be nil during encode.
+ os *objectState
+
+ // data stores the encoded values.
+ data []entry
+}
+
+var internalMapPool = sync.Pool{
+ New: func() interface{} {
+ return new(internalMap)
+ },
+}
+
+// newInternalMap returns a cached map.
+func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap {
+ m := internalMapPool.Get().(*internalMap)
+ m.es = es
+ m.ds = ds
+ m.os = os
+ if m.data != nil {
+ m.data = m.data[:0]
+ }
+ return m
+}
+
+// Map is a generic state container.
+//
+// This is the object passed to Save and Load in order to store their state.
+//
+// Detailed documentation is available in individual methods.
+type Map struct {
+ *internalMap
+}
+
+// Save adds the given object to the map.
+//
+// You should pass always pointers to the object you are saving. For example:
+//
+// type X struct {
+// A int
+// B *int
+// }
+//
+// func (x *X) Save(m Map) {
+// m.Save("A", &x.A)
+// m.Save("B", &x.B)
+// }
+//
+// func (x *X) Load(m Map) {
+// m.Load("A", &x.A)
+// m.Load("B", &x.B)
+// }
+func (m Map) Save(name string, objPtr interface{}) {
+ m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s")
+}
+
+// SaveValue adds the given object value to the map.
+//
+// This should be used for values where pointers are not available, or casts
+// are required during Save/Load.
+//
+// For example, if we want to cast external package type P.Foo to int64:
+//
+// type X struct {
+// A P.Foo
+// }
+//
+// func (x *X) Save(m Map) {
+// m.SaveValue("A", int64(x.A))
+// }
+//
+// func (x *X) Load(m Map) {
+// m.LoadValue("A", new(int64), func(x interface{}) {
+// x.A = P.Foo(x.(int64))
+// })
+// }
+func (m Map) SaveValue(name string, obj interface{}) {
+ m.save(name, reflect.ValueOf(obj), ".(value %s)")
+}
+
+// save is helper for the above. It takes the name of value to save the field
+// to, the field object (obj), and a format string that specifies how the
+// field's saving logic is dispatched from the struct (normal, value, etc.). The
+// format string should expect one string parameter, which is the name of the
+// field.
+func (m Map) save(name string, obj reflect.Value, format string) {
+ if m.es == nil {
+ // Not currently encoding.
+ m.Failf("no encode state for %q", name)
+ }
+
+ // Attempt the encode.
+ //
+ // These are sorted at the end, after all objects are added and will be
+ // sorted and checked for duplicates (see encodeStruct).
+ m.data = append(m.data, entry{
+ name: name,
+ object: m.es.encodeObject(obj, false, format, name),
+ })
+}
+
+// Load loads the given object from the map.
+//
+// See Save for an example.
+func (m Map) Load(name string, objPtr interface{}) {
+ m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s")
+}
+
+// LoadWait loads the given objects from the map, and marks it as requiring all
+// AfterLoad executions to complete prior to running this object's AfterLoad.
+//
+// See Save for an example.
+func (m Map) LoadWait(name string, objPtr interface{}) {
+ m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)")
+}
+
+// LoadValue loads the given object value from the map.
+//
+// See SaveValue for an example.
+func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) {
+ o := reflect.ValueOf(objPtr)
+ m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)")
+}
+
+// load is helper for the above. It takes the name of value to load the field
+// from, the target field pointer (objPtr), whether load completion of the
+// struct depends on the field's load completion (wait), the load completion
+// logic (fn), and a format string that specifies how the field's loading logic
+// is dispatched from the struct (normal, wait, value, etc.). The format string
+// should expect one string parameter, which is the name of the field.
+func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) {
+ if m.ds == nil {
+ // Not currently decoding.
+ m.Failf("no decode state for %q", name)
+ }
+
+ // Find the object.
+ //
+ // These are sorted up front (and should appear in the state file
+ // sorted as well), so we can do a binary search here to ensure that
+ // large structs don't behave badly.
+ i := sort.Search(len(m.data), func(i int) bool {
+ return m.data[i].name >= name
+ })
+ if i >= len(m.data) || m.data[i].name != name {
+ // There is no data for this name?
+ m.Failf("no data found for %q", name)
+ }
+
+ // Perform the decode.
+ m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name)
+ if wait {
+ // Mark this individual object a blocker.
+ m.ds.waitObject(m.os, m.data[i].object, fn)
+ }
+}
+
+// Failf fails the save or restore with the provided message. Processing will
+// stop after calling Failf, as the state package uses a panic & recover
+// mechanism for state errors. You should defer any cleanup required.
+func (m Map) Failf(format string, args ...interface{}) {
+ panic(fmt.Errorf(format, args...))
+}
+
+// AfterLoad schedules a function execution when all objects have been allocated
+// and their automated loading and customized load logic have been executed. fn
+// will not be executed until all of current object's dependencies' AfterLoad()
+// logic, if exist, have been executed.
+func (m Map) AfterLoad(fn func()) {
+ if m.ds == nil {
+ // Not currently decoding.
+ m.Failf("not decoding")
+ }
+
+ // Queue the local callback; this will execute when all of the above
+ // data dependencies have been cleared.
+ m.os.callbacks = append(m.os.callbacks, fn)
+}
diff --git a/pkg/state/object_go_proto/object.pb.go b/pkg/state/object_go_proto/object.pb.go
new file mode 100755
index 000000000..dc5127149
--- /dev/null
+++ b/pkg/state/object_go_proto/object.pb.go
@@ -0,0 +1,1195 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: pkg/state/object.proto
+
+package gvisor_state_statefile
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type Slice struct {
+ Length uint32 `protobuf:"varint,1,opt,name=length,proto3" json:"length,omitempty"`
+ Capacity uint32 `protobuf:"varint,2,opt,name=capacity,proto3" json:"capacity,omitempty"`
+ RefValue uint64 `protobuf:"varint,3,opt,name=ref_value,json=refValue,proto3" json:"ref_value,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Slice) Reset() { *m = Slice{} }
+func (m *Slice) String() string { return proto.CompactTextString(m) }
+func (*Slice) ProtoMessage() {}
+func (*Slice) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{0}
+}
+
+func (m *Slice) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Slice.Unmarshal(m, b)
+}
+func (m *Slice) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Slice.Marshal(b, m, deterministic)
+}
+func (m *Slice) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Slice.Merge(m, src)
+}
+func (m *Slice) XXX_Size() int {
+ return xxx_messageInfo_Slice.Size(m)
+}
+func (m *Slice) XXX_DiscardUnknown() {
+ xxx_messageInfo_Slice.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Slice proto.InternalMessageInfo
+
+func (m *Slice) GetLength() uint32 {
+ if m != nil {
+ return m.Length
+ }
+ return 0
+}
+
+func (m *Slice) GetCapacity() uint32 {
+ if m != nil {
+ return m.Capacity
+ }
+ return 0
+}
+
+func (m *Slice) GetRefValue() uint64 {
+ if m != nil {
+ return m.RefValue
+ }
+ return 0
+}
+
+type Array struct {
+ Contents []*Object `protobuf:"bytes,1,rep,name=contents,proto3" json:"contents,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Array) Reset() { *m = Array{} }
+func (m *Array) String() string { return proto.CompactTextString(m) }
+func (*Array) ProtoMessage() {}
+func (*Array) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{1}
+}
+
+func (m *Array) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Array.Unmarshal(m, b)
+}
+func (m *Array) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Array.Marshal(b, m, deterministic)
+}
+func (m *Array) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Array.Merge(m, src)
+}
+func (m *Array) XXX_Size() int {
+ return xxx_messageInfo_Array.Size(m)
+}
+func (m *Array) XXX_DiscardUnknown() {
+ xxx_messageInfo_Array.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Array proto.InternalMessageInfo
+
+func (m *Array) GetContents() []*Object {
+ if m != nil {
+ return m.Contents
+ }
+ return nil
+}
+
+type Map struct {
+ Keys []*Object `protobuf:"bytes,1,rep,name=keys,proto3" json:"keys,omitempty"`
+ Values []*Object `protobuf:"bytes,2,rep,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Map) Reset() { *m = Map{} }
+func (m *Map) String() string { return proto.CompactTextString(m) }
+func (*Map) ProtoMessage() {}
+func (*Map) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{2}
+}
+
+func (m *Map) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Map.Unmarshal(m, b)
+}
+func (m *Map) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Map.Marshal(b, m, deterministic)
+}
+func (m *Map) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Map.Merge(m, src)
+}
+func (m *Map) XXX_Size() int {
+ return xxx_messageInfo_Map.Size(m)
+}
+func (m *Map) XXX_DiscardUnknown() {
+ xxx_messageInfo_Map.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Map proto.InternalMessageInfo
+
+func (m *Map) GetKeys() []*Object {
+ if m != nil {
+ return m.Keys
+ }
+ return nil
+}
+
+func (m *Map) GetValues() []*Object {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Interface struct {
+ Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"`
+ Value *Object `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Interface) Reset() { *m = Interface{} }
+func (m *Interface) String() string { return proto.CompactTextString(m) }
+func (*Interface) ProtoMessage() {}
+func (*Interface) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{3}
+}
+
+func (m *Interface) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Interface.Unmarshal(m, b)
+}
+func (m *Interface) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Interface.Marshal(b, m, deterministic)
+}
+func (m *Interface) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Interface.Merge(m, src)
+}
+func (m *Interface) XXX_Size() int {
+ return xxx_messageInfo_Interface.Size(m)
+}
+func (m *Interface) XXX_DiscardUnknown() {
+ xxx_messageInfo_Interface.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Interface proto.InternalMessageInfo
+
+func (m *Interface) GetType() string {
+ if m != nil {
+ return m.Type
+ }
+ return ""
+}
+
+func (m *Interface) GetValue() *Object {
+ if m != nil {
+ return m.Value
+ }
+ return nil
+}
+
+type Struct struct {
+ Fields []*Field `protobuf:"bytes,1,rep,name=fields,proto3" json:"fields,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Struct) Reset() { *m = Struct{} }
+func (m *Struct) String() string { return proto.CompactTextString(m) }
+func (*Struct) ProtoMessage() {}
+func (*Struct) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{4}
+}
+
+func (m *Struct) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Struct.Unmarshal(m, b)
+}
+func (m *Struct) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Struct.Marshal(b, m, deterministic)
+}
+func (m *Struct) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Struct.Merge(m, src)
+}
+func (m *Struct) XXX_Size() int {
+ return xxx_messageInfo_Struct.Size(m)
+}
+func (m *Struct) XXX_DiscardUnknown() {
+ xxx_messageInfo_Struct.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Struct proto.InternalMessageInfo
+
+func (m *Struct) GetFields() []*Field {
+ if m != nil {
+ return m.Fields
+ }
+ return nil
+}
+
+type Field struct {
+ Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
+ Value *Object `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Field) Reset() { *m = Field{} }
+func (m *Field) String() string { return proto.CompactTextString(m) }
+func (*Field) ProtoMessage() {}
+func (*Field) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{5}
+}
+
+func (m *Field) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Field.Unmarshal(m, b)
+}
+func (m *Field) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Field.Marshal(b, m, deterministic)
+}
+func (m *Field) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Field.Merge(m, src)
+}
+func (m *Field) XXX_Size() int {
+ return xxx_messageInfo_Field.Size(m)
+}
+func (m *Field) XXX_DiscardUnknown() {
+ xxx_messageInfo_Field.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Field proto.InternalMessageInfo
+
+func (m *Field) GetName() string {
+ if m != nil {
+ return m.Name
+ }
+ return ""
+}
+
+func (m *Field) GetValue() *Object {
+ if m != nil {
+ return m.Value
+ }
+ return nil
+}
+
+type Uint16S struct {
+ Values []uint32 `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Uint16S) Reset() { *m = Uint16S{} }
+func (m *Uint16S) String() string { return proto.CompactTextString(m) }
+func (*Uint16S) ProtoMessage() {}
+func (*Uint16S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{6}
+}
+
+func (m *Uint16S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Uint16S.Unmarshal(m, b)
+}
+func (m *Uint16S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Uint16S.Marshal(b, m, deterministic)
+}
+func (m *Uint16S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Uint16S.Merge(m, src)
+}
+func (m *Uint16S) XXX_Size() int {
+ return xxx_messageInfo_Uint16S.Size(m)
+}
+func (m *Uint16S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Uint16S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Uint16S proto.InternalMessageInfo
+
+func (m *Uint16S) GetValues() []uint32 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Uint32S struct {
+ Values []uint32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Uint32S) Reset() { *m = Uint32S{} }
+func (m *Uint32S) String() string { return proto.CompactTextString(m) }
+func (*Uint32S) ProtoMessage() {}
+func (*Uint32S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{7}
+}
+
+func (m *Uint32S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Uint32S.Unmarshal(m, b)
+}
+func (m *Uint32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Uint32S.Marshal(b, m, deterministic)
+}
+func (m *Uint32S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Uint32S.Merge(m, src)
+}
+func (m *Uint32S) XXX_Size() int {
+ return xxx_messageInfo_Uint32S.Size(m)
+}
+func (m *Uint32S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Uint32S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Uint32S proto.InternalMessageInfo
+
+func (m *Uint32S) GetValues() []uint32 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Uint64S struct {
+ Values []uint64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Uint64S) Reset() { *m = Uint64S{} }
+func (m *Uint64S) String() string { return proto.CompactTextString(m) }
+func (*Uint64S) ProtoMessage() {}
+func (*Uint64S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{8}
+}
+
+func (m *Uint64S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Uint64S.Unmarshal(m, b)
+}
+func (m *Uint64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Uint64S.Marshal(b, m, deterministic)
+}
+func (m *Uint64S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Uint64S.Merge(m, src)
+}
+func (m *Uint64S) XXX_Size() int {
+ return xxx_messageInfo_Uint64S.Size(m)
+}
+func (m *Uint64S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Uint64S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Uint64S proto.InternalMessageInfo
+
+func (m *Uint64S) GetValues() []uint64 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Uintptrs struct {
+ Values []uint64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Uintptrs) Reset() { *m = Uintptrs{} }
+func (m *Uintptrs) String() string { return proto.CompactTextString(m) }
+func (*Uintptrs) ProtoMessage() {}
+func (*Uintptrs) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{9}
+}
+
+func (m *Uintptrs) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Uintptrs.Unmarshal(m, b)
+}
+func (m *Uintptrs) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Uintptrs.Marshal(b, m, deterministic)
+}
+func (m *Uintptrs) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Uintptrs.Merge(m, src)
+}
+func (m *Uintptrs) XXX_Size() int {
+ return xxx_messageInfo_Uintptrs.Size(m)
+}
+func (m *Uintptrs) XXX_DiscardUnknown() {
+ xxx_messageInfo_Uintptrs.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Uintptrs proto.InternalMessageInfo
+
+func (m *Uintptrs) GetValues() []uint64 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Int8S struct {
+ Values []byte `protobuf:"bytes,1,opt,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Int8S) Reset() { *m = Int8S{} }
+func (m *Int8S) String() string { return proto.CompactTextString(m) }
+func (*Int8S) ProtoMessage() {}
+func (*Int8S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{10}
+}
+
+func (m *Int8S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Int8S.Unmarshal(m, b)
+}
+func (m *Int8S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Int8S.Marshal(b, m, deterministic)
+}
+func (m *Int8S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Int8S.Merge(m, src)
+}
+func (m *Int8S) XXX_Size() int {
+ return xxx_messageInfo_Int8S.Size(m)
+}
+func (m *Int8S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Int8S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Int8S proto.InternalMessageInfo
+
+func (m *Int8S) GetValues() []byte {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Int16S struct {
+ Values []int32 `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Int16S) Reset() { *m = Int16S{} }
+func (m *Int16S) String() string { return proto.CompactTextString(m) }
+func (*Int16S) ProtoMessage() {}
+func (*Int16S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{11}
+}
+
+func (m *Int16S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Int16S.Unmarshal(m, b)
+}
+func (m *Int16S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Int16S.Marshal(b, m, deterministic)
+}
+func (m *Int16S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Int16S.Merge(m, src)
+}
+func (m *Int16S) XXX_Size() int {
+ return xxx_messageInfo_Int16S.Size(m)
+}
+func (m *Int16S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Int16S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Int16S proto.InternalMessageInfo
+
+func (m *Int16S) GetValues() []int32 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Int32S struct {
+ Values []int32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Int32S) Reset() { *m = Int32S{} }
+func (m *Int32S) String() string { return proto.CompactTextString(m) }
+func (*Int32S) ProtoMessage() {}
+func (*Int32S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{12}
+}
+
+func (m *Int32S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Int32S.Unmarshal(m, b)
+}
+func (m *Int32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Int32S.Marshal(b, m, deterministic)
+}
+func (m *Int32S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Int32S.Merge(m, src)
+}
+func (m *Int32S) XXX_Size() int {
+ return xxx_messageInfo_Int32S.Size(m)
+}
+func (m *Int32S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Int32S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Int32S proto.InternalMessageInfo
+
+func (m *Int32S) GetValues() []int32 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Int64S struct {
+ Values []int64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Int64S) Reset() { *m = Int64S{} }
+func (m *Int64S) String() string { return proto.CompactTextString(m) }
+func (*Int64S) ProtoMessage() {}
+func (*Int64S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{13}
+}
+
+func (m *Int64S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Int64S.Unmarshal(m, b)
+}
+func (m *Int64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Int64S.Marshal(b, m, deterministic)
+}
+func (m *Int64S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Int64S.Merge(m, src)
+}
+func (m *Int64S) XXX_Size() int {
+ return xxx_messageInfo_Int64S.Size(m)
+}
+func (m *Int64S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Int64S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Int64S proto.InternalMessageInfo
+
+func (m *Int64S) GetValues() []int64 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Bools struct {
+ Values []bool `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Bools) Reset() { *m = Bools{} }
+func (m *Bools) String() string { return proto.CompactTextString(m) }
+func (*Bools) ProtoMessage() {}
+func (*Bools) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{14}
+}
+
+func (m *Bools) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Bools.Unmarshal(m, b)
+}
+func (m *Bools) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Bools.Marshal(b, m, deterministic)
+}
+func (m *Bools) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Bools.Merge(m, src)
+}
+func (m *Bools) XXX_Size() int {
+ return xxx_messageInfo_Bools.Size(m)
+}
+func (m *Bools) XXX_DiscardUnknown() {
+ xxx_messageInfo_Bools.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Bools proto.InternalMessageInfo
+
+func (m *Bools) GetValues() []bool {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Float64S struct {
+ Values []float64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Float64S) Reset() { *m = Float64S{} }
+func (m *Float64S) String() string { return proto.CompactTextString(m) }
+func (*Float64S) ProtoMessage() {}
+func (*Float64S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{15}
+}
+
+func (m *Float64S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Float64S.Unmarshal(m, b)
+}
+func (m *Float64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Float64S.Marshal(b, m, deterministic)
+}
+func (m *Float64S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Float64S.Merge(m, src)
+}
+func (m *Float64S) XXX_Size() int {
+ return xxx_messageInfo_Float64S.Size(m)
+}
+func (m *Float64S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Float64S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Float64S proto.InternalMessageInfo
+
+func (m *Float64S) GetValues() []float64 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Float32S struct {
+ Values []float32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Float32S) Reset() { *m = Float32S{} }
+func (m *Float32S) String() string { return proto.CompactTextString(m) }
+func (*Float32S) ProtoMessage() {}
+func (*Float32S) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{16}
+}
+
+func (m *Float32S) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Float32S.Unmarshal(m, b)
+}
+func (m *Float32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Float32S.Marshal(b, m, deterministic)
+}
+func (m *Float32S) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Float32S.Merge(m, src)
+}
+func (m *Float32S) XXX_Size() int {
+ return xxx_messageInfo_Float32S.Size(m)
+}
+func (m *Float32S) XXX_DiscardUnknown() {
+ xxx_messageInfo_Float32S.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Float32S proto.InternalMessageInfo
+
+func (m *Float32S) GetValues() []float32 {
+ if m != nil {
+ return m.Values
+ }
+ return nil
+}
+
+type Object struct {
+ // Types that are valid to be assigned to Value:
+ // *Object_BoolValue
+ // *Object_StringValue
+ // *Object_Int64Value
+ // *Object_Uint64Value
+ // *Object_DoubleValue
+ // *Object_RefValue
+ // *Object_SliceValue
+ // *Object_ArrayValue
+ // *Object_InterfaceValue
+ // *Object_StructValue
+ // *Object_MapValue
+ // *Object_ByteArrayValue
+ // *Object_Uint16ArrayValue
+ // *Object_Uint32ArrayValue
+ // *Object_Uint64ArrayValue
+ // *Object_UintptrArrayValue
+ // *Object_Int8ArrayValue
+ // *Object_Int16ArrayValue
+ // *Object_Int32ArrayValue
+ // *Object_Int64ArrayValue
+ // *Object_BoolArrayValue
+ // *Object_Float64ArrayValue
+ // *Object_Float32ArrayValue
+ Value isObject_Value `protobuf_oneof:"value"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Object) Reset() { *m = Object{} }
+func (m *Object) String() string { return proto.CompactTextString(m) }
+func (*Object) ProtoMessage() {}
+func (*Object) Descriptor() ([]byte, []int) {
+ return fileDescriptor_3dee2c1912d4d62d, []int{17}
+}
+
+func (m *Object) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Object.Unmarshal(m, b)
+}
+func (m *Object) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Object.Marshal(b, m, deterministic)
+}
+func (m *Object) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Object.Merge(m, src)
+}
+func (m *Object) XXX_Size() int {
+ return xxx_messageInfo_Object.Size(m)
+}
+func (m *Object) XXX_DiscardUnknown() {
+ xxx_messageInfo_Object.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Object proto.InternalMessageInfo
+
+type isObject_Value interface {
+ isObject_Value()
+}
+
+type Object_BoolValue struct {
+ BoolValue bool `protobuf:"varint,1,opt,name=bool_value,json=boolValue,proto3,oneof"`
+}
+
+type Object_StringValue struct {
+ StringValue []byte `protobuf:"bytes,2,opt,name=string_value,json=stringValue,proto3,oneof"`
+}
+
+type Object_Int64Value struct {
+ Int64Value int64 `protobuf:"varint,3,opt,name=int64_value,json=int64Value,proto3,oneof"`
+}
+
+type Object_Uint64Value struct {
+ Uint64Value uint64 `protobuf:"varint,4,opt,name=uint64_value,json=uint64Value,proto3,oneof"`
+}
+
+type Object_DoubleValue struct {
+ DoubleValue float64 `protobuf:"fixed64,5,opt,name=double_value,json=doubleValue,proto3,oneof"`
+}
+
+type Object_RefValue struct {
+ RefValue uint64 `protobuf:"varint,6,opt,name=ref_value,json=refValue,proto3,oneof"`
+}
+
+type Object_SliceValue struct {
+ SliceValue *Slice `protobuf:"bytes,7,opt,name=slice_value,json=sliceValue,proto3,oneof"`
+}
+
+type Object_ArrayValue struct {
+ ArrayValue *Array `protobuf:"bytes,8,opt,name=array_value,json=arrayValue,proto3,oneof"`
+}
+
+type Object_InterfaceValue struct {
+ InterfaceValue *Interface `protobuf:"bytes,9,opt,name=interface_value,json=interfaceValue,proto3,oneof"`
+}
+
+type Object_StructValue struct {
+ StructValue *Struct `protobuf:"bytes,10,opt,name=struct_value,json=structValue,proto3,oneof"`
+}
+
+type Object_MapValue struct {
+ MapValue *Map `protobuf:"bytes,11,opt,name=map_value,json=mapValue,proto3,oneof"`
+}
+
+type Object_ByteArrayValue struct {
+ ByteArrayValue []byte `protobuf:"bytes,12,opt,name=byte_array_value,json=byteArrayValue,proto3,oneof"`
+}
+
+type Object_Uint16ArrayValue struct {
+ Uint16ArrayValue *Uint16S `protobuf:"bytes,13,opt,name=uint16_array_value,json=uint16ArrayValue,proto3,oneof"`
+}
+
+type Object_Uint32ArrayValue struct {
+ Uint32ArrayValue *Uint32S `protobuf:"bytes,14,opt,name=uint32_array_value,json=uint32ArrayValue,proto3,oneof"`
+}
+
+type Object_Uint64ArrayValue struct {
+ Uint64ArrayValue *Uint64S `protobuf:"bytes,15,opt,name=uint64_array_value,json=uint64ArrayValue,proto3,oneof"`
+}
+
+type Object_UintptrArrayValue struct {
+ UintptrArrayValue *Uintptrs `protobuf:"bytes,16,opt,name=uintptr_array_value,json=uintptrArrayValue,proto3,oneof"`
+}
+
+type Object_Int8ArrayValue struct {
+ Int8ArrayValue *Int8S `protobuf:"bytes,17,opt,name=int8_array_value,json=int8ArrayValue,proto3,oneof"`
+}
+
+type Object_Int16ArrayValue struct {
+ Int16ArrayValue *Int16S `protobuf:"bytes,18,opt,name=int16_array_value,json=int16ArrayValue,proto3,oneof"`
+}
+
+type Object_Int32ArrayValue struct {
+ Int32ArrayValue *Int32S `protobuf:"bytes,19,opt,name=int32_array_value,json=int32ArrayValue,proto3,oneof"`
+}
+
+type Object_Int64ArrayValue struct {
+ Int64ArrayValue *Int64S `protobuf:"bytes,20,opt,name=int64_array_value,json=int64ArrayValue,proto3,oneof"`
+}
+
+type Object_BoolArrayValue struct {
+ BoolArrayValue *Bools `protobuf:"bytes,21,opt,name=bool_array_value,json=boolArrayValue,proto3,oneof"`
+}
+
+type Object_Float64ArrayValue struct {
+ Float64ArrayValue *Float64S `protobuf:"bytes,22,opt,name=float64_array_value,json=float64ArrayValue,proto3,oneof"`
+}
+
+type Object_Float32ArrayValue struct {
+ Float32ArrayValue *Float32S `protobuf:"bytes,23,opt,name=float32_array_value,json=float32ArrayValue,proto3,oneof"`
+}
+
+func (*Object_BoolValue) isObject_Value() {}
+
+func (*Object_StringValue) isObject_Value() {}
+
+func (*Object_Int64Value) isObject_Value() {}
+
+func (*Object_Uint64Value) isObject_Value() {}
+
+func (*Object_DoubleValue) isObject_Value() {}
+
+func (*Object_RefValue) isObject_Value() {}
+
+func (*Object_SliceValue) isObject_Value() {}
+
+func (*Object_ArrayValue) isObject_Value() {}
+
+func (*Object_InterfaceValue) isObject_Value() {}
+
+func (*Object_StructValue) isObject_Value() {}
+
+func (*Object_MapValue) isObject_Value() {}
+
+func (*Object_ByteArrayValue) isObject_Value() {}
+
+func (*Object_Uint16ArrayValue) isObject_Value() {}
+
+func (*Object_Uint32ArrayValue) isObject_Value() {}
+
+func (*Object_Uint64ArrayValue) isObject_Value() {}
+
+func (*Object_UintptrArrayValue) isObject_Value() {}
+
+func (*Object_Int8ArrayValue) isObject_Value() {}
+
+func (*Object_Int16ArrayValue) isObject_Value() {}
+
+func (*Object_Int32ArrayValue) isObject_Value() {}
+
+func (*Object_Int64ArrayValue) isObject_Value() {}
+
+func (*Object_BoolArrayValue) isObject_Value() {}
+
+func (*Object_Float64ArrayValue) isObject_Value() {}
+
+func (*Object_Float32ArrayValue) isObject_Value() {}
+
+func (m *Object) GetValue() isObject_Value {
+ if m != nil {
+ return m.Value
+ }
+ return nil
+}
+
+func (m *Object) GetBoolValue() bool {
+ if x, ok := m.GetValue().(*Object_BoolValue); ok {
+ return x.BoolValue
+ }
+ return false
+}
+
+func (m *Object) GetStringValue() []byte {
+ if x, ok := m.GetValue().(*Object_StringValue); ok {
+ return x.StringValue
+ }
+ return nil
+}
+
+func (m *Object) GetInt64Value() int64 {
+ if x, ok := m.GetValue().(*Object_Int64Value); ok {
+ return x.Int64Value
+ }
+ return 0
+}
+
+func (m *Object) GetUint64Value() uint64 {
+ if x, ok := m.GetValue().(*Object_Uint64Value); ok {
+ return x.Uint64Value
+ }
+ return 0
+}
+
+func (m *Object) GetDoubleValue() float64 {
+ if x, ok := m.GetValue().(*Object_DoubleValue); ok {
+ return x.DoubleValue
+ }
+ return 0
+}
+
+func (m *Object) GetRefValue() uint64 {
+ if x, ok := m.GetValue().(*Object_RefValue); ok {
+ return x.RefValue
+ }
+ return 0
+}
+
+func (m *Object) GetSliceValue() *Slice {
+ if x, ok := m.GetValue().(*Object_SliceValue); ok {
+ return x.SliceValue
+ }
+ return nil
+}
+
+func (m *Object) GetArrayValue() *Array {
+ if x, ok := m.GetValue().(*Object_ArrayValue); ok {
+ return x.ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetInterfaceValue() *Interface {
+ if x, ok := m.GetValue().(*Object_InterfaceValue); ok {
+ return x.InterfaceValue
+ }
+ return nil
+}
+
+func (m *Object) GetStructValue() *Struct {
+ if x, ok := m.GetValue().(*Object_StructValue); ok {
+ return x.StructValue
+ }
+ return nil
+}
+
+func (m *Object) GetMapValue() *Map {
+ if x, ok := m.GetValue().(*Object_MapValue); ok {
+ return x.MapValue
+ }
+ return nil
+}
+
+func (m *Object) GetByteArrayValue() []byte {
+ if x, ok := m.GetValue().(*Object_ByteArrayValue); ok {
+ return x.ByteArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetUint16ArrayValue() *Uint16S {
+ if x, ok := m.GetValue().(*Object_Uint16ArrayValue); ok {
+ return x.Uint16ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetUint32ArrayValue() *Uint32S {
+ if x, ok := m.GetValue().(*Object_Uint32ArrayValue); ok {
+ return x.Uint32ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetUint64ArrayValue() *Uint64S {
+ if x, ok := m.GetValue().(*Object_Uint64ArrayValue); ok {
+ return x.Uint64ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetUintptrArrayValue() *Uintptrs {
+ if x, ok := m.GetValue().(*Object_UintptrArrayValue); ok {
+ return x.UintptrArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetInt8ArrayValue() *Int8S {
+ if x, ok := m.GetValue().(*Object_Int8ArrayValue); ok {
+ return x.Int8ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetInt16ArrayValue() *Int16S {
+ if x, ok := m.GetValue().(*Object_Int16ArrayValue); ok {
+ return x.Int16ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetInt32ArrayValue() *Int32S {
+ if x, ok := m.GetValue().(*Object_Int32ArrayValue); ok {
+ return x.Int32ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetInt64ArrayValue() *Int64S {
+ if x, ok := m.GetValue().(*Object_Int64ArrayValue); ok {
+ return x.Int64ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetBoolArrayValue() *Bools {
+ if x, ok := m.GetValue().(*Object_BoolArrayValue); ok {
+ return x.BoolArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetFloat64ArrayValue() *Float64S {
+ if x, ok := m.GetValue().(*Object_Float64ArrayValue); ok {
+ return x.Float64ArrayValue
+ }
+ return nil
+}
+
+func (m *Object) GetFloat32ArrayValue() *Float32S {
+ if x, ok := m.GetValue().(*Object_Float32ArrayValue); ok {
+ return x.Float32ArrayValue
+ }
+ return nil
+}
+
+// XXX_OneofWrappers is for the internal use of the proto package.
+func (*Object) XXX_OneofWrappers() []interface{} {
+ return []interface{}{
+ (*Object_BoolValue)(nil),
+ (*Object_StringValue)(nil),
+ (*Object_Int64Value)(nil),
+ (*Object_Uint64Value)(nil),
+ (*Object_DoubleValue)(nil),
+ (*Object_RefValue)(nil),
+ (*Object_SliceValue)(nil),
+ (*Object_ArrayValue)(nil),
+ (*Object_InterfaceValue)(nil),
+ (*Object_StructValue)(nil),
+ (*Object_MapValue)(nil),
+ (*Object_ByteArrayValue)(nil),
+ (*Object_Uint16ArrayValue)(nil),
+ (*Object_Uint32ArrayValue)(nil),
+ (*Object_Uint64ArrayValue)(nil),
+ (*Object_UintptrArrayValue)(nil),
+ (*Object_Int8ArrayValue)(nil),
+ (*Object_Int16ArrayValue)(nil),
+ (*Object_Int32ArrayValue)(nil),
+ (*Object_Int64ArrayValue)(nil),
+ (*Object_BoolArrayValue)(nil),
+ (*Object_Float64ArrayValue)(nil),
+ (*Object_Float32ArrayValue)(nil),
+ }
+}
+
+func init() {
+ proto.RegisterType((*Slice)(nil), "gvisor.state.statefile.Slice")
+ proto.RegisterType((*Array)(nil), "gvisor.state.statefile.Array")
+ proto.RegisterType((*Map)(nil), "gvisor.state.statefile.Map")
+ proto.RegisterType((*Interface)(nil), "gvisor.state.statefile.Interface")
+ proto.RegisterType((*Struct)(nil), "gvisor.state.statefile.Struct")
+ proto.RegisterType((*Field)(nil), "gvisor.state.statefile.Field")
+ proto.RegisterType((*Uint16S)(nil), "gvisor.state.statefile.Uint16s")
+ proto.RegisterType((*Uint32S)(nil), "gvisor.state.statefile.Uint32s")
+ proto.RegisterType((*Uint64S)(nil), "gvisor.state.statefile.Uint64s")
+ proto.RegisterType((*Uintptrs)(nil), "gvisor.state.statefile.Uintptrs")
+ proto.RegisterType((*Int8S)(nil), "gvisor.state.statefile.Int8s")
+ proto.RegisterType((*Int16S)(nil), "gvisor.state.statefile.Int16s")
+ proto.RegisterType((*Int32S)(nil), "gvisor.state.statefile.Int32s")
+ proto.RegisterType((*Int64S)(nil), "gvisor.state.statefile.Int64s")
+ proto.RegisterType((*Bools)(nil), "gvisor.state.statefile.Bools")
+ proto.RegisterType((*Float64S)(nil), "gvisor.state.statefile.Float64s")
+ proto.RegisterType((*Float32S)(nil), "gvisor.state.statefile.Float32s")
+ proto.RegisterType((*Object)(nil), "gvisor.state.statefile.Object")
+}
+
+func init() { proto.RegisterFile("pkg/state/object.proto", fileDescriptor_3dee2c1912d4d62d) }
+
+var fileDescriptor_3dee2c1912d4d62d = []byte{
+ // 781 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x96, 0x6f, 0x4f, 0xda, 0x5e,
+ 0x14, 0xc7, 0xa9, 0x40, 0x29, 0x07, 0x14, 0xb8, 0xfe, 0x7e, 0x8c, 0xcc, 0x38, 0xb1, 0x7b, 0x42,
+ 0xf6, 0x00, 0x33, 0x60, 0xc4, 0xf8, 0x64, 0x53, 0x13, 0x03, 0xc9, 0x8c, 0x59, 0x8d, 0xcb, 0x9e,
+ 0x99, 0x52, 0x2f, 0xac, 0xb3, 0xb6, 0x5d, 0x7b, 0x6b, 0xc2, 0xcb, 0xdc, 0x3b, 0x5a, 0xee, 0x1f,
+ 0xae, 0xfd, 0x03, 0xc5, 0xec, 0x89, 0xa1, 0xb7, 0xdf, 0xf3, 0xe1, 0xdc, 0xf3, 0x3d, 0xe7, 0x08,
+ 0xb4, 0xfd, 0xc7, 0xc5, 0x49, 0x48, 0x4c, 0x82, 0x4f, 0xbc, 0xd9, 0x2f, 0x6c, 0x91, 0xbe, 0x1f,
+ 0x78, 0xc4, 0x43, 0xed, 0xc5, 0xb3, 0x1d, 0x7a, 0x41, 0x9f, 0xbd, 0xe2, 0x7f, 0xe7, 0xb6, 0x83,
+ 0xf5, 0x1f, 0x50, 0xbe, 0x75, 0x6c, 0x0b, 0xa3, 0x36, 0xa8, 0x0e, 0x76, 0x17, 0xe4, 0x67, 0x47,
+ 0xe9, 0x2a, 0xbd, 0x5d, 0x43, 0x3c, 0xa1, 0xb7, 0xa0, 0x59, 0xa6, 0x6f, 0x5a, 0x36, 0x59, 0x76,
+ 0x76, 0xd8, 0x1b, 0xf9, 0x8c, 0x0e, 0xa0, 0x1a, 0xe0, 0xf9, 0xfd, 0xb3, 0xe9, 0x44, 0xb8, 0x53,
+ 0xec, 0x2a, 0xbd, 0x92, 0xa1, 0x05, 0x78, 0xfe, 0x9d, 0x3e, 0xeb, 0x97, 0x50, 0x3e, 0x0f, 0x02,
+ 0x73, 0x89, 0xce, 0x40, 0xb3, 0x3c, 0x97, 0x60, 0x97, 0x84, 0x1d, 0xa5, 0x5b, 0xec, 0xd5, 0x06,
+ 0xef, 0xfa, 0xeb, 0xb3, 0xe9, 0xdf, 0xb0, 0x94, 0x0d, 0xa9, 0xd7, 0x7f, 0x43, 0xf1, 0xda, 0xf4,
+ 0xd1, 0x00, 0x4a, 0x8f, 0x78, 0xf9, 0xda, 0x70, 0xa6, 0x45, 0x63, 0x50, 0x59, 0x62, 0x61, 0x67,
+ 0xe7, 0x55, 0x51, 0x42, 0xad, 0xdf, 0x41, 0x75, 0xea, 0x12, 0x1c, 0xcc, 0x4d, 0x0b, 0x23, 0x04,
+ 0x25, 0xb2, 0xf4, 0x31, 0xab, 0x49, 0xd5, 0x60, 0x9f, 0xd1, 0x08, 0xca, 0xfc, 0xc6, 0xb4, 0x1c,
+ 0xdb, 0xb9, 0x5c, 0xac, 0x7f, 0x06, 0xf5, 0x96, 0x04, 0x91, 0x45, 0xd0, 0x27, 0x50, 0xe7, 0x36,
+ 0x76, 0x1e, 0x56, 0xd7, 0x39, 0xdc, 0x04, 0xb8, 0xa2, 0x2a, 0x43, 0x88, 0xf5, 0x6f, 0x50, 0x66,
+ 0x07, 0x34, 0x27, 0xd7, 0x7c, 0x92, 0x39, 0xd1, 0xcf, 0xff, 0x98, 0xd3, 0x31, 0x54, 0xee, 0x6c,
+ 0x97, 0x7c, 0x1c, 0x87, 0xd4, 0x7e, 0x51, 0x2d, 0x9a, 0xd4, 0xae, 0xac, 0x86, 0x90, 0x0c, 0x07,
+ 0x69, 0x49, 0x25, 0x2d, 0x19, 0x8f, 0xd2, 0x12, 0x55, 0x4a, 0x74, 0xd0, 0xa8, 0xc4, 0x27, 0xc1,
+ 0x66, 0xcd, 0x11, 0x94, 0xa7, 0x2e, 0x39, 0x4d, 0x0a, 0x94, 0x5e, 0x5d, 0x0a, 0xba, 0xa0, 0x4e,
+ 0xd7, 0x25, 0x5b, 0x4e, 0x29, 0xb2, 0xb9, 0x36, 0x52, 0x8a, 0x6c, 0xaa, 0xcd, 0x78, 0x1a, 0x17,
+ 0x9e, 0xe7, 0xa4, 0x05, 0x5a, 0xfc, 0x2e, 0x57, 0x8e, 0x67, 0xae, 0x81, 0x28, 0x19, 0x4d, 0x36,
+ 0x95, 0x1d, 0xa9, 0xf9, 0x53, 0x03, 0x95, 0xdb, 0x81, 0x8e, 0x00, 0x66, 0x9e, 0xe7, 0x88, 0x41,
+ 0xa2, 0xb7, 0xd6, 0x26, 0x05, 0xa3, 0x4a, 0xcf, 0xd8, 0x2c, 0xa1, 0xf7, 0x50, 0x0f, 0x49, 0x60,
+ 0xbb, 0x8b, 0xfb, 0x17, 0x97, 0xeb, 0x93, 0x82, 0x51, 0xe3, 0xa7, 0x5c, 0x74, 0x0c, 0x35, 0x66,
+ 0x43, 0x6c, 0x1e, 0x8b, 0x93, 0x82, 0x01, 0xec, 0x50, 0x72, 0xa2, 0xb8, 0xa6, 0x44, 0x67, 0x96,
+ 0x72, 0xa2, 0xa4, 0xe8, 0xc1, 0x8b, 0x66, 0x0e, 0x16, 0xa2, 0x72, 0x57, 0xe9, 0x29, 0x54, 0xc4,
+ 0x4f, 0xb9, 0xe8, 0x30, 0x3e, 0xfa, 0xaa, 0xc0, 0xc8, 0xe1, 0x47, 0x5f, 0xa0, 0x16, 0xd2, 0xb5,
+ 0x22, 0x04, 0x15, 0xd6, 0x95, 0x1b, 0x1b, 0x9d, 0x6d, 0x20, 0x9a, 0x2a, 0x8b, 0x91, 0x04, 0x93,
+ 0xae, 0x0f, 0x41, 0xd0, 0xf2, 0x09, 0x6c, 0xd3, 0x50, 0x02, 0x8b, 0xe1, 0x84, 0xaf, 0xd0, 0xb0,
+ 0x57, 0x83, 0x2c, 0x28, 0x55, 0x46, 0x39, 0xde, 0x44, 0x91, 0x73, 0x3f, 0x29, 0x18, 0x7b, 0x32,
+ 0x96, 0xd3, 0x2e, 0x99, 0x05, 0x91, 0x45, 0x04, 0x0a, 0xf2, 0x07, 0x8d, 0xcf, 0xba, 0xb0, 0x28,
+ 0xb2, 0x08, 0x87, 0x9c, 0x41, 0xf5, 0xc9, 0xf4, 0x05, 0xa1, 0xc6, 0x08, 0x07, 0x9b, 0x08, 0xd7,
+ 0xa6, 0x4f, 0x4b, 0xfa, 0x64, 0xfa, 0x3c, 0xf6, 0x03, 0x34, 0x67, 0x4b, 0x82, 0xef, 0xe3, 0x55,
+ 0xa9, 0x8b, 0x3e, 0xd8, 0xa3, 0x6f, 0xce, 0x5f, 0xae, 0x7e, 0x03, 0x28, 0x62, 0x83, 0x9d, 0x50,
+ 0xef, 0xb2, 0x2f, 0x3c, 0xda, 0xf4, 0x85, 0x62, 0x15, 0x4c, 0x0a, 0x46, 0x93, 0x07, 0x67, 0x81,
+ 0xc3, 0x41, 0x02, 0xb8, 0xb7, 0x1d, 0x38, 0x1c, 0x48, 0xe0, 0x70, 0x90, 0x05, 0x8e, 0x47, 0x09,
+ 0x60, 0x63, 0x3b, 0x70, 0x3c, 0x92, 0xc0, 0xf1, 0x28, 0x06, 0x34, 0x60, 0x3f, 0xe2, 0x2b, 0x26,
+ 0x41, 0x6c, 0x32, 0x62, 0x37, 0x8f, 0x48, 0xb7, 0xd2, 0xa4, 0x60, 0xb4, 0x44, 0x78, 0x8c, 0x39,
+ 0x85, 0xa6, 0xed, 0x92, 0xd3, 0x04, 0xb0, 0x95, 0xdf, 0x88, 0x6c, 0x85, 0x89, 0xf6, 0x39, 0x3d,
+ 0x8f, 0x37, 0x63, 0x2b, 0x6b, 0x08, 0xca, 0xef, 0xa1, 0xe9, 0xca, 0x8f, 0x46, 0xda, 0x0e, 0x4e,
+ 0x4b, 0xb9, 0xb1, 0xbf, 0x95, 0xc6, 0xcd, 0x68, 0xa4, 0xbd, 0xe0, 0xb4, 0x94, 0x15, 0xff, 0x6d,
+ 0xa5, 0x71, 0x27, 0x1a, 0x69, 0x23, 0xa6, 0xd0, 0x64, 0xcb, 0x2c, 0x0e, 0xfb, 0x3f, 0xbf, 0x68,
+ 0x6c, 0xe1, 0xb2, 0x36, 0xf6, 0x3c, 0x27, 0xe9, 0xe9, 0x9c, 0xaf, 0xda, 0x04, 0xad, 0x9d, 0xef,
+ 0xe9, 0x6a, 0x3b, 0x53, 0x4f, 0x45, 0xf8, 0x1a, 0x66, 0xaa, 0x78, 0x6f, 0x5e, 0xc1, 0xe4, 0xe5,
+ 0x6b, 0x89, 0xf0, 0x17, 0xe6, 0x45, 0x45, 0xfc, 0xf7, 0x9d, 0xa9, 0xec, 0xc7, 0xd6, 0xf0, 0x6f,
+ 0x00, 0x00, 0x00, 0xff, 0xff, 0x84, 0x69, 0xc9, 0x45, 0x86, 0x09, 0x00, 0x00,
+}
diff --git a/pkg/state/printer.go b/pkg/state/printer.go
new file mode 100644
index 000000000..5174c3ba3
--- /dev/null
+++ b/pkg/state/printer.go
@@ -0,0 +1,251 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "strings"
+
+ "github.com/golang/protobuf/proto"
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// format formats a single object, for pretty-printing. It also returns whether
+// the value is a non-zero value.
+func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) {
+ switch x := object.GetValue().(type) {
+ case *pb.Object_BoolValue:
+ return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false
+ case *pb.Object_StringValue:
+ return fmt.Sprintf("\"%s\"", string(x.StringValue)), len(x.StringValue) != 0
+ case *pb.Object_Int64Value:
+ return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0
+ case *pb.Object_Uint64Value:
+ return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0
+ case *pb.Object_DoubleValue:
+ return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0
+ case *pb.Object_RefValue:
+ if x.RefValue == 0 {
+ return "nil", false
+ }
+ ref := fmt.Sprintf("g%dr%d", graph, x.RefValue)
+ if html {
+ ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
+ }
+ return ref, true
+ case *pb.Object_SliceValue:
+ if x.SliceValue.RefValue == 0 {
+ return "nil", false
+ }
+ ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue)
+ if html {
+ ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
+ }
+ return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true
+ case *pb.Object_ArrayValue:
+ if len(x.ArrayValue.Contents) == 0 {
+ return "[]", false
+ }
+ items := make([]string, 0, len(x.ArrayValue.Contents)+2)
+ zeros := make([]string, 0) // used to eliminate zero entries.
+ items = append(items, "[")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.ArrayValue.Contents); i++ {
+ item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html)
+ if ok {
+ if len(zeros) > 0 {
+ items = append(items, zeros...)
+ zeros = nil
+ }
+ items = append(items, fmt.Sprintf("\t%s,", item))
+ } else {
+ zeros = append(zeros, fmt.Sprintf("\t%s,", item))
+ }
+ }
+ if len(zeros) > 0 {
+ items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
+ }
+ items = append(items, "]")
+ return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents)
+ case *pb.Object_StructValue:
+ if len(x.StructValue.Fields) == 0 {
+ return "struct{}", false
+ }
+ items := make([]string, 0, len(x.StructValue.Fields)+2)
+ items = append(items, "struct{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ allZero := true
+ for _, field := range x.StructValue.Fields {
+ element, ok := format(graph, depth+1, field.Value, html)
+ allZero = allZero && !ok
+ items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), !allZero
+ case *pb.Object_MapValue:
+ if len(x.MapValue.Keys) == 0 {
+ return "map{}", false
+ }
+ items := make([]string, 0, len(x.MapValue.Keys)+2)
+ items = append(items, "map{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.MapValue.Keys); i++ {
+ key, _ := format(graph, depth+1, x.MapValue.Keys[i], html)
+ value, _ := format(graph, depth+1, x.MapValue.Values[i], html)
+ items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true
+ case *pb.Object_InterfaceValue:
+ if x.InterfaceValue.Type == "" {
+ return "interface(nil){}", false
+ }
+ element, _ := format(graph, depth+1, x.InterfaceValue.Value, html)
+ return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true
+ case *pb.Object_ByteArrayValue:
+ return printArray(reflect.ValueOf(x.ByteArrayValue))
+ case *pb.Object_Uint16ArrayValue:
+ return printArray(reflect.ValueOf(x.Uint16ArrayValue.Values))
+ case *pb.Object_Uint32ArrayValue:
+ return printArray(reflect.ValueOf(x.Uint32ArrayValue.Values))
+ case *pb.Object_Uint64ArrayValue:
+ return printArray(reflect.ValueOf(x.Uint64ArrayValue.Values))
+ case *pb.Object_UintptrArrayValue:
+ return printArray(castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
+ case *pb.Object_Int8ArrayValue:
+ return printArray(castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
+ case *pb.Object_Int16ArrayValue:
+ return printArray(reflect.ValueOf(x.Int16ArrayValue.Values))
+ case *pb.Object_Int32ArrayValue:
+ return printArray(reflect.ValueOf(x.Int32ArrayValue.Values))
+ case *pb.Object_Int64ArrayValue:
+ return printArray(reflect.ValueOf(x.Int64ArrayValue.Values))
+ case *pb.Object_BoolArrayValue:
+ return printArray(reflect.ValueOf(x.BoolArrayValue.Values))
+ case *pb.Object_Float64ArrayValue:
+ return printArray(reflect.ValueOf(x.Float64ArrayValue.Values))
+ case *pb.Object_Float32ArrayValue:
+ return printArray(reflect.ValueOf(x.Float32ArrayValue.Values))
+ }
+
+ // Should not happen, but tolerate.
+ return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true
+}
+
+// PrettyPrint reads the state stream from r, and pretty prints to w.
+func PrettyPrint(w io.Writer, r io.Reader, html bool) error {
+ var (
+ // current graph ID.
+ graph uint64
+
+ // current object ID.
+ id uint64
+ )
+
+ if html {
+ fmt.Fprintf(w, "<pre>")
+ defer fmt.Fprintf(w, "</pre>")
+ }
+
+ for {
+ // Find the first object to begin generation.
+ length, object, err := ReadHeader(r)
+ if err == io.EOF {
+ // Nothing else to do.
+ break
+ } else if err != nil {
+ return err
+ }
+ if !object {
+ // Increment the graph number & reset the ID.
+ graph++
+ id = 0
+ if length > 0 {
+ fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
+ io.Copy(ioutil.Discard, &io.LimitedReader{
+ R: r,
+ N: int64(length),
+ })
+ }
+ continue
+ }
+
+ // Read & unmarshal the object.
+ buf := make([]byte, length)
+ for done := 0; done < len(buf); {
+ n, err := r.Read(buf[done:])
+ done += n
+ if n == 0 && err != nil {
+ return err
+ }
+ }
+ obj := new(pb.Object)
+ if err := proto.Unmarshal(buf, obj); err != nil {
+ return err
+ }
+
+ id++ // First object must be one.
+ str, _ := format(graph, 0, obj, html)
+ tag := fmt.Sprintf("g%dr%d", graph, id)
+ if html {
+ tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func printArray(s reflect.Value) (string, bool) {
+ zero := reflect.Zero(s.Type().Elem()).Interface()
+ z := "0"
+ switch s.Type().Elem().Kind() {
+ case reflect.Bool:
+ z = "false"
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ case reflect.Float32, reflect.Float64:
+ default:
+ return fmt.Sprintf("unexpected non-primitive type array: %#v", s.Interface()), true
+ }
+
+ zeros := 0
+ items := make([]string, 0, s.Len())
+ for i := 0; i <= s.Len(); i++ {
+ if i < s.Len() && reflect.DeepEqual(s.Index(i).Interface(), zero) {
+ zeros++
+ continue
+ }
+ if zeros > 0 {
+ if zeros <= 4 {
+ for ; zeros > 0; zeros-- {
+ items = append(items, z)
+ }
+ } else {
+ items = append(items, fmt.Sprintf("(%d %ss)", zeros, z))
+ zeros = 0
+ }
+ }
+ if i < s.Len() {
+ items = append(items, fmt.Sprintf("%v", s.Index(i).Interface()))
+ }
+ }
+ return "[" + strings.Join(items, ",") + "]", zeros < s.Len()
+}
diff --git a/pkg/state/state.go b/pkg/state/state.go
new file mode 100644
index 000000000..cf7df803a
--- /dev/null
+++ b/pkg/state/state.go
@@ -0,0 +1,359 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package state provides functionality related to saving and loading object
+// graphs. For most types, it provides a set of default saving / loading logic
+// that will be invoked automatically if custom logic is not defined.
+//
+// Kind Support
+// ---- -------
+// Bool default
+// Int default
+// Int8 default
+// Int16 default
+// Int32 default
+// Int64 default
+// Uint default
+// Uint8 default
+// Uint16 default
+// Uint32 default
+// Uint64 default
+// Float32 default
+// Float64 default
+// Complex64 custom
+// Complex128 custom
+// Array default
+// Chan custom
+// Func custom
+// Interface custom
+// Map default (*)
+// Ptr default
+// Slice default
+// String default
+// Struct custom
+// UnsafePointer custom
+//
+// (*) Maps are treated as value types by this package, even if they are
+// pointers internally. If you want to save two independent references
+// to the same map value, you must explicitly use a pointer to a map.
+package state
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "runtime"
+
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// ErrState is returned when an error is encountered during encode/decode.
+type ErrState struct {
+ // err is the underlying error.
+ err error
+
+ // path is the visit path from root to the current object.
+ path string
+
+ // trace is the stack trace.
+ trace string
+}
+
+// Error returns a sensible description of the state error.
+func (e *ErrState) Error() string {
+ return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace)
+}
+
+// UnwrapErrState returns the underlying error in ErrState.
+//
+// If err is not *ErrState, err is returned directly.
+func UnwrapErrState(err error) error {
+ if e, ok := err.(*ErrState); ok {
+ return e.err
+ }
+ return err
+}
+
+// Save saves the given object state.
+func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
+ // Create the encoding state.
+ es := &encodeState{
+ idsByObject: make(map[uintptr]uint64),
+ w: w,
+ stats: stats,
+ }
+
+ // Perform the encoding.
+ return es.safely(func() {
+ es.Serialize(reflect.ValueOf(rootPtr).Elem())
+ })
+}
+
+// Load loads a checkpoint.
+func Load(r io.Reader, rootPtr interface{}, stats *Stats) error {
+ // Create the decoding state.
+ ds := &decodeState{
+ objectsByID: make(map[uint64]*objectState),
+ deferred: make(map[uint64]*pb.Object),
+ r: r,
+ stats: stats,
+ }
+
+ // Attempt our decode.
+ return ds.safely(func() {
+ ds.Deserialize(reflect.ValueOf(rootPtr).Elem())
+ })
+}
+
+// Fns are the state dispatch functions.
+type Fns struct {
+ // Save is a function like Save(concreteType, Map).
+ Save interface{}
+
+ // Load is a function like Load(concreteType, Map).
+ Load interface{}
+}
+
+// Save executes the save function.
+func (fns *Fns) invokeSave(obj reflect.Value, m Map) {
+ reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+}
+
+// Load executes the load function.
+func (fns *Fns) invokeLoad(obj reflect.Value, m Map) {
+ reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+}
+
+// validateStateFn ensures types are correct.
+func validateStateFn(fn interface{}, typ reflect.Type) bool {
+ fnTyp := reflect.TypeOf(fn)
+ if fnTyp.Kind() != reflect.Func {
+ return false
+ }
+ if fnTyp.NumIn() != 2 {
+ return false
+ }
+ if fnTyp.NumOut() != 0 {
+ return false
+ }
+ if fnTyp.In(0) != typ {
+ return false
+ }
+ if fnTyp.In(1) != reflect.TypeOf(Map{}) {
+ return false
+ }
+ return true
+}
+
+// Validate validates all state functions.
+func (fns *Fns) Validate(typ reflect.Type) bool {
+ return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ)
+}
+
+type typeDatabase struct {
+ // nameToType is a forward lookup table.
+ nameToType map[string]reflect.Type
+
+ // typeToName is the reverse lookup table.
+ typeToName map[reflect.Type]string
+
+ // typeToFns is the function lookup table.
+ typeToFns map[reflect.Type]Fns
+}
+
+// registeredTypes is a database used for SaveInterface and LoadInterface.
+var registeredTypes = typeDatabase{
+ nameToType: make(map[string]reflect.Type),
+ typeToName: make(map[reflect.Type]string),
+ typeToFns: make(map[reflect.Type]Fns),
+}
+
+// register registers a type under the given name. This will generally be
+// called via init() methods, and therefore uses panic to propagate errors.
+func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) {
+ // We can't allow name collisions.
+ if ot, ok := t.nameToType[name]; ok {
+ panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name()))
+ }
+
+ // Or multiple registrations.
+ if on, ok := t.typeToName[typ]; ok {
+ panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on))
+ }
+
+ t.nameToType[name] = typ
+ t.typeToName[typ] = name
+ t.typeToFns[typ] = fns
+}
+
+// lookupType finds a type given a name.
+func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) {
+ typ, ok := t.nameToType[name]
+ return typ, ok
+}
+
+// lookupName finds a name given a type.
+func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) {
+ name, ok := t.typeToName[typ]
+ return name, ok
+}
+
+// lookupFns finds functions given a type.
+func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) {
+ fns, ok := t.typeToFns[typ]
+ return fns, ok
+}
+
+// Register must be called for any interface implementation types that
+// implements Loader.
+//
+// Register should be called either immediately after startup or via init()
+// methods. Double registration of either names or types will result in a panic.
+//
+// No synchronization is provided; this should only be called in init.
+//
+// Example usage:
+//
+// state.Register("Foo", (*Foo)(nil), state.Fns{
+// Save: (*Foo).Save,
+// Load: (*Foo).Load,
+// })
+//
+func Register(name string, instance interface{}, fns Fns) {
+ registeredTypes.register(name, reflect.TypeOf(instance), fns)
+}
+
+// IsZeroValue checks if the given value is the zero value.
+//
+// This function is used by the stateify tool.
+func IsZeroValue(val interface{}) bool {
+ if val == nil {
+ return true
+ }
+ return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
+}
+
+// step captures one encoding / decoding step. On each step, there is up to one
+// choice made, which is captured by non-nil param. We intentionally do not
+// eagerly create the final path string, as that will only be needed upon panic.
+type step struct {
+ // dereference indicate if the current object is obtained by
+ // dereferencing a pointer.
+ dereference bool
+
+ // format is the formatting string that takes param below, if
+ // non-nil. For example, in array indexing case, we have "[%d]".
+ format string
+
+ // param stores the choice made at the current encoding / decoding step.
+ // For eaxmple, in array indexing case, param stores the index. When no
+ // choice is made, e.g. dereference, param should be nil.
+ param interface{}
+}
+
+// recoverable is the state encoding / decoding panic recovery facility. It is
+// also used to store encoding / decoding steps as well as the reference to the
+// original queued object from which the current object is dispatched. The
+// complete encoding / decoding path is synthesised from the steps in all queued
+// objects leading to the current object.
+type recoverable struct {
+ from *recoverable
+ steps []step
+}
+
+// push enters a new context level.
+func (sr *recoverable) push(dereference bool, format string, param interface{}) {
+ sr.steps = append(sr.steps, step{dereference, format, param})
+}
+
+// pop exits the current context level.
+func (sr *recoverable) pop() {
+ if len(sr.steps) <= 1 {
+ return
+ }
+ sr.steps = sr.steps[:len(sr.steps)-1]
+}
+
+// path returns the complete encoding / decoding path from root. This is only
+// called upon panic.
+func (sr *recoverable) path() string {
+ if sr.from == nil {
+ return "root"
+ }
+ p := sr.from.path()
+ for _, s := range sr.steps {
+ if s.dereference {
+ p = fmt.Sprintf("*(%s)", p)
+ }
+ if s.param == nil {
+ p += s.format
+ } else {
+ p += fmt.Sprintf(s.format, s.param)
+ }
+ }
+ return p
+}
+
+func (sr *recoverable) copy() recoverable {
+ return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)}
+}
+
+// safely executes the given function, catching a panic and unpacking as an error.
+//
+// The error flow through the state package uses panic and recover. There are
+// two important reasons for this:
+//
+// 1) Many of the reflection methods will already panic with invalid data or
+// violated assumptions. We would want to recover anyways here.
+//
+// 2) It allows us to eliminate boilerplate within Save() and Load() functions.
+// In nearly all cases, when the low-level serialization functions fail, you
+// will want the checkpoint to fail anyways. Plumbing errors through every
+// method doesn't add a lot of value. If there are specific error conditions
+// that you'd like to handle, you should add appropriate functionality to
+// objects themselves prior to calling Save() and Load().
+func (sr *recoverable) safely(fn func()) (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ es := new(ErrState)
+ if e, ok := r.(error); ok {
+ es.err = e
+ } else {
+ es.err = fmt.Errorf("%v", r)
+ }
+
+ es.path = sr.path()
+
+ // Make a stack. We don't know how big it will be ahead
+ // of time, but want to make sure we get the whole
+ // thing. So we just do a stupid brute force approach.
+ var stack []byte
+ for sz := 1024; ; sz *= 2 {
+ stack = make([]byte, sz)
+ n := runtime.Stack(stack, false)
+ if n < sz {
+ es.trace = string(stack[:n])
+ break
+ }
+ }
+
+ // Set the error.
+ err = es
+ }
+ }()
+
+ // Execute the function.
+ fn()
+ return nil
+}
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
new file mode 100644
index 000000000..ad4e3b43e
--- /dev/null
+++ b/pkg/state/statefile/statefile.go
@@ -0,0 +1,232 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statefile defines the state file data stream.
+//
+// This package currently does not include any details regarding the state
+// encoding itself, only details regarding state metadata and data layout.
+//
+// The file format is defined as follows.
+//
+// /------------------------------------------------------\
+// | header (8-bytes) |
+// +------------------------------------------------------+
+// | metadata length (8-bytes) |
+// +------------------------------------------------------+
+// | metadata |
+// +------------------------------------------------------+
+// | data |
+// \------------------------------------------------------/
+//
+// First, it includes a 8-byte magic header which is the following
+// sequence of bytes [0x67, 0x56, 0x69, 0x73, 0x6f, 0x72, 0x53, 0x46]
+//
+// This header is followed by an 8-byte length N (big endian), and an
+// ASCII-encoded JSON map that is exactly N bytes long.
+//
+// This map includes only strings for keys and strings for values. Keys in the
+// map that begin with "_" are for internal use only. They may be read, but may
+// not be provided by the user. In the future, this metadata may contain some
+// information relating to the state encoding itself.
+//
+// After the map, the remainder of the file is the state data.
+package statefile
+
+import (
+ "bytes"
+ "compress/flate"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/json"
+ "fmt"
+ "hash"
+ "io"
+ "strings"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/compressio"
+)
+
+// keySize is the AES-256 key length.
+const keySize = 32
+
+// compressionChunkSize is the chunk size for compression.
+const compressionChunkSize = 1024 * 1024
+
+// maxMetadataSize is the size limit of metadata section.
+const maxMetadataSize = 16 * 1024 * 1024
+
+// magicHeader is the byte sequence beginning each file.
+var magicHeader = []byte("\x67\x56\x69\x73\x6f\x72\x53\x46")
+
+// ErrBadMagic is returned if the header does not match.
+var ErrBadMagic = fmt.Errorf("bad magic header")
+
+// ErrMetadataMissing is returned if the state file is missing mandatory metadata.
+var ErrMetadataMissing = fmt.Errorf("missing metadata")
+
+// ErrInvalidMetadataLength is returned if the metadata length is too large.
+var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size is %d", maxMetadataSize)
+
+// ErrMetadataInvalid is returned if passed metadata is invalid.
+var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _")
+
+// NewWriter returns a state data writer for a statefile.
+//
+// Note that the returned WriteCloser must be closed.
+func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteCloser, error) {
+ if metadata == nil {
+ metadata = make(map[string]string)
+ }
+ for k := range metadata {
+ if strings.HasPrefix(k, "_") {
+ return nil, ErrMetadataInvalid
+ }
+ }
+
+ // Create our HMAC function.
+ h := hmac.New(sha256.New, key)
+ mw := io.MultiWriter(w, h)
+
+ // First, write the header.
+ if _, err := mw.Write(magicHeader); err != nil {
+ return nil, err
+ }
+
+ // Generate a timestamp, for convenience only.
+ metadata["_timestamp"] = time.Now().UTC().String()
+ defer delete(metadata, "_timestamp")
+
+ // Write the metadata.
+ b, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(b) > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+
+ // Metadata length.
+ if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil {
+ return nil, err
+ }
+ // Metadata bytes; io.MultiWriter will return a short write error if
+ // any of the writers returns < n.
+ if _, err := mw.Write(b); err != nil {
+ return nil, err
+ }
+ // Write the current hash.
+ cur := h.Sum(nil)
+ for done := 0; done < len(cur); {
+ n, err := mw.Write(cur[done:])
+ done += n
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Wrap in compression. We always use "best speed" mode here. When using
+ // "best compression" mode, there is usually only a little gain in file
+ // size reduction, which translate to even smaller gain in restore
+ // latency reduction, while inccuring much more CPU usage at save time.
+ return compressio.NewWriter(w, key, compressionChunkSize, flate.BestSpeed)
+}
+
+// MetadataUnsafe reads out the metadata from a state file without verifying any
+// HMAC. This function shouldn't be called for untrusted input files.
+func MetadataUnsafe(r io.Reader) (map[string]string, error) {
+ return metadata(r, nil)
+}
+
+// metadata validates the magic header and reads out the metadata from a state
+// data stream.
+func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
+ if h != nil {
+ r = io.TeeReader(r, h)
+ }
+
+ // Read and validate magic header.
+ b := make([]byte, len(magicHeader))
+ if _, err := r.Read(b); err != nil {
+ return nil, err
+ }
+ if !bytes.Equal(b, magicHeader) {
+ return nil, ErrBadMagic
+ }
+
+ // Read and validate metadata.
+ b, err := func() (b []byte, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ b = nil
+ err = fmt.Errorf("%v", r)
+ }
+ }()
+
+ metadataLen, err := binary.ReadUint64(r, binary.BigEndian)
+ if err != nil {
+ return nil, err
+ }
+ if metadataLen > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+ b = make([]byte, int(metadataLen))
+ if _, err := io.ReadFull(r, b); err != nil {
+ return nil, err
+ }
+ return b, nil
+ }()
+ if err != nil {
+ return nil, err
+ }
+
+ if h != nil {
+ // Check the hash prior to decoding.
+ cur := h.Sum(nil)
+ buf := make([]byte, len(cur))
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(cur, buf) {
+ return nil, compressio.ErrHashMismatch
+ }
+ }
+
+ // Decode the metadata.
+ metadata := make(map[string]string)
+ if err := json.Unmarshal(b, &metadata); err != nil {
+ return nil, err
+ }
+
+ return metadata, nil
+}
+
+// NewReader returns a reader for a statefile.
+func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
+ // Read the metadata with the hash.
+ h := hmac.New(sha256.New, key)
+ metadata, err := metadata(r, h)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Wrap in compression.
+ rc, err := compressio.NewReader(r, key)
+ if err != nil {
+ return nil, nil, err
+ }
+ return rc, metadata, nil
+}
diff --git a/pkg/state/statefile/statefile_state_autogen.go b/pkg/state/statefile/statefile_state_autogen.go
new file mode 100755
index 000000000..438c485ca
--- /dev/null
+++ b/pkg/state/statefile/statefile_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package statefile
+
diff --git a/pkg/state/stats.go b/pkg/state/stats.go
new file mode 100644
index 000000000..eb51cda47
--- /dev/null
+++ b/pkg/state/stats.go
@@ -0,0 +1,152 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "sort"
+ "time"
+)
+
+type statEntry struct {
+ count uint
+ total time.Duration
+}
+
+// Stats tracks encode / decode timing.
+//
+// This currently provides a meaningful String function and no other way to
+// extract stats about individual types.
+//
+// All exported receivers accept nil.
+type Stats struct {
+ // byType contains a breakdown of time spent by type.
+ byType map[reflect.Type]*statEntry
+
+ // stack contains objects in progress.
+ stack []reflect.Type
+
+ // last is the last start time.
+ last time.Time
+}
+
+// sample adds the samples to the given object.
+func (s *Stats) sample(typ reflect.Type) {
+ now := time.Now()
+ s.byType[typ].total += now.Sub(s.last)
+ s.last = now
+}
+
+// Add adds a sample count.
+func (s *Stats) Add(obj reflect.Value) {
+ if s == nil {
+ return
+ }
+ if s.byType == nil {
+ s.byType = make(map[reflect.Type]*statEntry)
+ }
+ typ := obj.Type()
+ entry, ok := s.byType[typ]
+ if !ok {
+ entry = new(statEntry)
+ s.byType[typ] = entry
+ }
+ entry.count++
+}
+
+// Remove removes a sample count. It should only be called after a previous
+// Add().
+func (s *Stats) Remove(obj reflect.Value) {
+ if s == nil {
+ return
+ }
+ typ := obj.Type()
+ entry := s.byType[typ]
+ entry.count--
+}
+
+// Start starts a sample.
+func (s *Stats) Start(obj reflect.Value) {
+ if s == nil {
+ return
+ }
+ if len(s.stack) > 0 {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ } else {
+ // First time sample.
+ s.last = time.Now()
+ }
+ s.stack = append(s.stack, obj.Type())
+}
+
+// Done finishes the current sample.
+func (s *Stats) Done() {
+ if s == nil {
+ return
+ }
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ s.stack = s.stack[:len(s.stack)-1]
+}
+
+type sliceEntry struct {
+ typ reflect.Type
+ entry *statEntry
+}
+
+// String returns a table representation of the stats.
+func (s *Stats) String() string {
+ if s == nil || len(s.byType) == 0 {
+ return "(no data)"
+ }
+
+ // Build a list of stat entries.
+ ss := make([]sliceEntry, 0, len(s.byType))
+ for typ, entry := range s.byType {
+ ss = append(ss, sliceEntry{
+ typ: typ,
+ entry: entry,
+ })
+ }
+
+ // Sort by total time (descending).
+ sort.Slice(ss, func(i, j int) bool {
+ return ss[i].entry.total > ss[j].entry.total
+ })
+
+ // Print the stat results.
+ var (
+ buf bytes.Buffer
+ count uint
+ total time.Duration
+ )
+ buf.WriteString("\n")
+ buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type"))
+ buf.WriteString("-------------+----------+----------+-------------\n")
+ for _, se := range ss {
+ count += se.entry.count
+ total += se.entry.total
+ per := se.entry.total / time.Duration(se.entry.count)
+ buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n",
+ se.entry.total, se.entry.count, per, se.typ.String()))
+ }
+ buf.WriteString("-------------+----------+----------+-------------\n")
+ buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]",
+ total, count, total/time.Duration(count)))
+ return string(buf.Bytes())
+}
diff --git a/pkg/syserr/host_linux.go b/pkg/syserr/host_linux.go
new file mode 100644
index 000000000..fc6ef60a1
--- /dev/null
+++ b/pkg/syserr/host_linux.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package syserr
+
+import (
+ "fmt"
+ "syscall"
+)
+
+const maxErrno = 134
+
+type linuxHostTranslation struct {
+ err *Error
+ ok bool
+}
+
+var linuxHostTranslations [maxErrno]linuxHostTranslation
+
+// FromHost translates a syscall.Errno to a corresponding Error value.
+func FromHost(err syscall.Errno) *Error {
+ if err < 0 || int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok {
+ panic(fmt.Sprintf("unknown host errno %q (%d)", err.Error(), err))
+ }
+ return linuxHostTranslations[err].err
+}
+
+func addLinuxHostTranslation(host syscall.Errno, trans *Error) {
+ if linuxHostTranslations[host].ok {
+ panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host))
+ }
+ linuxHostTranslations[host] = linuxHostTranslation{err: trans, ok: true}
+}
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
new file mode 100644
index 000000000..bd489b424
--- /dev/null
+++ b/pkg/syserr/netstack.go
@@ -0,0 +1,102 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syserr
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// Mapping for tcpip.Error types.
+var (
+ ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL)
+ ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL)
+ ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV)
+ ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT)
+ ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST)
+ ErrDuplicateAddress = New(tcpip.ErrDuplicateAddress.String(), linux.EEXIST)
+ ErrBadLinkEndpoint = New(tcpip.ErrBadLinkEndpoint.String(), linux.EINVAL)
+ ErrAlreadyBound = New(tcpip.ErrAlreadyBound.String(), linux.EINVAL)
+ ErrInvalidEndpointState = New(tcpip.ErrInvalidEndpointState.String(), linux.EINVAL)
+ ErrAlreadyConnecting = New(tcpip.ErrAlreadyConnecting.String(), linux.EALREADY)
+ ErrNoPortAvailable = New(tcpip.ErrNoPortAvailable.String(), linux.EAGAIN)
+ ErrPortInUse = New(tcpip.ErrPortInUse.String(), linux.EADDRINUSE)
+ ErrBadLocalAddress = New(tcpip.ErrBadLocalAddress.String(), linux.EADDRNOTAVAIL)
+ ErrClosedForSend = New(tcpip.ErrClosedForSend.String(), linux.EPIPE)
+ ErrClosedForReceive = New(tcpip.ErrClosedForReceive.String(), nil)
+ ErrTimeout = New(tcpip.ErrTimeout.String(), linux.ETIMEDOUT)
+ ErrAborted = New(tcpip.ErrAborted.String(), linux.EPIPE)
+ ErrConnectStarted = New(tcpip.ErrConnectStarted.String(), linux.EINPROGRESS)
+ ErrDestinationRequired = New(tcpip.ErrDestinationRequired.String(), linux.EDESTADDRREQ)
+ ErrNotSupported = New(tcpip.ErrNotSupported.String(), linux.EOPNOTSUPP)
+ ErrQueueSizeNotSupported = New(tcpip.ErrQueueSizeNotSupported.String(), linux.ENOTTY)
+ ErrNoSuchFile = New(tcpip.ErrNoSuchFile.String(), linux.ENOENT)
+ ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL)
+ ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES)
+ ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM)
+)
+
+var netstackErrorTranslations = map[*tcpip.Error]*Error{
+ tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
+ tcpip.ErrUnknownNICID: ErrUnknownNICID,
+ tcpip.ErrUnknownDevice: ErrUnknownDevice,
+ tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
+ tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
+ tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
+ tcpip.ErrNoRoute: ErrNoRoute,
+ tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
+ tcpip.ErrAlreadyBound: ErrAlreadyBound,
+ tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
+ tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
+ tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
+ tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
+ tcpip.ErrPortInUse: ErrPortInUse,
+ tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
+ tcpip.ErrClosedForSend: ErrClosedForSend,
+ tcpip.ErrClosedForReceive: ErrClosedForReceive,
+ tcpip.ErrWouldBlock: ErrWouldBlock,
+ tcpip.ErrConnectionRefused: ErrConnectionRefused,
+ tcpip.ErrTimeout: ErrTimeout,
+ tcpip.ErrAborted: ErrAborted,
+ tcpip.ErrConnectStarted: ErrConnectStarted,
+ tcpip.ErrDestinationRequired: ErrDestinationRequired,
+ tcpip.ErrNotSupported: ErrNotSupported,
+ tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
+ tcpip.ErrNotConnected: ErrNotConnected,
+ tcpip.ErrConnectionReset: ErrConnectionReset,
+ tcpip.ErrConnectionAborted: ErrConnectionAborted,
+ tcpip.ErrNoSuchFile: ErrNoSuchFile,
+ tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
+ tcpip.ErrNoLinkAddress: ErrHostDown,
+ tcpip.ErrBadAddress: ErrBadAddress,
+ tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
+ tcpip.ErrMessageTooLong: ErrMessageTooLong,
+ tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
+ tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
+ tcpip.ErrNotPermitted: ErrNotPermittedNet,
+}
+
+// TranslateNetstackError converts an error from the tcpip package to a sentry
+// internal error.
+func TranslateNetstackError(err *tcpip.Error) *Error {
+ if err == nil {
+ return nil
+ }
+ se, ok := netstackErrorTranslations[err]
+ if !ok {
+ panic("Unknown error: " + err.String())
+ }
+ return se
+}
diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go
new file mode 100644
index 000000000..4ddbd3322
--- /dev/null
+++ b/pkg/syserr/syserr.go
@@ -0,0 +1,293 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syserr contains sandbox-internal errors. These errors are distinct
+// from both the errors returned by host system calls and the errors returned
+// to sandboxed applications.
+package syserr
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// Error represents an internal error.
+type Error struct {
+ // message is the human readable form of this Error.
+ message string
+
+ // noTranslation indicates that this Error cannot be translated to a
+ // linux.Errno.
+ noTranslation bool
+
+ // errno is the linux.Errno this Error should be translated to. nil means
+ // that this Error should be translated to a nil linux.Errno.
+ errno *linux.Errno
+}
+
+// New creates a new Error and adds a translation for it.
+//
+// New must only be called at init.
+func New(message string, linuxTranslation *linux.Errno) *Error {
+ err := &Error{message: message, errno: linuxTranslation}
+
+ if linuxTranslation == nil {
+ return err
+ }
+
+ // TODO(b/34162363): Remove this.
+ errno := linuxTranslation.Number()
+ if errno <= 0 || errno >= len(linuxBackwardsTranslations) {
+ panic(fmt.Sprint("invalid errno: ", errno))
+ }
+
+ e := error(syscall.Errno(errno))
+ // syserror.ErrWouldBlock gets translated to syserror.EWOULDBLOCK and
+ // enables proper blocking semantics. This should temporary address the
+ // class of blocking bugs that keep popping up with the current state of
+ // the error space.
+ if e == syserror.EWOULDBLOCK {
+ e = syserror.ErrWouldBlock
+ }
+ linuxBackwardsTranslations[errno] = linuxBackwardsTranslation{err: e, ok: true}
+
+ return err
+}
+
+// NewDynamic creates a new error with a dynamic error message and an errno
+// translation.
+//
+// NewDynamic should only be used sparingly and not be used for static error
+// messages. Errors with static error messages should be declared with New as
+// global variables.
+func NewDynamic(message string, linuxTranslation *linux.Errno) *Error {
+ return &Error{message: message, errno: linuxTranslation}
+}
+
+// NewWithoutTranslation creates a new Error. If translation is attempted on
+// the error, translation will fail.
+//
+// NewWithoutTranslation may be called at any time, but static errors should
+// be declared as global variables and dynamic errors should be used sparingly.
+func NewWithoutTranslation(message string) *Error {
+ return &Error{message: message, noTranslation: true}
+}
+
+func newWithHost(message string, linuxTranslation *linux.Errno, hostErrno syscall.Errno) *Error {
+ e := New(message, linuxTranslation)
+ addLinuxHostTranslation(hostErrno, e)
+ return e
+}
+
+// String implements fmt.Stringer.String.
+func (e *Error) String() string {
+ if e == nil {
+ return "<nil>"
+ }
+ return e.message
+}
+
+type linuxBackwardsTranslation struct {
+ err error
+ ok bool
+}
+
+// TODO(b/34162363): Remove this.
+var linuxBackwardsTranslations [maxErrno]linuxBackwardsTranslation
+
+// ToError translates an Error to a corresponding error value.
+//
+// TODO(b/34162363): Remove this.
+func (e *Error) ToError() error {
+ if e == nil {
+ return nil
+ }
+ if e.noTranslation {
+ panic(fmt.Sprintf("error %q does not support translation", e.message))
+ }
+ if e.errno == nil {
+ return nil
+ }
+ errno := e.errno.Number()
+ if errno <= 0 || errno >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[errno].ok {
+ panic(fmt.Sprintf("unknown error %q (%d)", e.message, errno))
+ }
+ return linuxBackwardsTranslations[errno].err
+}
+
+// ToLinux converts the Error to a Linux ABI error that can be returned to the
+// application.
+func (e *Error) ToLinux() *linux.Errno {
+ if e.noTranslation {
+ panic(fmt.Sprintf("No Linux ABI translation available for %q", e.message))
+ }
+ return e.errno
+}
+
+// TODO(b/34162363): Remove or replace most of these errors.
+//
+// Some of the errors should be replaced with package specific errors and
+// others should be removed entirely.
+var (
+ ErrNotPermitted = newWithHost("operation not permitted", linux.EPERM, syscall.EPERM)
+ ErrNoFileOrDir = newWithHost("no such file or directory", linux.ENOENT, syscall.ENOENT)
+ ErrNoProcess = newWithHost("no such process", linux.ESRCH, syscall.ESRCH)
+ ErrInterrupted = newWithHost("interrupted system call", linux.EINTR, syscall.EINTR)
+ ErrIO = newWithHost("I/O error", linux.EIO, syscall.EIO)
+ ErrDeviceOrAddress = newWithHost("no such device or address", linux.ENXIO, syscall.ENXIO)
+ ErrTooManyArgs = newWithHost("argument list too long", linux.E2BIG, syscall.E2BIG)
+ ErrEcec = newWithHost("exec format error", linux.ENOEXEC, syscall.ENOEXEC)
+ ErrBadFD = newWithHost("bad file number", linux.EBADF, syscall.EBADF)
+ ErrNoChild = newWithHost("no child processes", linux.ECHILD, syscall.ECHILD)
+ ErrTryAgain = newWithHost("try again", linux.EAGAIN, syscall.EAGAIN)
+ ErrNoMemory = newWithHost("out of memory", linux.ENOMEM, syscall.ENOMEM)
+ ErrPermissionDenied = newWithHost("permission denied", linux.EACCES, syscall.EACCES)
+ ErrBadAddress = newWithHost("bad address", linux.EFAULT, syscall.EFAULT)
+ ErrNotBlockDevice = newWithHost("block device required", linux.ENOTBLK, syscall.ENOTBLK)
+ ErrBusy = newWithHost("device or resource busy", linux.EBUSY, syscall.EBUSY)
+ ErrExists = newWithHost("file exists", linux.EEXIST, syscall.EEXIST)
+ ErrCrossDeviceLink = newWithHost("cross-device link", linux.EXDEV, syscall.EXDEV)
+ ErrNoDevice = newWithHost("no such device", linux.ENODEV, syscall.ENODEV)
+ ErrNotDir = newWithHost("not a directory", linux.ENOTDIR, syscall.ENOTDIR)
+ ErrIsDir = newWithHost("is a directory", linux.EISDIR, syscall.EISDIR)
+ ErrInvalidArgument = newWithHost("invalid argument", linux.EINVAL, syscall.EINVAL)
+ ErrFileTableOverflow = newWithHost("file table overflow", linux.ENFILE, syscall.ENFILE)
+ ErrTooManyOpenFiles = newWithHost("too many open files", linux.EMFILE, syscall.EMFILE)
+ ErrNotTTY = newWithHost("not a typewriter", linux.ENOTTY, syscall.ENOTTY)
+ ErrTestFileBusy = newWithHost("text file busy", linux.ETXTBSY, syscall.ETXTBSY)
+ ErrFileTooBig = newWithHost("file too large", linux.EFBIG, syscall.EFBIG)
+ ErrNoSpace = newWithHost("no space left on device", linux.ENOSPC, syscall.ENOSPC)
+ ErrIllegalSeek = newWithHost("illegal seek", linux.ESPIPE, syscall.ESPIPE)
+ ErrReadOnlyFS = newWithHost("read-only file system", linux.EROFS, syscall.EROFS)
+ ErrTooManyLinks = newWithHost("too many links", linux.EMLINK, syscall.EMLINK)
+ ErrBrokenPipe = newWithHost("broken pipe", linux.EPIPE, syscall.EPIPE)
+ ErrDomain = newWithHost("math argument out of domain of func", linux.EDOM, syscall.EDOM)
+ ErrRange = newWithHost("math result not representable", linux.ERANGE, syscall.ERANGE)
+ ErrDeadlock = newWithHost("resource deadlock would occur", linux.EDEADLOCK, syscall.EDEADLOCK)
+ ErrNameTooLong = newWithHost("file name too long", linux.ENAMETOOLONG, syscall.ENAMETOOLONG)
+ ErrNoLocksAvailable = newWithHost("no record locks available", linux.ENOLCK, syscall.ENOLCK)
+ ErrInvalidSyscall = newWithHost("invalid system call number", linux.ENOSYS, syscall.ENOSYS)
+ ErrDirNotEmpty = newWithHost("directory not empty", linux.ENOTEMPTY, syscall.ENOTEMPTY)
+ ErrLinkLoop = newWithHost("too many symbolic links encountered", linux.ELOOP, syscall.ELOOP)
+ ErrNoMessage = newWithHost("no message of desired type", linux.ENOMSG, syscall.ENOMSG)
+ ErrIdentifierRemoved = newWithHost("identifier removed", linux.EIDRM, syscall.EIDRM)
+ ErrChannelOutOfRange = newWithHost("channel number out of range", linux.ECHRNG, syscall.ECHRNG)
+ ErrLevelTwoNotSynced = newWithHost("level 2 not synchronized", linux.EL2NSYNC, syscall.EL2NSYNC)
+ ErrLevelThreeHalted = newWithHost("level 3 halted", linux.EL3HLT, syscall.EL3HLT)
+ ErrLevelThreeReset = newWithHost("level 3 reset", linux.EL3RST, syscall.EL3RST)
+ ErrLinkNumberOutOfRange = newWithHost("link number out of range", linux.ELNRNG, syscall.ELNRNG)
+ ErrProtocolDriverNotAttached = newWithHost("protocol driver not attached", linux.EUNATCH, syscall.EUNATCH)
+ ErrNoCSIAvailable = newWithHost("no CSI structure available", linux.ENOCSI, syscall.ENOCSI)
+ ErrLevelTwoHalted = newWithHost("level 2 halted", linux.EL2HLT, syscall.EL2HLT)
+ ErrInvalidExchange = newWithHost("invalid exchange", linux.EBADE, syscall.EBADE)
+ ErrInvalidRequestDescriptor = newWithHost("invalid request descriptor", linux.EBADR, syscall.EBADR)
+ ErrExchangeFull = newWithHost("exchange full", linux.EXFULL, syscall.EXFULL)
+ ErrNoAnode = newWithHost("no anode", linux.ENOANO, syscall.ENOANO)
+ ErrInvalidRequestCode = newWithHost("invalid request code", linux.EBADRQC, syscall.EBADRQC)
+ ErrInvalidSlot = newWithHost("invalid slot", linux.EBADSLT, syscall.EBADSLT)
+ ErrBadFontFile = newWithHost("bad font file format", linux.EBFONT, syscall.EBFONT)
+ ErrNotStream = newWithHost("device not a stream", linux.ENOSTR, syscall.ENOSTR)
+ ErrNoDataAvailable = newWithHost("no data available", linux.ENODATA, syscall.ENODATA)
+ ErrTimerExpired = newWithHost("timer expired", linux.ETIME, syscall.ETIME)
+ ErrStreamsResourceDepleted = newWithHost("out of streams resources", linux.ENOSR, syscall.ENOSR)
+ ErrMachineNotOnNetwork = newWithHost("machine is not on the network", linux.ENONET, syscall.ENONET)
+ ErrPackageNotInstalled = newWithHost("package not installed", linux.ENOPKG, syscall.ENOPKG)
+ ErrIsRemote = newWithHost("object is remote", linux.EREMOTE, syscall.EREMOTE)
+ ErrNoLink = newWithHost("link has been severed", linux.ENOLINK, syscall.ENOLINK)
+ ErrAdvertise = newWithHost("advertise error", linux.EADV, syscall.EADV)
+ ErrSRMount = newWithHost("srmount error", linux.ESRMNT, syscall.ESRMNT)
+ ErrSendCommunication = newWithHost("communication error on send", linux.ECOMM, syscall.ECOMM)
+ ErrProtocol = newWithHost("protocol error", linux.EPROTO, syscall.EPROTO)
+ ErrMultihopAttempted = newWithHost("multihop attempted", linux.EMULTIHOP, syscall.EMULTIHOP)
+ ErrRFS = newWithHost("RFS specific error", linux.EDOTDOT, syscall.EDOTDOT)
+ ErrInvalidDataMessage = newWithHost("not a data message", linux.EBADMSG, syscall.EBADMSG)
+ ErrOverflow = newWithHost("value too large for defined data type", linux.EOVERFLOW, syscall.EOVERFLOW)
+ ErrNetworkNameNotUnique = newWithHost("name not unique on network", linux.ENOTUNIQ, syscall.ENOTUNIQ)
+ ErrFDInBadState = newWithHost("file descriptor in bad state", linux.EBADFD, syscall.EBADFD)
+ ErrRemoteAddressChanged = newWithHost("remote address changed", linux.EREMCHG, syscall.EREMCHG)
+ ErrSharedLibraryInaccessible = newWithHost("can not access a needed shared library", linux.ELIBACC, syscall.ELIBACC)
+ ErrCorruptedSharedLibrary = newWithHost("accessing a corrupted shared library", linux.ELIBBAD, syscall.ELIBBAD)
+ ErrLibSectionCorrupted = newWithHost(".lib section in a.out corrupted", linux.ELIBSCN, syscall.ELIBSCN)
+ ErrTooManySharedLibraries = newWithHost("attempting to link in too many shared libraries", linux.ELIBMAX, syscall.ELIBMAX)
+ ErrSharedLibraryExeced = newWithHost("cannot exec a shared library directly", linux.ELIBEXEC, syscall.ELIBEXEC)
+ ErrIllegalByteSequence = newWithHost("illegal byte sequence", linux.EILSEQ, syscall.EILSEQ)
+ ErrShouldRestart = newWithHost("interrupted system call should be restarted", linux.ERESTART, syscall.ERESTART)
+ ErrStreamPipe = newWithHost("streams pipe error", linux.ESTRPIPE, syscall.ESTRPIPE)
+ ErrTooManyUsers = newWithHost("too many users", linux.EUSERS, syscall.EUSERS)
+ ErrNotASocket = newWithHost("socket operation on non-socket", linux.ENOTSOCK, syscall.ENOTSOCK)
+ ErrDestinationAddressRequired = newWithHost("destination address required", linux.EDESTADDRREQ, syscall.EDESTADDRREQ)
+ ErrMessageTooLong = newWithHost("message too long", linux.EMSGSIZE, syscall.EMSGSIZE)
+ ErrWrongProtocolForSocket = newWithHost("protocol wrong type for socket", linux.EPROTOTYPE, syscall.EPROTOTYPE)
+ ErrProtocolNotAvailable = newWithHost("protocol not available", linux.ENOPROTOOPT, syscall.ENOPROTOOPT)
+ ErrProtocolNotSupported = newWithHost("protocol not supported", linux.EPROTONOSUPPORT, syscall.EPROTONOSUPPORT)
+ ErrSocketNotSupported = newWithHost("socket type not supported", linux.ESOCKTNOSUPPORT, syscall.ESOCKTNOSUPPORT)
+ ErrEndpointOperation = newWithHost("operation not supported on transport endpoint", linux.EOPNOTSUPP, syscall.EOPNOTSUPP)
+ ErrProtocolFamilyNotSupported = newWithHost("protocol family not supported", linux.EPFNOSUPPORT, syscall.EPFNOSUPPORT)
+ ErrAddressFamilyNotSupported = newWithHost("address family not supported by protocol", linux.EAFNOSUPPORT, syscall.EAFNOSUPPORT)
+ ErrAddressInUse = newWithHost("address already in use", linux.EADDRINUSE, syscall.EADDRINUSE)
+ ErrAddressNotAvailable = newWithHost("cannot assign requested address", linux.EADDRNOTAVAIL, syscall.EADDRNOTAVAIL)
+ ErrNetworkDown = newWithHost("network is down", linux.ENETDOWN, syscall.ENETDOWN)
+ ErrNetworkUnreachable = newWithHost("network is unreachable", linux.ENETUNREACH, syscall.ENETUNREACH)
+ ErrNetworkReset = newWithHost("network dropped connection because of reset", linux.ENETRESET, syscall.ENETRESET)
+ ErrConnectionAborted = newWithHost("software caused connection abort", linux.ECONNABORTED, syscall.ECONNABORTED)
+ ErrConnectionReset = newWithHost("connection reset by peer", linux.ECONNRESET, syscall.ECONNRESET)
+ ErrNoBufferSpace = newWithHost("no buffer space available", linux.ENOBUFS, syscall.ENOBUFS)
+ ErrAlreadyConnected = newWithHost("transport endpoint is already connected", linux.EISCONN, syscall.EISCONN)
+ ErrNotConnected = newWithHost("transport endpoint is not connected", linux.ENOTCONN, syscall.ENOTCONN)
+ ErrShutdown = newWithHost("cannot send after transport endpoint shutdown", linux.ESHUTDOWN, syscall.ESHUTDOWN)
+ ErrTooManyRefs = newWithHost("too many references: cannot splice", linux.ETOOMANYREFS, syscall.ETOOMANYREFS)
+ ErrTimedOut = newWithHost("connection timed out", linux.ETIMEDOUT, syscall.ETIMEDOUT)
+ ErrConnectionRefused = newWithHost("connection refused", linux.ECONNREFUSED, syscall.ECONNREFUSED)
+ ErrHostDown = newWithHost("host is down", linux.EHOSTDOWN, syscall.EHOSTDOWN)
+ ErrNoRoute = newWithHost("no route to host", linux.EHOSTUNREACH, syscall.EHOSTUNREACH)
+ ErrAlreadyInProgress = newWithHost("operation already in progress", linux.EALREADY, syscall.EALREADY)
+ ErrInProgress = newWithHost("operation now in progress", linux.EINPROGRESS, syscall.EINPROGRESS)
+ ErrStaleFileHandle = newWithHost("stale file handle", linux.ESTALE, syscall.ESTALE)
+ ErrStructureNeedsCleaning = newWithHost("structure needs cleaning", linux.EUCLEAN, syscall.EUCLEAN)
+ ErrIsNamedFile = newWithHost("is a named type file", linux.ENOTNAM, syscall.ENOTNAM)
+ ErrRemoteIO = newWithHost("remote I/O error", linux.EREMOTEIO, syscall.EREMOTEIO)
+ ErrQuotaExceeded = newWithHost("quota exceeded", linux.EDQUOT, syscall.EDQUOT)
+ ErrNoMedium = newWithHost("no medium found", linux.ENOMEDIUM, syscall.ENOMEDIUM)
+ ErrWrongMediumType = newWithHost("wrong medium type", linux.EMEDIUMTYPE, syscall.EMEDIUMTYPE)
+ ErrCanceled = newWithHost("operation canceled", linux.ECANCELED, syscall.ECANCELED)
+ ErrNoKey = newWithHost("required key not available", linux.ENOKEY, syscall.ENOKEY)
+ ErrKeyExpired = newWithHost("key has expired", linux.EKEYEXPIRED, syscall.EKEYEXPIRED)
+ ErrKeyRevoked = newWithHost("key has been revoked", linux.EKEYREVOKED, syscall.EKEYREVOKED)
+ ErrKeyRejected = newWithHost("key was rejected by service", linux.EKEYREJECTED, syscall.EKEYREJECTED)
+ ErrOwnerDied = newWithHost("owner died", linux.EOWNERDEAD, syscall.EOWNERDEAD)
+ ErrNotRecoverable = newWithHost("state not recoverable", linux.ENOTRECOVERABLE, syscall.ENOTRECOVERABLE)
+
+ // ErrWouldBlock translates to EWOULDBLOCK which is the same as EAGAIN
+ // on Linux.
+ ErrWouldBlock = New("operation would block", linux.EWOULDBLOCK)
+)
+
+// FromError converts a generic error to an *Error.
+//
+// TODO(b/34162363): Remove this function.
+func FromError(err error) *Error {
+ if err == nil {
+ return nil
+ }
+ if errno, ok := err.(syscall.Errno); ok {
+ return FromHost(errno)
+ }
+ if errno, ok := syserror.TranslateError(err); ok {
+ return FromHost(errno)
+ }
+ panic("unknown error: " + err.Error())
+}
diff --git a/pkg/syserr/syserr_state_autogen.go b/pkg/syserr/syserr_state_autogen.go
new file mode 100755
index 000000000..f34cb096b
--- /dev/null
+++ b/pkg/syserr/syserr_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package syserr
+
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
new file mode 100644
index 000000000..345653544
--- /dev/null
+++ b/pkg/syserror/syserror.go
@@ -0,0 +1,153 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syserror contains syscall error codes exported as error interface
+// instead of Errno. This allows for fast comparison and returns when the
+// comparand or return value is of type error because there is no need to
+// convert from Errno to an interface, i.e., runtime.convT2I isn't called.
+package syserror
+
+import (
+ "errors"
+ "syscall"
+)
+
+// The following variables have the same meaning as their syscall equivalent.
+var (
+ E2BIG = error(syscall.E2BIG)
+ EACCES = error(syscall.EACCES)
+ EAGAIN = error(syscall.EAGAIN)
+ EBADF = error(syscall.EBADF)
+ EBUSY = error(syscall.EBUSY)
+ ECHILD = error(syscall.ECHILD)
+ ECONNREFUSED = error(syscall.ECONNREFUSED)
+ ECONNRESET = error(syscall.ECONNRESET)
+ EDEADLK = error(syscall.EDEADLK)
+ EEXIST = error(syscall.EEXIST)
+ EFAULT = error(syscall.EFAULT)
+ EFBIG = error(syscall.EFBIG)
+ EIDRM = error(syscall.EIDRM)
+ EINTR = error(syscall.EINTR)
+ EINVAL = error(syscall.EINVAL)
+ EIO = error(syscall.EIO)
+ EISDIR = error(syscall.EISDIR)
+ ELIBBAD = error(syscall.ELIBBAD)
+ ELOOP = error(syscall.ELOOP)
+ EMFILE = error(syscall.EMFILE)
+ EMSGSIZE = error(syscall.EMSGSIZE)
+ ENAMETOOLONG = error(syscall.ENAMETOOLONG)
+ ENOATTR = ENODATA
+ ENODATA = error(syscall.ENODATA)
+ ENODEV = error(syscall.ENODEV)
+ ENOENT = error(syscall.ENOENT)
+ ENOEXEC = error(syscall.ENOEXEC)
+ ENOLCK = error(syscall.ENOLCK)
+ ENOLINK = error(syscall.ENOLINK)
+ ENOMEM = error(syscall.ENOMEM)
+ ENOSPC = error(syscall.ENOSPC)
+ ENOSYS = error(syscall.ENOSYS)
+ ENOTDIR = error(syscall.ENOTDIR)
+ ENOTEMPTY = error(syscall.ENOTEMPTY)
+ ENOTSUP = error(syscall.ENOTSUP)
+ ENOTTY = error(syscall.ENOTTY)
+ ENXIO = error(syscall.ENXIO)
+ EOPNOTSUPP = error(syscall.EOPNOTSUPP)
+ EOVERFLOW = error(syscall.EOVERFLOW)
+ EPERM = error(syscall.EPERM)
+ EPIPE = error(syscall.EPIPE)
+ ERANGE = error(syscall.ERANGE)
+ EROFS = error(syscall.EROFS)
+ ESPIPE = error(syscall.ESPIPE)
+ ESRCH = error(syscall.ESRCH)
+ ETIMEDOUT = error(syscall.ETIMEDOUT)
+ EUSERS = error(syscall.EUSERS)
+ EWOULDBLOCK = error(syscall.EWOULDBLOCK)
+ EXDEV = error(syscall.EXDEV)
+)
+
+var (
+ // ErrWouldBlock is an internal error used to indicate that an operation
+ // cannot be satisfied immediately, and should be retried at a later
+ // time, possibly when the caller has received a notification that the
+ // operation may be able to complete. It is used by implementations of
+ // the kio.File interface.
+ ErrWouldBlock = errors.New("request would block")
+
+ // ErrInterrupted is returned if a request is interrupted before it can
+ // complete.
+ ErrInterrupted = errors.New("request was interrupted")
+
+ // ErrExceedsFileSizeLimit is returned if a request would exceed the
+ // file's size limit.
+ ErrExceedsFileSizeLimit = errors.New("exceeds file size limit")
+)
+
+// errorMap is the map used to convert generic errors into errnos.
+var errorMap = map[error]syscall.Errno{}
+
+// errorUnwrappers is an array of unwrap functions to extract typed errors.
+var errorUnwrappers = []func(error) (syscall.Errno, bool){}
+
+// AddErrorTranslation allows modules to populate the error map by adding their
+// own translations during initialization. Returns if the error translation is
+// accepted or not. A pre-existing translation will not be overwritten by the
+// new translation.
+func AddErrorTranslation(from error, to syscall.Errno) bool {
+ if _, ok := errorMap[from]; ok {
+ return false
+ }
+
+ errorMap[from] = to
+ return true
+}
+
+// AddErrorUnwrapper registers an unwrap method that can extract a concrete error
+// from a typed, but not initialized, error.
+func AddErrorUnwrapper(unwrap func(e error) (syscall.Errno, bool)) {
+ errorUnwrappers = append(errorUnwrappers, unwrap)
+}
+
+// TranslateError translates errors to errnos, it will return false if
+// the error was not registered.
+func TranslateError(from error) (syscall.Errno, bool) {
+ err, ok := errorMap[from]
+ if ok {
+ return err, ok
+ }
+ // Try to unwrap the error if we couldn't match an error
+ // exactly. This might mean that a package has its own
+ // error type.
+ for _, unwrap := range errorUnwrappers {
+ err, ok := unwrap(from)
+ if ok {
+ return err, ok
+ }
+ }
+ return 0, false
+}
+
+// ConvertIntr converts the provided error code (err) to another one (intr) if
+// the first error corresponds to an interrupted operation.
+func ConvertIntr(err, intr error) error {
+ if err == ErrInterrupted {
+ return intr
+ }
+ return err
+}
+
+func init() {
+ AddErrorTranslation(ErrWouldBlock, syscall.EWOULDBLOCK)
+ AddErrorTranslation(ErrInterrupted, syscall.EINTR)
+ AddErrorTranslation(ErrExceedsFileSizeLimit, syscall.EFBIG)
+}
diff --git a/pkg/syserror/syserror_state_autogen.go b/pkg/syserror/syserror_state_autogen.go
new file mode 100755
index 000000000..5691e4f5e
--- /dev/null
+++ b/pkg/syserror/syserror_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package syserror
+
diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go
new file mode 100755
index 000000000..7e51a28e8
--- /dev/null
+++ b/pkg/tcpip/buffer/buffer_state_autogen.go
@@ -0,0 +1,24 @@
+// automatically generated by stateify.
+
+package buffer
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *VectorisedView) beforeSave() {}
+func (x *VectorisedView) save(m state.Map) {
+ x.beforeSave()
+ m.Save("views", &x.views)
+ m.Save("size", &x.size)
+}
+
+func (x *VectorisedView) afterLoad() {}
+func (x *VectorisedView) load(m state.Map) {
+ m.Load("views", &x.views)
+ m.Load("size", &x.size)
+}
+
+func init() {
+ state.Register("buffer.VectorisedView", (*VectorisedView)(nil), state.Fns{Save: (*VectorisedView).save, Load: (*VectorisedView).load})
+}
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
new file mode 100644
index 000000000..4287464f3
--- /dev/null
+++ b/pkg/tcpip/buffer/prependable.go
@@ -0,0 +1,74 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+// Prependable is a buffer that grows backwards, that is, more data can be
+// prepended to it. It is useful when building networking packets, where each
+// protocol adds its own headers to the front of the higher-level protocol
+// header and payload; for example, TCP would prepend its header to the payload,
+// then IP would prepend its own, then ethernet.
+type Prependable struct {
+ // Buf is the buffer backing the prependable buffer.
+ buf View
+
+ // usedIdx is the index where the used part of the buffer begins.
+ usedIdx int
+}
+
+// NewPrependable allocates a new prependable buffer with the given size.
+func NewPrependable(size int) Prependable {
+ return Prependable{buf: NewView(size), usedIdx: size}
+}
+
+// NewPrependableFromView creates an entirely-used Prependable from a View.
+//
+// NewPrependableFromView takes ownership of v. Note that since the entire
+// prependable is used, further attempts to call Prepend will note that size >
+// p.usedIdx and return nil.
+func NewPrependableFromView(v View) Prependable {
+ return Prependable{buf: v, usedIdx: 0}
+}
+
+// View returns a View of the backing buffer that contains all prepended
+// data so far.
+func (p Prependable) View() View {
+ return p.buf[p.usedIdx:]
+}
+
+// UsedLength returns the number of bytes used so far.
+func (p Prependable) UsedLength() int {
+ return len(p.buf) - p.usedIdx
+}
+
+// AvailableLength returns the number of bytes used so far.
+func (p Prependable) AvailableLength() int {
+ return p.usedIdx
+}
+
+// TrimBack removes size bytes from the end.
+func (p *Prependable) TrimBack(size int) {
+ p.buf = p.buf[:len(p.buf)-size]
+}
+
+// Prepend reserves the requested space in front of the buffer, returning a
+// slice that represents the reserved space.
+func (p *Prependable) Prepend(size int) []byte {
+ if size > p.usedIdx {
+ return nil
+ }
+
+ p.usedIdx -= size
+ return p.View()[:size:size]
+}
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
new file mode 100644
index 000000000..1a9d40778
--- /dev/null
+++ b/pkg/tcpip/buffer/view.go
@@ -0,0 +1,158 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer provides the implementation of a buffer view.
+package buffer
+
+// View is a slice of a buffer, with convenience methods.
+type View []byte
+
+// NewView allocates a new buffer and returns an initialized view that covers
+// the whole buffer.
+func NewView(size int) View {
+ return make(View, size)
+}
+
+// NewViewFromBytes allocates a new buffer and copies in the given bytes.
+func NewViewFromBytes(b []byte) View {
+ return append(View(nil), b...)
+}
+
+// TrimFront removes the first "count" bytes from the visible section of the
+// buffer.
+func (v *View) TrimFront(count int) {
+ *v = (*v)[count:]
+}
+
+// CapLength irreversibly reduces the length of the visible section of the
+// buffer to the value specified.
+func (v *View) CapLength(length int) {
+ // We also set the slice cap because if we don't, one would be able to
+ // expand the view back to include the region just excluded. We want to
+ // prevent that to avoid potential data leak if we have uninitialized
+ // data in excluded region.
+ *v = (*v)[:length:length]
+}
+
+// ToVectorisedView returns a VectorisedView containing the receiver.
+func (v View) ToVectorisedView() VectorisedView {
+ return NewVectorisedView(len(v), []View{v})
+}
+
+// VectorisedView is a vectorised version of View using non contigous memory.
+// It supports all the convenience methods supported by View.
+//
+// +stateify savable
+type VectorisedView struct {
+ views []View
+ size int
+}
+
+// NewVectorisedView creates a new vectorised view from an already-allocated slice
+// of View and sets its size.
+func NewVectorisedView(size int, views []View) VectorisedView {
+ return VectorisedView{views: views, size: size}
+}
+
+// TrimFront removes the first "count" bytes of the vectorised view.
+func (vv *VectorisedView) TrimFront(count int) {
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ vv.views[0].TrimFront(count)
+ return
+ }
+ count -= len(vv.views[0])
+ vv.RemoveFirst()
+ }
+}
+
+// CapLength irreversibly reduces the length of the vectorised view.
+func (vv *VectorisedView) CapLength(length int) {
+ if length < 0 {
+ length = 0
+ }
+ if vv.size < length {
+ return
+ }
+ vv.size = length
+ for i := range vv.views {
+ v := &vv.views[i]
+ if len(*v) >= length {
+ if length == 0 {
+ vv.views = vv.views[:i]
+ } else {
+ v.CapLength(length)
+ vv.views = vv.views[:i+1]
+ }
+ return
+ }
+ length -= len(*v)
+ }
+}
+
+// Clone returns a clone of this VectorisedView.
+// If the buffer argument is large enough to contain all the Views of this VectorisedView,
+// the method will avoid allocations and use the buffer to store the Views of the clone.
+func (vv VectorisedView) Clone(buffer []View) VectorisedView {
+ return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
+}
+
+// First returns the first view of the vectorised view.
+func (vv VectorisedView) First() View {
+ if len(vv.views) == 0 {
+ return nil
+ }
+ return vv.views[0]
+}
+
+// RemoveFirst removes the first view of the vectorised view.
+func (vv *VectorisedView) RemoveFirst() {
+ if len(vv.views) == 0 {
+ return
+ }
+ vv.size -= len(vv.views[0])
+ vv.views = vv.views[1:]
+}
+
+// Size returns the size in bytes of the entire content stored in the vectorised view.
+func (vv VectorisedView) Size() int {
+ return vv.size
+}
+
+// ToView returns a single view containing the content of the vectorised view.
+//
+// If the vectorised view contains a single view, that view will be returned
+// directly.
+func (vv VectorisedView) ToView() View {
+ if len(vv.views) == 1 {
+ return vv.views[0]
+ }
+ u := make([]byte, 0, vv.size)
+ for _, v := range vv.views {
+ u = append(u, v...)
+ }
+ return u
+}
+
+// Views returns the slice containing the all views.
+func (vv VectorisedView) Views() []View {
+ return vv.views
+}
+
+// Append appends the views in a vectorised view to this vectorised view.
+func (vv *VectorisedView) Append(vv2 VectorisedView) {
+ vv.views = append(vv.views, vv2.views...)
+ vv.size += vv2.size
+}
diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go
new file mode 100644
index 000000000..52c22230e
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins.go
@@ -0,0 +1,80 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package jenkins implements Jenkins's one_at_a_time, non-cryptographic hash
+// functions created by by Bob Jenkins.
+//
+// See https://en.wikipedia.org/wiki/Jenkins_hash_function#cite_note-dobbsx-1
+//
+package jenkins
+
+import (
+ "hash"
+)
+
+// Sum32 represents Jenkins's one_at_a_time hash.
+//
+// Use the Sum32 type directly (as opposed to New32 below)
+// to avoid allocations.
+type Sum32 uint32
+
+// New32 returns a new 32-bit Jenkins's one_at_a_time hash.Hash.
+//
+// Its Sum method will lay the value out in big-endian byte order.
+func New32() hash.Hash32 {
+ var s Sum32
+ return &s
+}
+
+// Reset resets the hash to its initial state.
+func (s *Sum32) Reset() { *s = 0 }
+
+// Sum32 returns the hash value
+func (s *Sum32) Sum32() uint32 {
+ hash := *s
+
+ hash += (hash << 3)
+ hash ^= hash >> 11
+ hash += hash << 15
+
+ return uint32(hash)
+}
+
+// Write adds more data to the running hash.
+//
+// It never returns an error.
+func (s *Sum32) Write(data []byte) (int, error) {
+ hash := *s
+ for _, b := range data {
+ hash += Sum32(b)
+ hash += hash << 10
+ hash ^= hash >> 6
+ }
+ *s = hash
+ return len(data), nil
+}
+
+// Size returns the number of bytes Sum will return.
+func (s *Sum32) Size() int { return 4 }
+
+// BlockSize returns the hash's underlying block size.
+func (s *Sum32) BlockSize() int { return 1 }
+
+// Sum appends the current hash to in and returns the resulting slice.
+//
+// It does not change the underlying hash state.
+func (s *Sum32) Sum(in []byte) []byte {
+ v := s.Sum32()
+ return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go
new file mode 100755
index 000000000..310f0ee6d
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package jenkins
+
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
new file mode 100644
index 000000000..55fe7292c
--- /dev/null
+++ b/pkg/tcpip/header/arp.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.googlesource.com/gvisor/pkg/tcpip"
+
+const (
+ // ARPProtocolNumber is the ARP network protocol number.
+ ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
+
+ // ARPSize is the size of an IPv4-over-Ethernet ARP packet.
+ ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+)
+
+// ARPOp is an ARP opcode.
+type ARPOp uint16
+
+// Typical ARP opcodes defined in RFC 826.
+const (
+ ARPRequest ARPOp = 1
+ ARPReply ARPOp = 2
+)
+
+// ARP is an ARP packet stored in a byte array as described in RFC 826.
+type ARP []byte
+
+func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
+func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
+func (a ARP) hardwareAddressSize() int { return int(a[4]) }
+func (a ARP) protocolAddressSize() int { return int(a[5]) }
+
+// Op is the ARP opcode.
+func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+
+// SetOp sets the ARP opcode.
+func (a ARP) SetOp(op ARPOp) {
+ a[6] = uint8(op >> 8)
+ a[7] = uint8(op)
+}
+
+// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
+func (a ARP) SetIPv4OverEthernet() {
+ a[0], a[1] = 0, 1 // htypeEthernet
+ a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
+ a[4] = 6 // macSize
+ a[5] = uint8(IPv4AddressSize)
+}
+
+// HardwareAddressSender is the link address of the sender.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) HardwareAddressSender() []byte {
+ const s = 8
+ return a[s : s+6]
+}
+
+// ProtocolAddressSender is the protocol address of the sender.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) ProtocolAddressSender() []byte {
+ const s = 8 + 6
+ return a[s : s+4]
+}
+
+// HardwareAddressTarget is the link address of the target.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) HardwareAddressTarget() []byte {
+ const s = 8 + 6 + 4
+ return a[s : s+6]
+}
+
+// ProtocolAddressTarget is the protocol address of the target.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) ProtocolAddressTarget() []byte {
+ const s = 8 + 6 + 4 + 6
+ return a[s : s+4]
+}
+
+// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
+func (a ARP) IsValid() bool {
+ if len(a) < ARPSize {
+ return false
+ }
+ const htypeEthernet = 1
+ const macSize = 6
+ return a.hardwareAddressSpace() == htypeEthernet &&
+ a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
+ a.hardwareAddressSize() == macSize &&
+ a.protocolAddressSize() == IPv4AddressSize
+}
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
new file mode 100644
index 000000000..2eaa7938a
--- /dev/null
+++ b/pkg/tcpip/header/checksum.go
@@ -0,0 +1,94 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package header provides the implementation of the encoding and decoding of
+// network protocol headers.
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func calculateChecksum(buf []byte, initial uint32) uint16 {
+ v := initial
+
+ l := len(buf)
+ if l&1 != 0 {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+
+ for i := 0; i < l; i += 2 {
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16))
+}
+
+// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// given byte array.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func Checksum(buf []byte, initial uint16) uint16 {
+ return calculateChecksum(buf, uint32(initial))
+}
+
+// ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given VectorizedView.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
+ var odd bool
+ sum := initial
+ for _, v := range vv.Views() {
+ if len(v) == 0 {
+ continue
+ }
+ s := uint32(sum)
+ if odd {
+ s += uint32(v[0])
+ v = v[1:]
+ }
+ odd = len(v)&1 != 0
+ sum = calculateChecksum(v, s)
+ }
+ return sum
+}
+
+// ChecksumCombine combines the two uint16 to form their checksum. This is done
+// by adding them and the carry.
+//
+// Note that checksum a must have been computed on an even number of bytes.
+func ChecksumCombine(a, b uint16) uint16 {
+ v := uint32(a) + uint32(b)
+ return uint16(v + v>>16)
+}
+
+// PseudoHeaderChecksum calculates the pseudo-header checksum for the given
+// destination protocol and network address. Pseudo-headers are needed by
+// transport layers when calculating their own checksum.
+func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 {
+ xsum := Checksum([]byte(srcAddr), 0)
+ xsum = Checksum([]byte(dstAddr), xsum)
+
+ // Add the length portion of the checksum to the pseudo-checksum.
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ xsum = Checksum(tmp, xsum)
+
+ return Checksum([]byte{0, uint8(protocol)}, xsum)
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
new file mode 100644
index 000000000..76143f454
--- /dev/null
+++ b/pkg/tcpip/header/eth.go
@@ -0,0 +1,74 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ dstMAC = 0
+ srcMAC = 6
+ ethType = 12
+)
+
+// EthernetFields contains the fields of an ethernet frame header. It is used to
+// describe the fields of a frame that needs to be encoded.
+type EthernetFields struct {
+ // SrcAddr is the "MAC source" field of an ethernet frame header.
+ SrcAddr tcpip.LinkAddress
+
+ // DstAddr is the "MAC destination" field of an ethernet frame header.
+ DstAddr tcpip.LinkAddress
+
+ // Type is the "ethertype" field of an ethernet frame header.
+ Type tcpip.NetworkProtocolNumber
+}
+
+// Ethernet represents an ethernet frame header stored in a byte array.
+type Ethernet []byte
+
+const (
+ // EthernetMinimumSize is the minimum size of a valid ethernet frame.
+ EthernetMinimumSize = 14
+
+ // EthernetAddressSize is the size, in bytes, of an ethernet address.
+ EthernetAddressSize = 6
+)
+
+// SourceAddress returns the "MAC source" field of the ethernet frame header.
+func (b Ethernet) SourceAddress() tcpip.LinkAddress {
+ return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
+}
+
+// DestinationAddress returns the "MAC destination" field of the ethernet frame
+// header.
+func (b Ethernet) DestinationAddress() tcpip.LinkAddress {
+ return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize])
+}
+
+// Type returns the "ethertype" field of the ethernet frame header.
+func (b Ethernet) Type() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:]))
+}
+
+// Encode encodes all the fields of the ethernet frame header.
+func (b Ethernet) Encode(e *EthernetFields) {
+ binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type))
+ copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
+ copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
+}
diff --git a/pkg/tcpip/header/gue.go b/pkg/tcpip/header/gue.go
new file mode 100644
index 000000000..10d358c0e
--- /dev/null
+++ b/pkg/tcpip/header/gue.go
@@ -0,0 +1,73 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+const (
+ typeHLen = 0
+ encapProto = 1
+)
+
+// GUEFields contains the fields of a GUE packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type GUEFields struct {
+ // Type is the "type" field of the GUE header.
+ Type uint8
+
+ // Control is the "control" field of the GUE header.
+ Control bool
+
+ // HeaderLength is the "header length" field of the GUE header. It must
+ // be at least 4 octets, and a multiple of 4 as well.
+ HeaderLength uint8
+
+ // Protocol is the "protocol" field of the GUE header. This is one of
+ // the IPPROTO_* values.
+ Protocol uint8
+}
+
+// GUE represents a Generic UDP Encapsulation header stored in a byte array, the
+// fields are described in https://tools.ietf.org/html/draft-ietf-nvo3-gue-01.
+type GUE []byte
+
+const (
+ // GUEMinimumSize is the minimum size of a valid GUE packet.
+ GUEMinimumSize = 4
+)
+
+// TypeAndControl returns the GUE packet type (top 3 bits of the first byte,
+// which includes the control bit).
+func (b GUE) TypeAndControl() uint8 {
+ return b[typeHLen] >> 5
+}
+
+// HeaderLength returns the total length of the GUE header.
+func (b GUE) HeaderLength() uint8 {
+ return 4 + 4*(b[typeHLen]&0x1f)
+}
+
+// Protocol returns the protocol field of the GUE header.
+func (b GUE) Protocol() uint8 {
+ return b[encapProto]
+}
+
+// Encode encodes all the fields of the GUE header.
+func (b GUE) Encode(i *GUEFields) {
+ ctl := uint8(0)
+ if i.Control {
+ ctl = 1 << 5
+ }
+ b[typeHLen] = ctl | i.Type<<6 | (i.HeaderLength-4)/4
+ b[encapProto] = i.Protocol
+}
diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go
new file mode 100755
index 000000000..a8f4c4693
--- /dev/null
+++ b/pkg/tcpip/header/header_state_autogen.go
@@ -0,0 +1,42 @@
+// automatically generated by stateify.
+
+package header
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *SACKBlock) beforeSave() {}
+func (x *SACKBlock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *SACKBlock) afterLoad() {}
+func (x *SACKBlock) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func (x *TCPOptions) beforeSave() {}
+func (x *TCPOptions) save(m state.Map) {
+ x.beforeSave()
+ m.Save("TS", &x.TS)
+ m.Save("TSVal", &x.TSVal)
+ m.Save("TSEcr", &x.TSEcr)
+ m.Save("SACKBlocks", &x.SACKBlocks)
+}
+
+func (x *TCPOptions) afterLoad() {}
+func (x *TCPOptions) load(m state.Map) {
+ m.Load("TS", &x.TS)
+ m.Load("TSVal", &x.TSVal)
+ m.Load("TSEcr", &x.TSEcr)
+ m.Load("SACKBlocks", &x.SACKBlocks)
+}
+
+func init() {
+ state.Register("header.SACKBlock", (*SACKBlock)(nil), state.Fns{Save: (*SACKBlock).save, Load: (*SACKBlock).load})
+ state.Register("header.TCPOptions", (*TCPOptions)(nil), state.Fns{Save: (*TCPOptions).save, Load: (*TCPOptions).load})
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
new file mode 100644
index 000000000..782e1053c
--- /dev/null
+++ b/pkg/tcpip/header/icmpv4.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// ICMPv4 represents an ICMPv4 header stored in a byte array.
+type ICMPv4 []byte
+
+const (
+ // ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv4MinimumSize = 4
+
+ // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ ICMPv4EchoMinimumSize = 6
+
+ // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4
+
+ // ICMPv4ProtocolNumber is the ICMP transport protocol number.
+ ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
+)
+
+// ICMPv4Type is the ICMP type field described in RFC 792.
+type ICMPv4Type byte
+
+// Typical values of ICMPv4Type defined in RFC 792.
+const (
+ ICMPv4EchoReply ICMPv4Type = 0
+ ICMPv4DstUnreachable ICMPv4Type = 3
+ ICMPv4SrcQuench ICMPv4Type = 4
+ ICMPv4Redirect ICMPv4Type = 5
+ ICMPv4Echo ICMPv4Type = 8
+ ICMPv4TimeExceeded ICMPv4Type = 11
+ ICMPv4ParamProblem ICMPv4Type = 12
+ ICMPv4Timestamp ICMPv4Type = 13
+ ICMPv4TimestampReply ICMPv4Type = 14
+ ICMPv4InfoRequest ICMPv4Type = 15
+ ICMPv4InfoReply ICMPv4Type = 16
+)
+
+// Values for ICMP code as defined in RFC 792.
+const (
+ ICMPv4PortUnreachable = 3
+ ICMPv4FragmentationNeeded = 4
+)
+
+// Type is the ICMP type field.
+func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv4) Code() byte { return b[1] }
+
+// SetCode sets the ICMP code field.
+func (b ICMPv4) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[2:])
+}
+
+// SetChecksum sets the ICMP checksum field.
+func (b ICMPv4) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[2:], checksum)
+}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv4) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv4) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv4) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv4) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv4) Payload() []byte {
+ return b[ICMPv4MinimumSize:]
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
new file mode 100644
index 000000000..d0b10d849
--- /dev/null
+++ b/pkg/tcpip/header/icmpv6.go
@@ -0,0 +1,121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// ICMPv6 represents an ICMPv6 header stored in a byte array.
+type ICMPv6 []byte
+
+const (
+ // ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv6MinimumSize = 4
+
+ // ICMPv6ProtocolNumber is the ICMP transport protocol number.
+ ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
+
+ // ICMPv6NeighborSolicitMinimumSize is the minimum size of a
+ // neighbor solicitation packet.
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
+
+ // ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
+ ICMPv6NeighborAdvertSize = 32
+
+ // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ ICMPv6EchoMinimumSize = 8
+
+ // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4
+
+ // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
+ // packet-too-big packet.
+ ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4
+)
+
+// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
+type ICMPv6Type byte
+
+// Typical values of ICMPv6Type defined in RFC 4443.
+const (
+ ICMPv6DstUnreachable ICMPv6Type = 1
+ ICMPv6PacketTooBig ICMPv6Type = 2
+ ICMPv6TimeExceeded ICMPv6Type = 3
+ ICMPv6ParamProblem ICMPv6Type = 4
+ ICMPv6EchoRequest ICMPv6Type = 128
+ ICMPv6EchoReply ICMPv6Type = 129
+
+ // Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
+
+ ICMPv6RouterSolicit ICMPv6Type = 133
+ ICMPv6RouterAdvert ICMPv6Type = 134
+ ICMPv6NeighborSolicit ICMPv6Type = 135
+ ICMPv6NeighborAdvert ICMPv6Type = 136
+ ICMPv6RedirectMsg ICMPv6Type = 137
+)
+
+// Values for ICMP code as defined in RFC 4443.
+const (
+ ICMPv6PortUnreachable = 4
+)
+
+// Type is the ICMP type field.
+func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv6) Code() byte { return b[1] }
+
+// SetCode sets the ICMP code field.
+func (b ICMPv6) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv6) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[2:])
+}
+
+// SetChecksum calculates and sets the ICMP checksum field.
+func (b ICMPv6) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[2:], checksum)
+}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv6) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv6) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv6) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv6) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv6) Payload() []byte {
+ return b[ICMPv6MinimumSize:]
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
new file mode 100644
index 000000000..fb250ea30
--- /dev/null
+++ b/pkg/tcpip/header/interfaces.go
@@ -0,0 +1,92 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ // MaxIPPacketSize is the maximum supported IP packet size, excluding
+ // jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
+ // in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
+ // 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
+ // m is the minimum IPv6 header size; we leave room for some potential
+ // IP options.
+ MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
+)
+
+// Transport offers generic methods to query and/or update the fields of the
+// header of a transport protocol buffer.
+type Transport interface {
+ // SourcePort returns the value of the "source port" field.
+ SourcePort() uint16
+
+ // Destination returns the value of the "destination port" field.
+ DestinationPort() uint16
+
+ // Checksum returns the value of the "checksum" field.
+ Checksum() uint16
+
+ // SetSourcePort sets the value of the "source port" field.
+ SetSourcePort(uint16)
+
+ // SetDestinationPort sets the value of the "destination port" field.
+ SetDestinationPort(uint16)
+
+ // SetChecksum sets the value of the "checksum" field.
+ SetChecksum(uint16)
+
+ // Payload returns the data carried in the transport buffer.
+ Payload() []byte
+}
+
+// Network offers generic methods to query and/or update the fields of the
+// header of a network protocol buffer.
+type Network interface {
+ // SourceAddress returns the value of the "source address" field.
+ SourceAddress() tcpip.Address
+
+ // DestinationAddress returns the value of the "destination address"
+ // field.
+ DestinationAddress() tcpip.Address
+
+ // Checksum returns the value of the "checksum" field.
+ Checksum() uint16
+
+ // SetSourceAddress sets the value of the "source address" field.
+ SetSourceAddress(tcpip.Address)
+
+ // SetDestinationAddress sets the value of the "destination address"
+ // field.
+ SetDestinationAddress(tcpip.Address)
+
+ // SetChecksum sets the value of the "checksum" field.
+ SetChecksum(uint16)
+
+ // TransportProtocol returns the number of the transport protocol
+ // stored in the payload.
+ TransportProtocol() tcpip.TransportProtocolNumber
+
+ // Payload returns a byte slice containing the payload of the network
+ // packet.
+ Payload() []byte
+
+ // TOS returns the values of the "type of service" and "flow label" fields.
+ TOS() (uint8, uint32)
+
+ // SetTOS sets the values of the "type of service" and "flow label" fields.
+ SetTOS(t uint8, l uint32)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
new file mode 100644
index 000000000..96e461491
--- /dev/null
+++ b/pkg/tcpip/header/ipv4.go
@@ -0,0 +1,282 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ versIHL = 0
+ tos = 1
+ totalLen = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
+)
+
+// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv4Fields struct {
+ // IHL is the "internet header length" field of an IPv4 packet.
+ IHL uint8
+
+ // TOS is the "type of service" field of an IPv4 packet.
+ TOS uint8
+
+ // TotalLength is the "total length" field of an IPv4 packet.
+ TotalLength uint16
+
+ // ID is the "identification" field of an IPv4 packet.
+ ID uint16
+
+ // Flags is the "flags" field of an IPv4 packet.
+ Flags uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv4 packet.
+ FragmentOffset uint16
+
+ // TTL is the "time to live" field of an IPv4 packet.
+ TTL uint8
+
+ // Protocol is the "protocol" field of an IPv4 packet.
+ Protocol uint8
+
+ // Checksum is the "checksum" field of an IPv4 packet.
+ Checksum uint16
+
+ // SrcAddr is the "source ip address" of an IPv4 packet.
+ SrcAddr tcpip.Address
+
+ // DstAddr is the "destination ip address" of an IPv4 packet.
+ DstAddr tcpip.Address
+}
+
+// IPv4 represents an ipv4 header stored in a byte array.
+// Most of the methods of IPv4 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv4 before using other methods.
+type IPv4 []byte
+
+const (
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ IPv4MinimumSize = 20
+
+ // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
+ // that there are only 4 bits to represents the header length in 32-bit
+ // units, the header cannot exceed 15*4 = 60 bytes.
+ IPv4MaximumHeaderSize = 60
+
+ // IPv4AddressSize is the size, in bytes, of an IPv4 address.
+ IPv4AddressSize = 4
+
+ // IPv4ProtocolNumber is IPv4's network protocol number.
+ IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
+
+ // IPv4Version is the version of the ipv4 protocol.
+ IPv4Version = 4
+
+ // IPv4Broadcast is the broadcast address of the IPv4 procotol.
+ IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
+
+ // IPv4Any is the non-routable IPv4 "any" meta address.
+ IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+)
+
+// Flags that may be set in an IPv4 packet.
+const (
+ IPv4FlagMoreFragments = 1 << iota
+ IPv4FlagDontFragment
+)
+
+// IPVersion returns the version of IP used in the given packet. It returns -1
+// if the packet is not large enough to contain the version field.
+func IPVersion(b []byte) int {
+ // Length must be at least offset+length of version field.
+ if len(b) < versIHL+1 {
+ return -1
+ }
+ return int(b[versIHL] >> 4)
+}
+
+// HeaderLength returns the value of the "header length" field of the ipv4
+// header.
+func (b IPv4) HeaderLength() uint8 {
+ return (b[versIHL] & 0xf) * 4
+}
+
+// ID returns the value of the identifier field of the ipv4 header.
+func (b IPv4) ID() uint16 {
+ return binary.BigEndian.Uint16(b[id:])
+}
+
+// Protocol returns the value of the protocol field of the ipv4 header.
+func (b IPv4) Protocol() uint8 {
+ return b[protocol]
+}
+
+// Flags returns the "flags" field of the ipv4 header.
+func (b IPv4) Flags() uint8 {
+ return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
+}
+
+// TTL returns the "TTL" field of the ipv4 header.
+func (b IPv4) TTL() uint8 {
+ return b[ttl]
+}
+
+// FragmentOffset returns the "fragment offset" field of the ipv4 header.
+func (b IPv4) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[flagsFO:]) << 3
+}
+
+// TotalLength returns the "total length" field of the ipv4 header.
+func (b IPv4) TotalLength() uint16 {
+ return binary.BigEndian.Uint16(b[totalLen:])
+}
+
+// Checksum returns the checksum field of the ipv4 header.
+func (b IPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[checksum:])
+}
+
+// SourceAddress returns the "source address" field of the ipv4 header.
+func (b IPv4) SourceAddress() tcpip.Address {
+ return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv4
+// header.
+func (b IPv4) DestinationAddress() tcpip.Address {
+ return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.Protocol())
+}
+
+// Payload implements Network.Payload.
+func (b IPv4) Payload() []byte {
+ return b[b.HeaderLength():][:b.PayloadLength()]
+}
+
+// PayloadLength returns the length of the payload portion of the ipv4 packet.
+func (b IPv4) PayloadLength() uint16 {
+ return b.TotalLength() - uint16(b.HeaderLength())
+}
+
+// TOS returns the "type of service" field of the ipv4 header.
+func (b IPv4) TOS() (uint8, uint32) {
+ return b[tos], 0
+}
+
+// SetTOS sets the "type of service" field of the ipv4 header.
+func (b IPv4) SetTOS(v uint8, _ uint32) {
+ b[tos] = v
+}
+
+// SetTotalLength sets the "total length" field of the ipv4 header.
+func (b IPv4) SetTotalLength(totalLength uint16) {
+ binary.BigEndian.PutUint16(b[totalLen:], totalLength)
+}
+
+// SetChecksum sets the checksum field of the ipv4 header.
+func (b IPv4) SetChecksum(v uint16) {
+ binary.BigEndian.PutUint16(b[checksum:], v)
+}
+
+// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
+// ipv4 header.
+func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
+ v := (uint16(flags) << 13) | (offset >> 3)
+ binary.BigEndian.PutUint16(b[flagsFO:], v)
+}
+
+// SetID sets the identification field.
+func (b IPv4) SetID(v uint16) {
+ binary.BigEndian.PutUint16(b[id:], v)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv4 header.
+func (b IPv4) SetSourceAddress(addr tcpip.Address) {
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], addr)
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv4
+// header.
+func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], addr)
+}
+
+// CalculateChecksum calculates the checksum of the ipv4 header.
+func (b IPv4) CalculateChecksum() uint16 {
+ return Checksum(b[:b.HeaderLength()], 0)
+}
+
+// Encode encodes all the fields of the ipv4 header.
+func (b IPv4) Encode(i *IPv4Fields) {
+ b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b[tos] = i.TOS
+ b.SetTotalLength(i.TotalLength)
+ binary.BigEndian.PutUint16(b[id:], i.ID)
+ b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
+ b[ttl] = i.TTL
+ b[protocol] = i.Protocol
+ b.SetChecksum(i.Checksum)
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
+}
+
+// EncodePartial updates the total length and checksum fields of ipv4 header,
+// taking in the partial checksum, which is the checksum of the header without
+// the total length and checksum fields. It is useful in cases when similar
+// packets are produced.
+func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) {
+ b.SetTotalLength(totalLength)
+ checksum := Checksum(b[totalLen:totalLen+2], partialChecksum)
+ b.SetChecksum(^checksum)
+}
+
+// IsValid performs basic validation on the packet.
+func (b IPv4) IsValid(pktSize int) bool {
+ if len(b) < IPv4MinimumSize {
+ return false
+ }
+
+ hlen := int(b.HeaderLength())
+ tlen := int(b.TotalLength())
+ if hlen > tlen || tlen > pktSize {
+ return false
+ }
+
+ return true
+}
+
+// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
+// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
+// will be 1110 = 0xe0.
+func IsV4MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv4AddressSize {
+ return false
+ }
+ return (addr[0] & 0xf0) == 0xe0
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
new file mode 100644
index 000000000..66820a466
--- /dev/null
+++ b/pkg/tcpip/header/ipv6.go
@@ -0,0 +1,248 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ versTCFL = 0
+ payloadLen = 4
+ nextHdr = 6
+ hopLimit = 7
+ v6SrcAddr = 8
+ v6DstAddr = 24
+)
+
+// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv6Fields struct {
+ // TrafficClass is the "traffic class" field of an IPv6 packet.
+ TrafficClass uint8
+
+ // FlowLabel is the "flow label" field of an IPv6 packet.
+ FlowLabel uint32
+
+ // PayloadLength is the "payload length" field of an IPv6 packet.
+ PayloadLength uint16
+
+ // NextHeader is the "next header" field of an IPv6 packet.
+ NextHeader uint8
+
+ // HopLimit is the "hop limit" field of an IPv6 packet.
+ HopLimit uint8
+
+ // SrcAddr is the "source ip address" of an IPv6 packet.
+ SrcAddr tcpip.Address
+
+ // DstAddr is the "destination ip address" of an IPv6 packet.
+ DstAddr tcpip.Address
+}
+
+// IPv6 represents an ipv6 header stored in a byte array.
+// Most of the methods of IPv6 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv6 before using other methods.
+type IPv6 []byte
+
+const (
+ // IPv6MinimumSize is the minimum size of a valid IPv6 packet.
+ IPv6MinimumSize = 40
+
+ // IPv6AddressSize is the size, in bytes, of an IPv6 address.
+ IPv6AddressSize = 16
+
+ // IPv6ProtocolNumber is IPv6's network protocol number.
+ IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
+
+ // IPv6Version is the version of the ipv6 protocol.
+ IPv6Version = 6
+
+ // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
+ // section 5.
+ IPv6MinimumMTU = 1280
+
+ // IPv6Any is the non-routable IPv6 "any" meta address.
+ IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+)
+
+// PayloadLength returns the value of the "payload length" field of the ipv6
+// header.
+func (b IPv6) PayloadLength() uint16 {
+ return binary.BigEndian.Uint16(b[payloadLen:])
+}
+
+// HopLimit returns the value of the "hop limit" field of the ipv6 header.
+func (b IPv6) HopLimit() uint8 {
+ return b[hopLimit]
+}
+
+// NextHeader returns the value of the "next header" field of the ipv6 header.
+func (b IPv6) NextHeader() uint8 {
+ return b[nextHdr]
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.NextHeader())
+}
+
+// Payload implements Network.Payload.
+func (b IPv6) Payload() []byte {
+ return b[IPv6MinimumSize:][:b.PayloadLength()]
+}
+
+// SourceAddress returns the "source address" field of the ipv6 header.
+func (b IPv6) SourceAddress() tcpip.Address {
+ return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv6
+// header.
+func (b IPv6) DestinationAddress() tcpip.Address {
+ return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+}
+
+// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
+// checksum, it just returns 0.
+func (IPv6) Checksum() uint16 {
+ return 0
+}
+
+// TOS returns the "traffic class" and "flow label" fields of the ipv6 header.
+func (b IPv6) TOS() (uint8, uint32) {
+ v := binary.BigEndian.Uint32(b[versTCFL:])
+ return uint8(v >> 20), v & 0xfffff
+}
+
+// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header.
+func (b IPv6) SetTOS(t uint8, l uint32) {
+ vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
+ binary.BigEndian.PutUint32(b[versTCFL:], vtf)
+}
+
+// SetPayloadLength sets the "payload length" field of the ipv6 header.
+func (b IPv6) SetPayloadLength(payloadLength uint16) {
+ binary.BigEndian.PutUint16(b[payloadLen:], payloadLength)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv6 header.
+func (b IPv6) SetSourceAddress(addr tcpip.Address) {
+ copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv6
+// header.
+func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
+ copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+}
+
+// SetNextHeader sets the value of the "next header" field of the ipv6 header.
+func (b IPv6) SetNextHeader(v uint8) {
+ b[nextHdr] = v
+}
+
+// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
+// checksum, it is empty.
+func (IPv6) SetChecksum(uint16) {
+}
+
+// Encode encodes all the fields of the ipv6 header.
+func (b IPv6) Encode(i *IPv6Fields) {
+ b.SetTOS(i.TrafficClass, i.FlowLabel)
+ b.SetPayloadLength(i.PayloadLength)
+ b[nextHdr] = i.NextHeader
+ b[hopLimit] = i.HopLimit
+ copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
+ copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+}
+
+// IsValid performs basic validation on the packet.
+func (b IPv6) IsValid(pktSize int) bool {
+ if len(b) < IPv6MinimumSize {
+ return false
+ }
+
+ dlen := int(b.PayloadLength())
+ if dlen > pktSize-IPv6MinimumSize {
+ return false
+ }
+
+ return true
+}
+
+// IsV4MappedAddress determines if the provided address is an IPv4 mapped
+// address by checking if its prefix is 0:0:0:0:0:ffff::/96.
+func IsV4MappedAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
+}
+
+// IsV6MulticastAddress determines if the provided address is an IPv6
+// multicast address (anything starting with FF).
+func IsV6MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ return addr[0] == 0xff
+}
+
+// SolicitedNodeAddr computes the solicited-node multicast address. This is
+// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
+// address.
+func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
+ const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
+ return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
+}
+
+// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
+// (MAC) address.
+func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
+ // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local
+ // header, FE80::.
+ //
+ // The conversion is very nearly:
+ // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
+ // Note the capital A. The conversion aa->Aa involves a bit flip.
+ lladdrb := [16]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ 8: linkAddr[0] ^ 2,
+ 9: linkAddr[1],
+ 10: linkAddr[2],
+ 11: 0xFF,
+ 12: 0xFE,
+ 13: linkAddr[3],
+ 14: linkAddr[4],
+ 15: linkAddr[5],
+ }
+ return tcpip.Address(lladdrb[:])
+}
+
+// IsV6LinkLocalAddress determines if the provided address is an IPv6
+// link-local address (fe80::/10).
+func IsV6LinkLocalAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
new file mode 100644
index 000000000..6d896355a
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ nextHdrFrag = 0
+ fragOff = 2
+ more = 3
+ idV6 = 4
+)
+
+// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv6FragmentFields struct {
+ // NextHeader is the "next header" field of an IPv6 fragment.
+ NextHeader uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv6 fragment.
+ FragmentOffset uint16
+
+ // M is the "more" field of an IPv6 fragment.
+ M bool
+
+ // Identification is the "identification" field of an IPv6 fragment.
+ Identification uint32
+}
+
+// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
+// Most of the methods of IPv6Fragment access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv6Fragment before using other methods.
+type IPv6Fragment []byte
+
+const (
+ // IPv6FragmentHeader header is the number used to specify that the next
+ // header is a fragment header, per RFC 2460.
+ IPv6FragmentHeader = 44
+
+ // IPv6FragmentHeaderSize is the size of the fragment header.
+ IPv6FragmentHeaderSize = 8
+)
+
+// Encode encodes all the fields of the ipv6 fragment.
+func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
+ b[nextHdrFrag] = i.NextHeader
+ binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
+ if i.M {
+ b[more] |= 1
+ }
+ binary.BigEndian.PutUint32(b[idV6:], i.Identification)
+}
+
+// IsValid performs basic validation on the fragment header.
+func (b IPv6Fragment) IsValid() bool {
+ return len(b) >= IPv6FragmentHeaderSize
+}
+
+// NextHeader returns the value of the "next header" field of the ipv6 fragment.
+func (b IPv6Fragment) NextHeader() uint8 {
+ return b[nextHdrFrag]
+}
+
+// FragmentOffset returns the "fragment offset" field of the ipv6 fragment.
+func (b IPv6Fragment) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[fragOff:]) >> 3
+}
+
+// More returns the "more" field of the ipv6 fragment.
+func (b IPv6Fragment) More() bool {
+ return b[more]&1 > 0
+}
+
+// Payload implements Network.Payload.
+func (b IPv6Fragment) Payload() []byte {
+ return b[IPv6FragmentHeaderSize:]
+}
+
+// ID returns the value of the identifier field of the ipv6 fragment.
+func (b IPv6Fragment) ID() uint32 {
+ return binary.BigEndian.Uint32(b[idV6:])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.NextHeader())
+}
+
+// The functions below have been added only to satisfy the Network interface.
+
+// Checksum is not supported by IPv6Fragment.
+func (b IPv6Fragment) Checksum() uint16 {
+ panic("not supported")
+}
+
+// SourceAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SourceAddress() tcpip.Address {
+ panic("not supported")
+}
+
+// DestinationAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) DestinationAddress() tcpip.Address {
+ panic("not supported")
+}
+
+// SetSourceAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetSourceAddress(tcpip.Address) {
+ panic("not supported")
+}
+
+// SetDestinationAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) {
+ panic("not supported")
+}
+
+// SetChecksum is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetChecksum(uint16) {
+ panic("not supported")
+}
+
+// TOS is not supported by IPv6Fragment.
+func (b IPv6Fragment) TOS() (uint8, uint32) {
+ panic("not supported")
+}
+
+// SetTOS is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetTOS(t uint8, l uint32) {
+ panic("not supported")
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
new file mode 100644
index 000000000..0cd89b992
--- /dev/null
+++ b/pkg/tcpip/header/tcp.go
@@ -0,0 +1,543 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "github.com/google/btree"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// These constants are the offsets of the respective fields in the TCP header.
+const (
+ TCPSrcPortOffset = 0
+ TCPDstPortOffset = 2
+ TCPSeqNumOffset = 4
+ TCPAckNumOffset = 8
+ TCPDataOffset = 12
+ TCPFlagsOffset = 13
+ TCPWinSizeOffset = 14
+ TCPChecksumOffset = 16
+ TCPUrgentPtrOffset = 18
+)
+
+const (
+ // MaxWndScale is maximum allowed window scaling, as described in
+ // RFC 1323, section 2.3, page 11.
+ MaxWndScale = 14
+
+ // TCPMaxSACKBlocks is the maximum number of SACK blocks that can
+ // be encoded in a TCP option field.
+ TCPMaxSACKBlocks = 4
+)
+
+// Flags that may be set in a TCP segment.
+const (
+ TCPFlagFin = 1 << iota
+ TCPFlagSyn
+ TCPFlagRst
+ TCPFlagPsh
+ TCPFlagAck
+ TCPFlagUrg
+)
+
+// Options that may be present in a TCP segment.
+const (
+ TCPOptionEOL = 0
+ TCPOptionNOP = 1
+ TCPOptionMSS = 2
+ TCPOptionWS = 3
+ TCPOptionTS = 8
+ TCPOptionSACKPermitted = 4
+ TCPOptionSACK = 5
+)
+
+// TCPFields contains the fields of a TCP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type TCPFields struct {
+ // SrcPort is the "source port" field of a TCP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a TCP packet.
+ DstPort uint16
+
+ // SeqNum is the "sequence number" field of a TCP packet.
+ SeqNum uint32
+
+ // AckNum is the "acknowledgement number" field of a TCP packet.
+ AckNum uint32
+
+ // DataOffset is the "data offset" field of a TCP packet.
+ DataOffset uint8
+
+ // Flags is the "flags" field of a TCP packet.
+ Flags uint8
+
+ // WindowSize is the "window size" field of a TCP packet.
+ WindowSize uint16
+
+ // Checksum is the "checksum" field of a TCP packet.
+ Checksum uint16
+
+ // UrgentPointer is the "urgent pointer" field of a TCP packet.
+ UrgentPointer uint16
+}
+
+// TCPSynOptions is used to return the parsed TCP Options in a syn
+// segment.
+type TCPSynOptions struct {
+ // MSS is the maximum segment size provided by the peer in the SYN.
+ MSS uint16
+
+ // WS is the window scale option provided by the peer in the SYN.
+ //
+ // Set to -1 if no window scale option was provided.
+ WS int
+
+ // TS is true if the timestamp option was provided in the syn/syn-ack.
+ TS bool
+
+ // TSVal is the value of the TSVal field in the timestamp option.
+ TSVal uint32
+
+ // TSEcr is the value of the TSEcr field in the timestamp option.
+ TSEcr uint32
+
+ // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
+ SACKPermitted bool
+}
+
+// SACKBlock represents a single contiguous SACK block.
+//
+// +stateify savable
+type SACKBlock struct {
+ // Start indicates the lowest sequence number in the block.
+ Start seqnum.Value
+
+ // End indicates the sequence number immediately following the last
+ // sequence number of this block.
+ End seqnum.Value
+}
+
+// Less returns true if r.Start < b.Start.
+func (r SACKBlock) Less(b btree.Item) bool {
+ return r.Start.LessThan(b.(SACKBlock).Start)
+}
+
+// Contains returns true if b is completely contained in r.
+func (r SACKBlock) Contains(b SACKBlock) bool {
+ return r.Start.LessThanEq(b.Start) && b.End.LessThanEq(r.End)
+}
+
+// TCPOptions are used to parse and cache the TCP segment options for a non
+// syn/syn-ack segment.
+//
+// +stateify savable
+type TCPOptions struct {
+ // TS is true if the TimeStamp option is enabled.
+ TS bool
+
+ // TSVal is the value in the TSVal field of the segment.
+ TSVal uint32
+
+ // TSEcr is the value in the TSEcr field of the segment.
+ TSEcr uint32
+
+ // SACKBlocks are the SACK blocks specified in the segment.
+ SACKBlocks []SACKBlock
+}
+
+// TCP represents a TCP header stored in a byte array.
+type TCP []byte
+
+const (
+ // TCPMinimumSize is the minimum size of a valid TCP packet.
+ TCPMinimumSize = 20
+
+ // TCPOptionsMaximumSize is the maximum size of TCP options.
+ TCPOptionsMaximumSize = 40
+
+ // TCPHeaderMaximumSize is the maximum header size of a TCP packet.
+ TCPHeaderMaximumSize = TCPMinimumSize + TCPOptionsMaximumSize
+
+ // TCPProtocolNumber is TCP's transport protocol number.
+ TCPProtocolNumber tcpip.TransportProtocolNumber = 6
+)
+
+// SourcePort returns the "source port" field of the tcp header.
+func (b TCP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[TCPSrcPortOffset:])
+}
+
+// DestinationPort returns the "destination port" field of the tcp header.
+func (b TCP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[TCPDstPortOffset:])
+}
+
+// SequenceNumber returns the "sequence number" field of the tcp header.
+func (b TCP) SequenceNumber() uint32 {
+ return binary.BigEndian.Uint32(b[TCPSeqNumOffset:])
+}
+
+// AckNumber returns the "ack number" field of the tcp header.
+func (b TCP) AckNumber() uint32 {
+ return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
+}
+
+// DataOffset returns the "data offset" field of the tcp header.
+func (b TCP) DataOffset() uint8 {
+ return (b[TCPDataOffset] >> 4) * 4
+}
+
+// Payload returns the data in the tcp packet.
+func (b TCP) Payload() []byte {
+ return b[b.DataOffset():]
+}
+
+// Flags returns the flags field of the tcp header.
+func (b TCP) Flags() uint8 {
+ return b[TCPFlagsOffset]
+}
+
+// WindowSize returns the "window size" field of the tcp header.
+func (b TCP) WindowSize() uint16 {
+ return binary.BigEndian.Uint16(b[TCPWinSizeOffset:])
+}
+
+// Checksum returns the "checksum" field of the tcp header.
+func (b TCP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
+}
+
+// SetSourcePort sets the "source port" field of the tcp header.
+func (b TCP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the tcp header.
+func (b TCP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port)
+}
+
+// SetChecksum sets the checksum field of the tcp header.
+func (b TCP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum)
+}
+
+// CalculateChecksum calculates the checksum of the tcp segment.
+// partialChecksum is the checksum of the network-layer pseudo-header
+// and the checksum of the segment data.
+func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
+ // Calculate the rest of the checksum.
+ return Checksum(b[:b.DataOffset()], partialChecksum)
+}
+
+// Options returns a slice that holds the unparsed TCP options in the segment.
+func (b TCP) Options() []byte {
+ return b[TCPMinimumSize:b.DataOffset()]
+}
+
+// ParsedOptions returns a TCPOptions structure which parses and caches the TCP
+// option values in the TCP segment. NOTE: Invoking this function repeatedly is
+// expensive as it reparses the options on each invocation.
+func (b TCP) ParsedOptions() TCPOptions {
+ return ParseTCPOptions(b.Options())
+}
+
+func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) {
+ binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq)
+ binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack)
+ b[TCPFlagsOffset] = flags
+ binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
+}
+
+// Encode encodes all the fields of the tcp header.
+func (b TCP) Encode(t *TCPFields) {
+ b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
+ binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort)
+ binary.BigEndian.PutUint16(b[TCPDstPortOffset:], t.DstPort)
+ b[TCPDataOffset] = (t.DataOffset / 4) << 4
+ binary.BigEndian.PutUint16(b[TCPChecksumOffset:], t.Checksum)
+ binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer)
+}
+
+// EncodePartial updates a subset of the fields of the tcp header. It is useful
+// in cases when similar segments are produced.
+func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) {
+ // Add the total length and "flags" field contributions to the checksum.
+ // We don't use the flags field directly from the header because it's a
+ // one-byte field with an odd offset, so it would be accounted for
+ // incorrectly by the Checksum routine.
+ tmp := make([]byte, 4)
+ binary.BigEndian.PutUint16(tmp, length)
+ binary.BigEndian.PutUint16(tmp[2:], uint16(flags))
+ checksum := Checksum(tmp, partialChecksum)
+
+ // Encode the passed-in fields.
+ b.encodeSubset(seqnum, acknum, flags, rcvwnd)
+
+ // Add the contributions of the passed-in fields to the checksum.
+ checksum = Checksum(b[TCPSeqNumOffset:TCPSeqNumOffset+8], checksum)
+ checksum = Checksum(b[TCPWinSizeOffset:TCPWinSizeOffset+2], checksum)
+
+ // Encode the checksum.
+ b.SetChecksum(^checksum)
+}
+
+// ParseSynOptions parses the options received in a SYN segment and returns the
+// relevant ones. opts should point to the option part of the TCP Header.
+func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
+ limit := len(opts)
+
+ synOpts := TCPSynOptions{
+ // Per RFC 1122, page 85: "If an MSS option is not received at
+ // connection setup, TCP MUST assume a default send MSS of 536."
+ MSS: 536,
+ // If no window scale option is specified, WS in options is
+ // returned as -1; this is because the absence of the option
+ // indicates that the we cannot use window scaling on the
+ // receive end either.
+ WS: -1,
+ }
+
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case TCPOptionEOL:
+ i = limit
+ case TCPOptionNOP:
+ i++
+ case TCPOptionMSS:
+ if i+4 > limit || opts[i+1] != 4 {
+ return synOpts
+ }
+ mss := uint16(opts[i+2])<<8 | uint16(opts[i+3])
+ if mss == 0 {
+ return synOpts
+ }
+ synOpts.MSS = mss
+ i += 4
+
+ case TCPOptionWS:
+ if i+3 > limit || opts[i+1] != 3 {
+ return synOpts
+ }
+ ws := int(opts[i+2])
+ if ws > MaxWndScale {
+ ws = MaxWndScale
+ }
+ synOpts.WS = ws
+ i += 3
+
+ case TCPOptionTS:
+ if i+10 > limit || opts[i+1] != 10 {
+ return synOpts
+ }
+ synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:])
+ if isAck {
+ // If the segment is a SYN-ACK then store the Timestamp Echo Reply
+ // in the segment.
+ synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:])
+ }
+ synOpts.TS = true
+ i += 10
+ case TCPOptionSACKPermitted:
+ if i+2 > limit || opts[i+1] != 2 {
+ return synOpts
+ }
+ synOpts.SACKPermitted = true
+ i += 2
+
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return synOpts
+ }
+ l := int(opts[i+1])
+ // If the length is incorrect or if l+i overflows the
+ // total options length then return false.
+ if l < 2 || i+l > limit {
+ return synOpts
+ }
+ i += l
+ }
+ }
+
+ return synOpts
+}
+
+// ParseTCPOptions extracts and stores all known options in the provided byte
+// slice in a TCPOptions structure.
+func ParseTCPOptions(b []byte) TCPOptions {
+ opts := TCPOptions{}
+ limit := len(b)
+ for i := 0; i < limit; {
+ switch b[i] {
+ case TCPOptionEOL:
+ i = limit
+ case TCPOptionNOP:
+ i++
+ case TCPOptionTS:
+ if i+10 > limit || (b[i+1] != 10) {
+ return opts
+ }
+ opts.TS = true
+ opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
+ opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
+ i += 10
+ case TCPOptionSACK:
+ if i+2 > limit {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ sackOptionLen := int(b[i+1])
+ if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ numBlocks := (sackOptionLen - 2) / 8
+ opts.SACKBlocks = []SACKBlock{}
+ for j := 0; j < numBlocks; j++ {
+ start := binary.BigEndian.Uint32(b[i+2+j*8:])
+ end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
+ opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
+ Start: seqnum.Value(start),
+ End: seqnum.Value(end),
+ })
+ }
+ i += sackOptionLen
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return opts
+ }
+ l := int(b[i+1])
+ // If the length is incorrect or if l+i overflows the
+ // total options length then return false.
+ if l < 2 || i+l > limit {
+ return opts
+ }
+ i += l
+ }
+ }
+ return opts
+}
+
+// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
+// the supplied buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeMSSOption(mss uint32, b []byte) int {
+ // mssOptionSize is the number of bytes in a valid MSS option.
+ const mssOptionSize = 4
+
+ if len(b) < mssOptionSize {
+ return 0
+ }
+ b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss)
+ return mssOptionSize
+}
+
+// EncodeWSOption encodes the WS TCP option with the WS value in the
+// provided buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeWSOption(ws int, b []byte) int {
+ if len(b) < 3 {
+ return 0
+ }
+ b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws)
+ return int(b[1])
+}
+
+// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
+// option into the provided buffer. If the buffer is smaller than expected it
+// just returns without encoding anything. It returns the number of bytes
+// written to the provided buffer.
+func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
+ if len(b) < 10 {
+ return 0
+ }
+ b[0], b[1] = TCPOptionTS, 10
+ binary.BigEndian.PutUint32(b[2:], tsVal)
+ binary.BigEndian.PutUint32(b[6:], tsEcr)
+ return int(b[1])
+}
+
+// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
+// buffer. If the buffer is smaller than required it just returns without
+// encoding anything. It returns the number of bytes written to the provided
+// buffer.
+func EncodeSACKPermittedOption(b []byte) int {
+ if len(b) < 2 {
+ return 0
+ }
+
+ b[0], b[1] = TCPOptionSACKPermitted, 2
+ return int(b[1])
+}
+
+// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
+// in the provided slice. It tries to fit in as many blocks as possible based on
+// number of bytes available in the provided buffer. It returns the number of
+// bytes written to the provided buffer.
+func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
+ if len(sackBlocks) == 0 {
+ return 0
+ }
+ l := len(sackBlocks)
+ if l > TCPMaxSACKBlocks {
+ l = TCPMaxSACKBlocks
+ }
+ if ll := (len(b) - 2) / 8; ll < l {
+ l = ll
+ }
+ if l == 0 {
+ // There is not enough space in the provided buffer to add
+ // any SACK blocks.
+ return 0
+ }
+ b[0] = TCPOptionSACK
+ b[1] = byte(l*8 + 2)
+ for i := 0; i < l; i++ {
+ binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
+ binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
+ }
+ return int(b[1])
+}
+
+// EncodeNOP adds an explicit NOP to the option list.
+func EncodeNOP(b []byte) int {
+ if len(b) == 0 {
+ return 0
+ }
+ b[0] = TCPOptionNOP
+ return 1
+}
+
+// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
+// the option buffer. It adds padding bytes after the offset specified and
+// returns the number of padding bytes added. The passed in options slice
+// must have space for the padding bytes.
+func AddTCPOptionPadding(options []byte, offset int) int {
+ paddingToAdd := -offset & 3
+ // Now add any padding bytes that might be required to quad align the
+ // options.
+ for i := offset; i < offset+paddingToAdd; i++ {
+ options[i] = TCPOptionNOP
+ }
+ return paddingToAdd
+}
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
new file mode 100644
index 000000000..2205fec18
--- /dev/null
+++ b/pkg/tcpip/header/udp.go
@@ -0,0 +1,110 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ udpSrcPort = 0
+ udpDstPort = 2
+ udpLength = 4
+ udpChecksum = 6
+)
+
+// UDPFields contains the fields of a UDP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type UDPFields struct {
+ // SrcPort is the "source port" field of a UDP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a UDP packet.
+ DstPort uint16
+
+ // Length is the "length" field of a UDP packet.
+ Length uint16
+
+ // Checksum is the "checksum" field of a UDP packet.
+ Checksum uint16
+}
+
+// UDP represents a UDP header stored in a byte array.
+type UDP []byte
+
+const (
+ // UDPMinimumSize is the minimum size of a valid UDP packet.
+ UDPMinimumSize = 8
+
+ // UDPProtocolNumber is UDP's transport protocol number.
+ UDPProtocolNumber tcpip.TransportProtocolNumber = 17
+)
+
+// SourcePort returns the "source port" field of the udp header.
+func (b UDP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[udpSrcPort:])
+}
+
+// DestinationPort returns the "destination port" field of the udp header.
+func (b UDP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[udpDstPort:])
+}
+
+// Length returns the "length" field of the udp header.
+func (b UDP) Length() uint16 {
+ return binary.BigEndian.Uint16(b[udpLength:])
+}
+
+// Payload returns the data contained in the UDP datagram.
+func (b UDP) Payload() []byte {
+ return b[UDPMinimumSize:]
+}
+
+// Checksum returns the "checksum" field of the udp header.
+func (b UDP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[udpChecksum:])
+}
+
+// SetSourcePort sets the "source port" field of the udp header.
+func (b UDP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the udp header.
+func (b UDP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpDstPort:], port)
+}
+
+// SetChecksum sets the "checksum" field of the udp header.
+func (b UDP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
+}
+
+// CalculateChecksum calculates the checksum of the udp packet, given the
+// checksum of the network-layer pseudo-header and the checksum of the payload.
+func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
+ // Calculate the rest of the checksum.
+ return Checksum(b[:UDPMinimumSize], partialChecksum)
+}
+
+// Encode encodes all the fields of the udp header.
+func (b UDP) Encode(u *UDPFields) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
+ binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
+ binary.BigEndian.PutUint16(b[udpLength:], u.Length)
+ binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
new file mode 100644
index 000000000..1f889c2a0
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -0,0 +1,372 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package fdbased provides the implemention of data-link layer endpoints
+// backed by boundary-preserving file descriptors (e.g., TUN devices,
+// seqpacket/datagram sockets).
+//
+// FD based endpoints can be used in the networking stack by calling New() to
+// create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package fdbased
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// linkDispatcher reads packets from the link FD and dispatches them to the
+// NetworkDispatcher.
+type linkDispatcher interface {
+ dispatch() (bool, *tcpip.Error)
+}
+
+// PacketDispatchMode are the various supported methods of receiving and
+// dispatching packets from the underlying FD.
+type PacketDispatchMode int
+
+const (
+ // Readv is the default dispatch mode and is the least performant of the
+ // dispatch options but the one that is supported by all underlying FD
+ // types.
+ Readv PacketDispatchMode = iota
+ // RecvMMsg enables use of recvmmsg() syscall instead of readv() to
+ // read inbound packets. This reduces # of syscalls needed to process
+ // packets.
+ //
+ // NOTE: recvmmsg() is only supported for sockets, so if the underlying
+ // FD is not a socket then the code will still fall back to the readv()
+ // path.
+ RecvMMsg
+ // PacketMMap enables use of PACKET_RX_RING to receive packets from the
+ // NIC. PacketMMap requires that the underlying FD be an AF_PACKET. The
+ // primary use-case for this is runsc which uses an AF_PACKET FD to
+ // receive packets from the veth device.
+ PacketMMap
+)
+
+type endpoint struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ mtu uint32
+
+ // hdrSize specifies the link-layer header size. If set to 0, no header
+ // is added/removed; otherwise an ethernet header is used.
+ hdrSize int
+
+ // addr is the address of the endpoint.
+ addr tcpip.LinkAddress
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // closed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ closed func(*tcpip.Error)
+
+ inboundDispatcher linkDispatcher
+ dispatcher stack.NetworkDispatcher
+
+ // packetDispatchMode controls the packet dispatcher used by this
+ // endpoint.
+ packetDispatchMode PacketDispatchMode
+
+ // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
+ // disabled.
+ gsoMaxSize uint32
+}
+
+// Options specify the details about the fd-based endpoint to be created.
+type Options struct {
+ FD int
+ MTU uint32
+ EthernetHeader bool
+ ClosedFunc func(*tcpip.Error)
+ Address tcpip.LinkAddress
+ SaveRestore bool
+ DisconnectOk bool
+ GSOMaxSize uint32
+ PacketDispatchMode PacketDispatchMode
+ TXChecksumOffload bool
+ RXChecksumOffload bool
+}
+
+// New creates a new fd-based endpoint.
+//
+// Makes fd non-blocking, but does not take ownership of fd, which must remain
+// open for the lifetime of the returned endpoint.
+func New(opts *Options) (tcpip.LinkEndpointID, error) {
+ if err := syscall.SetNonblock(opts.FD, true); err != nil {
+ return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", opts.FD, err)
+ }
+
+ caps := stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ hdrSize := 0
+ if opts.EthernetHeader {
+ hdrSize = header.EthernetMinimumSize
+ caps |= stack.CapabilityResolutionRequired
+ }
+
+ if opts.SaveRestore {
+ caps |= stack.CapabilitySaveRestore
+ }
+
+ if opts.DisconnectOk {
+ caps |= stack.CapabilityDisconnectOk
+ }
+
+ e := &endpoint{
+ fd: opts.FD,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ packetDispatchMode: opts.PacketDispatchMode,
+ }
+
+ isSocket, err := isSocketFD(e.fd)
+ if err != nil {
+ return 0, err
+ }
+ if isSocket {
+ if opts.GSOMaxSize != 0 {
+ e.caps |= stack.CapabilityGSO
+ e.gsoMaxSize = opts.GSOMaxSize
+ }
+ }
+ e.inboundDispatcher, err = createInboundDispatcher(e, isSocket)
+ if err != nil {
+ return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ }
+
+ return stack.RegisterLinkEndpoint(e), nil
+}
+
+func createInboundDispatcher(e *endpoint, isSocket bool) (linkDispatcher, error) {
+ // By default use the readv() dispatcher as it works with all kinds of
+ // FDs (tap/tun/unix domain sockets and af_packet).
+ inboundDispatcher, err := newReadVDispatcher(e.fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", e.fd, e, err)
+ }
+
+ if isSocket {
+ switch e.packetDispatchMode {
+ case PacketMMap:
+ inboundDispatcher, err = newPacketMMapDispatcher(e.fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", e.fd, e, err)
+ }
+ case RecvMMsg:
+ // If the provided FD is a socket then we optimize
+ // packet reads by using recvmmsg() instead of read() to
+ // read packets in a batch.
+ inboundDispatcher, err = newRecvMMsgDispatcher(e.fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", e.fd, e, err)
+ }
+ }
+ }
+ return inboundDispatcher, nil
+}
+
+func isSocketFD(fd int) (bool, error) {
+ var stat syscall.Stat_t
+ if err := syscall.Fstat(fd, &stat); err != nil {
+ return false, fmt.Errorf("syscall.Fstat(%v,...) failed: %v", fd, err)
+ }
+ return (stat.Mode & syscall.S_IFSOCK) == syscall.S_IFSOCK, nil
+}
+
+// Attach launches the goroutine that reads packets from the file descriptor and
+// dispatches them via the provided dispatcher.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ // Link endpoints are not savable. When transportation endpoints are
+ // saved, they stop sending outgoing packets and all incoming packets
+ // are rejected.
+ go e.dispatchLoop() // S/R-SAFE: See above.
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *endpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength returns the maximum size of the link-layer header.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// virtioNetHdr is declared in linux/virtio_net.h.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+// These constants are declared in linux/virtio_net.h.
+const (
+ _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1
+
+ _VIRTIO_NET_HDR_GSO_TCPV4 = 1
+ _VIRTIO_NET_HDR_GSO_TCPV6 = 4
+)
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ }
+
+ if e.Capabilities()&stack.CapabilityGSO != 0 {
+ vnetHdr := virtioNetHdr{}
+ vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
+ if gso != nil {
+ vnetHdr.hdrLen = uint16(hdr.UsedLength())
+ if gso.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
+ vnetHdr.csumOffset = gso.CsumOffset
+ }
+ if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS {
+ switch gso.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", gso.Type))
+ }
+ vnetHdr.gsoSize = gso.MSS
+ }
+ }
+
+ return rawfile.NonBlockingWrite3(e.fd, vnetHdrBuf, hdr.View(), payload.ToView())
+ }
+
+ if payload.Size() == 0 {
+ return rawfile.NonBlockingWrite(e.fd, hdr.View())
+ }
+
+ return rawfile.NonBlockingWrite3(e.fd, hdr.View(), payload.ToView(), nil)
+}
+
+// WriteRawPacket writes a raw packet directly to the file descriptor.
+func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
+ return rawfile.NonBlockingWrite(e.fd, packet)
+}
+
+// dispatchLoop reads packets from the file descriptor in a loop and dispatches
+// them to the network stack.
+func (e *endpoint) dispatchLoop() *tcpip.Error {
+ for {
+ cont, err := e.inboundDispatcher.dispatch()
+ if err != nil || !cont {
+ if e.closed != nil {
+ e.closed(err)
+ }
+ return err
+ }
+ }
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ return e.gsoMaxSize
+}
+
+// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
+// to the FD, but does not read from it. All reads come from injected packets.
+type InjectableEndpoint struct {
+ endpoint
+
+ dispatcher stack.NetworkDispatcher
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// Inject injects an inbound packet.
+func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+}
+
+// NewInjectable creates a new fd-based InjectableEndpoint.
+func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) {
+ syscall.SetNonblock(fd, true)
+
+ e := &InjectableEndpoint{endpoint: endpoint{
+ fd: fd,
+ mtu: mtu,
+ caps: capabilities,
+ }}
+
+ return stack.RegisterLinkEndpoint(e), e
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
new file mode 100644
index 000000000..97a477b61
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
@@ -0,0 +1,32 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdbased
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{}))
+
+func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ sh.Data = uintptr(unsafe.Pointer(hdr))
+ sh.Len = virtioNetHdrSize
+ sh.Cap = virtioNetHdrSize
+ return
+}
diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go
new file mode 100755
index 000000000..0555db528
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package fdbased
+
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
new file mode 100644
index 000000000..6b7f2a185
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -0,0 +1,25 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux !amd64
+
+package fdbased
+
+import "gvisor.googlesource.com/gvisor/pkg/tcpip"
+
+// Stubbed out version for non-linux/non-amd64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, *tcpip.Error) {
+ return nil, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go
new file mode 100644
index 000000000..1c2d8c468
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_amd64.go
@@ -0,0 +1,194 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64
+
+package fdbased
+
+import (
+ "encoding/binary"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+)
+
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, -1); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ return true, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
new file mode 100644
index 000000000..47cb1d1cc
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
@@ -0,0 +1,84 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64
+
+package fdbased
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// tPacketHdrlen is the TPACKET_HDRLEN variable defined in <linux/if_packet.h>.
+var tPacketHdrlen = tPacketAlign(unsafe.Sizeof(tPacketHdr{}) + unsafe.Sizeof(syscall.RawSockaddrLinklayer{}))
+
+// tpStatus returns the frame status field.
+// The status is concurrently updated by the kernel as a result we must
+// use atomic operations to prevent races.
+func (t tPacketHdr) tpStatus() uint32 {
+ hdr := unsafe.Pointer(&t[0])
+ statusPtr := unsafe.Pointer(uintptr(hdr) + uintptr(tpStatusOffset))
+ return atomic.LoadUint32((*uint32)(statusPtr))
+}
+
+// setTPStatus set's the frame status to the provided status.
+// The status is concurrently updated by the kernel as a result we must
+// use atomic operations to prevent races.
+func (t tPacketHdr) setTPStatus(status uint32) {
+ hdr := unsafe.Pointer(&t[0])
+ statusPtr := unsafe.Pointer(uintptr(hdr) + uintptr(tpStatusOffset))
+ atomic.StoreUint32((*uint32)(statusPtr), status)
+}
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &packetMMapDispatcher{
+ fd: fd,
+ e: e,
+ }
+ pageSize := unix.Getpagesize()
+ if tpBlockSize%pageSize != 0 {
+ return nil, fmt.Errorf("tpBlockSize: %d is not page aligned, pagesize: %d", tpBlockSize, pageSize)
+ }
+ tReq := tPacketReq{
+ tpBlockSize: uint32(tpBlockSize),
+ tpBlockNR: uint32(tpBlockNR),
+ tpFrameSize: uint32(tpFrameSize),
+ tpFrameNR: uint32(tpFrameNR),
+ }
+ // Setup PACKET_RX_RING.
+ if err := setsockopt(d.fd, syscall.SOL_PACKET, syscall.PACKET_RX_RING, unsafe.Pointer(&tReq), unsafe.Sizeof(tReq)); err != nil {
+ return nil, fmt.Errorf("failed to enable PACKET_RX_RING: %v", err)
+ }
+ // Let's mmap the blocks.
+ sz := tpBlockSize * tpBlockNR
+ buf, err := syscall.Mmap(d.fd, 0, sz, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
+ if err != nil {
+ return nil, fmt.Errorf("syscall.Mmap(...,0, %v, ...) failed = %v", sz, err)
+ }
+ d.ringBuffer = buf
+ return d, nil
+}
+
+func setsockopt(fd, level, name int, val unsafe.Pointer, vallen uintptr) error {
+ if _, _, errno := syscall.Syscall6(syscall.SYS_SETSOCKOPT, uintptr(fd), uintptr(level), uintptr(name), uintptr(val), vallen, 0); errno != 0 {
+ return error(errno)
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
new file mode 100644
index 000000000..1ae0e3359
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -0,0 +1,309 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdbased
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// BufConfig defines the shape of the vectorised view used to read packets from the NIC.
+var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}
+
+// readVDispatcher uses readv() system call to read inbound packets and
+// dispatches them.
+type readVDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // views are the actual buffers that hold the packet contents.
+ views []buffer.View
+
+ // iovecs are initialized with base pointers/len of the corresponding
+ // entries in the views defined above, except when GSO is enabled then
+ // the first iovec points to a buffer for the vnet header which is
+ // stripped before the views are passed up the stack for further
+ // processing.
+ iovecs []syscall.Iovec
+}
+
+func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &readVDispatcher{fd: fd, e: e}
+ d.views = make([]buffer.View, len(BufConfig))
+ iovLen := len(BufConfig)
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ iovLen++
+ }
+ d.iovecs = make([]syscall.Iovec, iovLen)
+ return d, nil
+}
+
+func (d *readVDispatcher) allocateViews(bufConfig []int) {
+ var vnetHdr [virtioNetHdrSize]byte
+ vnetHdrOff := 0
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ // The kernel adds virtioNetHdr before each packet, but
+ // we don't use it, so so we allocate a buffer for it,
+ // add it in iovecs but don't add it in a view.
+ d.iovecs[0] = syscall.Iovec{
+ Base: &vnetHdr[0],
+ Len: uint64(virtioNetHdrSize),
+ }
+ vnetHdrOff++
+ }
+ for i := 0; i < len(bufConfig); i++ {
+ if d.views[i] != nil {
+ break
+ }
+ b := buffer.NewView(bufConfig[i])
+ d.views[i] = b
+ d.iovecs[i+vnetHdrOff] = syscall.Iovec{
+ Base: &b[0],
+ Len: uint64(len(b)),
+ }
+ }
+}
+
+func (d *readVDispatcher) capViews(n int, buffers []int) int {
+ c := 0
+ for i, s := range buffers {
+ c += s
+ if c >= n {
+ d.views[i].CapLength(s - (c - n))
+ return i + 1
+ }
+ }
+ return len(buffers)
+}
+
+// dispatch reads one packet from the file descriptor and dispatches it.
+func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
+ d.allocateViews(BufConfig)
+
+ n, err := rawfile.BlockingReadv(d.fd, d.iovecs)
+ if err != nil {
+ return false, err
+ }
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ // Skip virtioNetHdr which is added before each packet, it
+ // isn't used and it isn't in a view.
+ n -= virtioNetHdrSize
+ }
+ if n <= d.e.hdrSize {
+ return false, nil
+ }
+
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(d.views[0])
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(d.views[0]) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ used := d.capViews(n, BufConfig)
+ vv := buffer.NewVectorisedView(n, d.views[:used])
+ vv.TrimFront(d.e.hdrSize)
+
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+
+ // Prepare e.views for another packet: release used views.
+ for i := 0; i < used; i++ {
+ d.views[i] = nil
+ }
+
+ return true, nil
+}
+
+// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
+// dispatches them.
+type recvMMsgDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // views is an array of array of buffers that contain packet contents.
+ views [][]buffer.View
+
+ // iovecs is an array of array of iovec records where each iovec base
+ // pointer and length are initialzed to the corresponding view above,
+ // except when GSO is neabled then the first iovec in each array of
+ // iovecs points to a buffer for the vnet header which is stripped
+ // before the views are passed up the stack for further processing.
+ iovecs [][]syscall.Iovec
+
+ // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to
+ // reference an array of iovecs in the iovecs field defined above. This
+ // array is passed as the parameter to recvmmsg call to retrieve
+ // potentially more than 1 packet per syscall.
+ msgHdrs []rawfile.MMsgHdr
+}
+
+const (
+ // MaxMsgsPerRecv is the maximum number of packets we want to retrieve
+ // in a single RecvMMsg call.
+ MaxMsgsPerRecv = 8
+)
+
+func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &recvMMsgDispatcher{
+ fd: fd,
+ e: e,
+ }
+ d.views = make([][]buffer.View, MaxMsgsPerRecv)
+ for i := range d.views {
+ d.views[i] = make([]buffer.View, len(BufConfig))
+ }
+ d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv)
+ iovLen := len(BufConfig)
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ // virtioNetHdr is prepended before each packet.
+ iovLen++
+ }
+ for i := range d.iovecs {
+ d.iovecs[i] = make([]syscall.Iovec, iovLen)
+ }
+ d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv)
+ for i := range d.msgHdrs {
+ d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0]
+ d.msgHdrs[i].Msg.Iovlen = uint64(iovLen)
+ }
+ return d, nil
+}
+
+func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int {
+ c := 0
+ for i, s := range buffers {
+ c += s
+ if c >= n {
+ d.views[k][i].CapLength(s - (c - n))
+ return i + 1
+ }
+ }
+ return len(buffers)
+}
+
+func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) {
+ for k := 0; k < len(d.views); k++ {
+ var vnetHdr [virtioNetHdrSize]byte
+ vnetHdrOff := 0
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ // The kernel adds virtioNetHdr before each packet, but
+ // we don't use it, so so we allocate a buffer for it,
+ // add it in iovecs but don't add it in a view.
+ d.iovecs[k][0] = syscall.Iovec{
+ Base: &vnetHdr[0],
+ Len: uint64(virtioNetHdrSize),
+ }
+ vnetHdrOff++
+ }
+ for i := 0; i < len(bufConfig); i++ {
+ if d.views[k][i] != nil {
+ break
+ }
+ b := buffer.NewView(bufConfig[i])
+ d.views[k][i] = b
+ d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{
+ Base: &b[0],
+ Len: uint64(len(b)),
+ }
+ }
+ }
+}
+
+// recvMMsgDispatch reads more than one packet at a time from the file
+// descriptor and dispatches it.
+func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
+ d.allocateViews(BufConfig)
+
+ nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs)
+ if err != nil {
+ return false, err
+ }
+ // Process each of received packets.
+ for k := 0; k < nMsgs; k++ {
+ n := int(d.msgHdrs[k].Len)
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ n -= virtioNetHdrSize
+ }
+ if n <= d.e.hdrSize {
+ return false, nil
+ }
+
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(d.views[k][0])
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(d.views[k][0]) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ used := d.capViews(k, int(n), BufConfig)
+ vv := buffer.NewVectorisedView(int(n), d.views[k][:used])
+ vv.TrimFront(d.e.hdrSize)
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+
+ // Prepare e.views for another packet: release used views.
+ for i := 0; i < used; i++ {
+ d.views[k][i] = nil
+ }
+ }
+
+ for k := 0; k < nMsgs; k++ {
+ d.msgHdrs[k].Len = 0
+ }
+
+ return true, nil
+}
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
new file mode 100644
index 000000000..2c1148123
--- /dev/null
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -0,0 +1,87 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package loopback provides the implemention of loopback data-link layer
+// endpoints. Such endpoints just turn outbound packets into inbound ones.
+//
+// Loopback endpoints can be used in the networking stack by calling New() to
+// create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package loopback
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+}
+
+// New creates a new loopback endpoint. This link-layer endpoint just turns
+// outbound packets into inbound packets.
+func New() tcpip.LinkEndpointID {
+ return stack.RegisterLinkEndpoint(&endpoint{})
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
+// layer dispatcher for later use when packets need to be dispatched.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns a constant that matches the
+// linux loopback interface.
+func (*endpoint) MTU() uint32 {
+ return 65536
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises
+// itself as supporting checksum offload, but in reality it's just omitted.
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityRXChecksumOffload | stack.CapabilityTXChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the
+// loopback interface doesn't have a header, it just returns 0.
+func (*endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*endpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
+// packets to the network-layer dispatcher.
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+
+ // Because we're immediately turning around and writing the packet back to the
+ // rx path, we intentionally don't preserve the remote and local link
+ // addresses from the stack.Route we're passed.
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go
new file mode 100755
index 000000000..87ec8cfc7
--- /dev/null
+++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package loopback
+
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
new file mode 100644
index 000000000..b54131573
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the poll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ CALL ·callEntersyscallblock(SB)
+ MOVQ fds+0(FP), DI
+ MOVQ nfds+8(FP), SI
+ MOVQ timeout+16(FP), DX
+ MOVQ $0x7, AX // SYS_POLL
+ SYSCALL
+ CMPQ AX, $0xfffffffffffff001
+ JLS ok
+ MOVQ $-1, n+24(FP)
+ NEGQ AX
+ MOVQ AX, err+32(FP)
+ CALL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVQ AX, n+24(FP)
+ MOVQ $0, err+32(FP)
+ CALL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
new file mode 100644
index 000000000..c87268610
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64
+// +build go1.12
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+package rawfile
+
+import (
+ "syscall"
+ _ "unsafe" // for go:linkname
+)
+
+//go:noescape
+func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno)
+
+// Use go:linkname to call into the runtime. As of Go 1.12 this has to
+// be done from Go code so that we make an ABIInternal call to an
+// ABIInternal function; see https://golang.org/issue/27539.
+
+// We need to call both entersyscallblock and exitsyscall this way so
+// that the runtime's check on the stack pointer lines up.
+
+// Note that calling an unexported function in the runtime package is
+// unsafe and this hack is likely to break in future Go releases.
+
+//go:linkname entersyscallblock runtime.entersyscallblock
+func entersyscallblock()
+
+//go:linkname exitsyscall runtime.exitsyscall
+func exitsyscall()
+
+// These forwarding functions must be nosplit because 1) we must
+// disallow preemption between entersyscallblock and exitsyscall, and
+// 2) we have an untyped assembly frame on the stack which can not be
+// grown or moved.
+
+//go:nosplit
+func callEntersyscallblock() {
+ entersyscallblock()
+}
+
+//go:nosplit
+func callExitsyscall() {
+ exitsyscall()
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
new file mode 100644
index 000000000..4eab77c74
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,!amd64
+
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// BlockingPoll is just a stub function that forwards to the poll() system call
+// on non-amd64 platforms.
+func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout))
+ return int(n), e
+}
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
new file mode 100644
index 000000000..8bde41637
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const maxErrno = 134
+
+var translations [maxErrno]*tcpip.Error
+
+// TranslateErrno translate an errno from the syscall package into a
+// *tcpip.Error.
+//
+// Valid, but unreconigized errnos will be translated to
+// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
+func TranslateErrno(e syscall.Errno) *tcpip.Error {
+ if err := translations[e]; err != nil {
+ return err
+ }
+ return tcpip.ErrInvalidEndpointState
+}
+
+func addTranslation(host syscall.Errno, trans *tcpip.Error) {
+ if translations[host] != nil {
+ panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host))
+ }
+ translations[host] = trans
+}
+
+func init() {
+ addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress)
+ addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute)
+ addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState)
+ addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting)
+ addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected)
+ addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse)
+ addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress)
+ addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend)
+ addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock)
+ addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused)
+ addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout)
+ addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted)
+ addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired)
+ addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported)
+ addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported)
+ addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected)
+ addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset)
+ addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted)
+ addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong)
+ addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace)
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go
new file mode 100755
index 000000000..662c04444
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package rawfile
+
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
new file mode 100644
index 000000000..86db7a487
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -0,0 +1,182 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package rawfile contains utilities for using the netstack with raw host
+// files on Linux hosts.
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// GetMTU determines the MTU of a network interface device.
+func GetMTU(name string) (uint32, error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ return 0, err
+ }
+
+ defer syscall.Close(fd)
+
+ var ifreq struct {
+ name [16]byte
+ mtu int32
+ _ [20]byte
+ }
+
+ copy(ifreq.name[:], name)
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq)))
+ if errno != 0 {
+ return 0, errno
+ }
+
+ return uint32(ifreq.mtu), nil
+}
+
+// NonBlockingWrite writes the given buffer to a file descriptor. It fails if
+// partial data is written.
+func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
+ var ptr unsafe.Pointer
+ if len(buf) > 0 {
+ ptr = unsafe.Pointer(&buf[0])
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf)))
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a
+// single syscall. It fails if partial data is written.
+func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
+ // If the is no second buffer, issue a regular write.
+ if len(b2) == 0 {
+ return NonBlockingWrite(fd, b1)
+ }
+
+ // We have two buffers. Build the iovec that represents them and issue
+ // a writev syscall.
+ iovec := [3]syscall.Iovec{
+ {
+ Base: &b1[0],
+ Len: uint64(len(b1)),
+ },
+ {
+ Base: &b2[0],
+ Len: uint64(len(b2)),
+ },
+ }
+ iovecLen := uintptr(2)
+
+ if len(b3) > 0 {
+ iovecLen++
+ iovec[2].Base = &b3[0]
+ iovec[2].Len = uint64(len(b3))
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// PollEvent represents the pollfd structure passed to a poll() system call.
+type PollEvent struct {
+ FD int32
+ Events int16
+ Revents int16
+}
+
+// BlockingRead reads from a file descriptor that is set up as non-blocking. If
+// no data is available, it will block in a poll() syscall until the file
+// descirptor becomes readable.
+func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ _, e = BlockingPoll(&event, 1, -1)
+ if e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
+
+// BlockingReadv reads from a file descriptor that is set up as non-blocking and
+// stores the data in a list of iovecs buffers. If no data is available, it will
+// block in a poll() syscall until the file descriptor becomes readable.
+func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ _, e = BlockingPoll(&event, 1, -1)
+ if e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
+
+// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
+type MMsgHdr struct {
+ Msg syscall.Msghdr
+ Len uint32
+ _ [4]byte
+}
+
+// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
+// and stores the received messages in a slice of MMsgHdr structures. If no data
+// is available, it will block in a poll() syscall until the file descriptor
+// becomes readable.
+func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ if _, e := BlockingPoll(&event, 1, -1); e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go
new file mode 100644
index 000000000..c16c19647
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/pcap.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sniffer
+
+import "time"
+
+type pcapHeader struct {
+ // MagicNumber is the file magic number.
+ MagicNumber uint32
+
+ // VersionMajor is the major version number.
+ VersionMajor uint16
+
+ // VersionMinor is the minor version number.
+ VersionMinor uint16
+
+ // Thiszone is the GMT to local correction.
+ Thiszone int32
+
+ // Sigfigs is the accuracy of timestamps.
+ Sigfigs uint32
+
+ // Snaplen is the max length of captured packets, in octets.
+ Snaplen uint32
+
+ // Network is the data link type.
+ Network uint32
+}
+
+const pcapPacketHeaderLen = 16
+
+type pcapPacketHeader struct {
+ // Seconds is the timestamp seconds.
+ Seconds uint32
+
+ // Microseconds is the timestamp microseconds.
+ Microseconds uint32
+
+ // IncludedLength is the number of octets of packet saved in file.
+ IncludedLength uint32
+
+ // OriginalLength is the actual length of packet.
+ OriginalLength uint32
+}
+
+func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader {
+ now := time.Now()
+ return pcapPacketHeader{
+ Seconds: uint32(now.Unix()),
+ Microseconds: uint32(now.Nanosecond() / 1000),
+ IncludedLength: incLen,
+ OriginalLength: orgLen,
+ }
+}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
new file mode 100644
index 000000000..fccabd554
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -0,0 +1,408 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sniffer provides the implementation of data-link layer endpoints that
+// wrap another endpoint and logs inbound and outbound packets.
+//
+// Sniffer endpoints can be used in the networking stack by calling New(eID) to
+// create a new endpoint, where eID is the ID of the endpoint being wrapped,
+// and then passing it as an argument to Stack.CreateNIC().
+package sniffer
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "os"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// LogPackets is a flag used to enable or disable packet logging via the log
+// package. Valid values are 0 or 1.
+//
+// LogPackets must be accessed atomically.
+var LogPackets uint32 = 1
+
+// LogPacketsToFile is a flag used to enable or disable logging packets to a
+// pcap file. Valid values are 0 or 1. A file must have been specified when the
+// sniffer was created for this flag to have effect.
+//
+// LogPacketsToFile must be accessed atomically.
+var LogPacketsToFile uint32 = 1
+
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ lower stack.LinkEndpoint
+ file *os.File
+ maxPCAPLen uint32
+}
+
+// New creates a new sniffer link-layer endpoint. It wraps around another
+// endpoint and logs packets and they traverse the endpoint.
+func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID {
+ return stack.RegisterLinkEndpoint(&endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ })
+}
+
+func zoneOffset() (int32, error) {
+ loc, err := time.LoadLocation("Local")
+ if err != nil {
+ return 0, err
+ }
+ date := time.Date(0, 0, 0, 0, 0, 0, 0, loc)
+ _, offset := date.Zone()
+ return int32(offset), nil
+}
+
+func writePCAPHeader(w io.Writer, maxLen uint32) error {
+ offset, err := zoneOffset()
+ if err != nil {
+ return err
+ }
+ return binary.Write(w, binary.BigEndian, pcapHeader{
+ // From https://wiki.wireshark.org/Development/LibpcapFileFormat
+ MagicNumber: 0xa1b2c3d4,
+
+ VersionMajor: 2,
+ VersionMinor: 4,
+ Thiszone: offset,
+ Sigfigs: 0,
+ Snaplen: maxLen,
+ Network: 101, // LINKTYPE_RAW
+ })
+}
+
+// NewWithFile creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets and they traverse the endpoint.
+//
+// Packets can be logged to file in the pcap format. A sniffer created
+// with this function will not emit packets using the standard log
+// package.
+//
+// snapLen is the maximum amount of a packet to be saved. Packets with a length
+// less than or equal too snapLen will be saved in their entirety. Longer
+// packets will be truncated to snapLen.
+func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
+ if err := writePCAPHeader(file, snapLen); err != nil {
+ return 0, err
+ }
+ return stack.RegisterLinkEndpoint(&endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ file: file,
+ maxPCAPLen: snapLen,
+ }), nil
+}
+
+// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
+// called by the link-layer endpoint being wrapped when a packet arrives, and
+// logs the packet before forwarding to the actual dispatcher.
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("recv", protocol, vv.First())
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ vs := vv.Views()
+ length := vv.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
+ panic(err)
+ }
+ for _, v := range vs {
+ if length == 0 {
+ break
+ }
+ if len(v) > length {
+ v = v[:length]
+ }
+ if _, err := buf.Write([]byte(v)); err != nil {
+ panic(err)
+ }
+ length -= len(v)
+ }
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
+ }
+ }
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv)
+}
+
+// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
+// and registers with the lower endpoint as its dispatcher so that "e" is called
+// for inbound packets.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
+// lower endpoint.
+func (e *endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
+// request to the lower endpoint.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards
+// the request to the lower endpoint.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.lower.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// WritePacket implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and forwards
+// the request to the lower endpoint.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("send", protocol, hdr.View())
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ hdrBuf := hdr.View()
+ length := len(hdrBuf) + payload.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil {
+ panic(err)
+ }
+ if len(hdrBuf) > length {
+ hdrBuf = hdrBuf[:length]
+ }
+ if _, err := buf.Write(hdrBuf); err != nil {
+ panic(err)
+ }
+ length -= len(hdrBuf)
+ if length > 0 {
+ for _, v := range payload.Views() {
+ if len(v) > length {
+ v = v[:length]
+ }
+ n, err := buf.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ length -= n
+ if length == 0 {
+ break
+ }
+ }
+ }
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
+ }
+ }
+ return e.lower.WritePacket(r, gso, hdr, payload, protocol)
+}
+
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) {
+ // Figure out the network layer info.
+ var transProto uint8
+ src := tcpip.Address("unknown")
+ dst := tcpip.Address("unknown")
+ id := 0
+ size := uint16(0)
+ var fragmentOffset uint16
+ var moreFragments bool
+ switch protocol {
+ case header.IPv4ProtocolNumber:
+ ipv4 := header.IPv4(b)
+ fragmentOffset = ipv4.FragmentOffset()
+ moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
+ src = ipv4.SourceAddress()
+ dst = ipv4.DestinationAddress()
+ transProto = ipv4.Protocol()
+ size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
+ b = b[ipv4.HeaderLength():]
+ id = int(ipv4.ID())
+
+ case header.IPv6ProtocolNumber:
+ ipv6 := header.IPv6(b)
+ src = ipv6.SourceAddress()
+ dst = ipv6.DestinationAddress()
+ transProto = ipv6.NextHeader()
+ size = ipv6.PayloadLength()
+ b = b[header.IPv6MinimumSize:]
+
+ case header.ARPProtocolNumber:
+ arp := header.ARP(b)
+ log.Infof(
+ "%s arp %v (%v) -> %v (%v) valid:%v",
+ prefix,
+ tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
+ tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
+ arp.IsValid(),
+ )
+ return
+ default:
+ log.Infof("%s unknown network protocol", prefix)
+ return
+ }
+
+ // Figure out the transport layer info.
+ transName := "unknown"
+ srcPort := uint16(0)
+ dstPort := uint16(0)
+ details := ""
+ switch tcpip.TransportProtocolNumber(transProto) {
+ case header.ICMPv4ProtocolNumber:
+ transName = "icmp"
+ icmp := header.ICMPv4(b)
+ icmpType := "unknown"
+ if fragmentOffset == 0 {
+ switch icmp.Type() {
+ case header.ICMPv4EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv4DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv4SrcQuench:
+ icmpType = "source quench"
+ case header.ICMPv4Redirect:
+ icmpType = "redirect"
+ case header.ICMPv4Echo:
+ icmpType = "echo"
+ case header.ICMPv4TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv4ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv4Timestamp:
+ icmpType = "timestamp"
+ case header.ICMPv4TimestampReply:
+ icmpType = "timestamp reply"
+ case header.ICMPv4InfoRequest:
+ icmpType = "info request"
+ case header.ICMPv4InfoReply:
+ icmpType = "info reply"
+ }
+ }
+ log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ return
+
+ case header.ICMPv6ProtocolNumber:
+ transName = "icmp"
+ icmp := header.ICMPv6(b)
+ icmpType := "unknown"
+ switch icmp.Type() {
+ case header.ICMPv6DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv6PacketTooBig:
+ icmpType = "packet too big"
+ case header.ICMPv6TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv6ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv6EchoRequest:
+ icmpType = "echo request"
+ case header.ICMPv6EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv6RouterSolicit:
+ icmpType = "router solicit"
+ case header.ICMPv6RouterAdvert:
+ icmpType = "router advert"
+ case header.ICMPv6NeighborSolicit:
+ icmpType = "neighbor solicit"
+ case header.ICMPv6NeighborAdvert:
+ icmpType = "neighbor advert"
+ case header.ICMPv6RedirectMsg:
+ icmpType = "redirect message"
+ }
+ log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ return
+
+ case header.UDPProtocolNumber:
+ transName = "udp"
+ udp := header.UDP(b)
+ if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
+ srcPort = udp.SourcePort()
+ dstPort = udp.DestinationPort()
+ }
+ size -= header.UDPMinimumSize
+
+ details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
+
+ case header.TCPProtocolNumber:
+ transName = "tcp"
+ tcp := header.TCP(b)
+ if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize {
+ offset := int(tcp.DataOffset())
+ if offset < header.TCPMinimumSize {
+ details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
+ break
+ }
+ if offset > len(tcp) && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp))
+ break
+ }
+
+ srcPort = tcp.SourcePort()
+ dstPort = tcp.DestinationPort()
+ size -= uint16(offset)
+
+ // Initialize the TCP flags.
+ flags := tcp.Flags()
+ flagsStr := []byte("FSRPAU")
+ for i := range flagsStr {
+ if flags&(1<<uint(i)) == 0 {
+ flagsStr[i] = ' '
+ }
+ }
+ details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ if flags&header.TCPFlagSyn != 0 {
+ details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
+ } else {
+ details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions())
+ }
+ }
+
+ default:
+ log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto)
+ return
+ }
+
+ log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+}
diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go
new file mode 100755
index 000000000..cfd84a739
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package sniffer
+
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
new file mode 100644
index 000000000..a3f2bce3e
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp.go
@@ -0,0 +1,203 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package arp implements the ARP network protocol. It is used to resolve
+// IPv4 addresses into link-local MAC addresses, and advertises IPv4
+// addresses of its stack with the local network.
+//
+// To use it in the networking stack, pass arp.ProtocolName as one of the
+// network protocols when calling stack.New. Then add an "arp" address to
+// every NIC on the stack that should respond to ARP requests. That is:
+//
+// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
+// // handle err
+// }
+package arp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ARP protocol name.
+ ProtocolName = "arp"
+
+ // ProtocolNumber is the ARP protocol number.
+ ProtocolNumber = header.ARPProtocolNumber
+
+ // ProtocolAddress is the address expected by the ARP endpoint.
+ ProtocolAddress = tcpip.Address("arp")
+)
+
+// endpoint implements stack.NetworkEndpoint.
+type endpoint struct {
+ nicid tcpip.NICID
+ addr tcpip.Address
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+}
+
+// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+func (e *endpoint) MTU() uint32 {
+ lmtu := e.linkEP.MTU()
+ return lmtu - uint32(e.MaxHeaderLength())
+}
+
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &stack.NetworkEndpointID{ProtocolAddress}
+}
+
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.ARPSize
+}
+
+func (e *endpoint) Close() {}
+
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+ v := vv.First()
+ h := header.ARP(v)
+ if !h.IsValid() {
+ return
+ }
+
+ switch h.Op() {
+ case header.ARPRequest:
+ localAddr := tcpip.Address(h.ProtocolAddressTarget())
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+ hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
+ pkt := header.ARP(hdr.Prepend(header.ARPSize))
+ pkt.SetIPv4OverEthernet()
+ pkt.SetOp(header.ARPReply)
+ copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ fallthrough // also fill the cache from requests
+ case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
+ }
+}
+
+// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
+type protocol struct {
+}
+
+func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
+func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
+
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.ARP(v)
+ return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+}
+
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ if addr != ProtocolAddress {
+ return nil, tcpip.ErrBadLocalAddress
+ }
+ return &endpoint{
+ nicid: nicid,
+ addr: addr,
+ linkEP: sender,
+ linkAddrCache: linkAddrCache,
+ }, nil
+}
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv4ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ r := &stack.Route{
+ RemoteLinkAddress: broadcastMAC,
+ }
+
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
+ h := header.ARP(hdr.Prepend(header.ARPSize))
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), linkEP.LinkAddress())
+ copy(h.ProtocolAddressSender(), localAddr)
+ copy(h.ProtocolAddressTarget(), addr)
+
+ return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == header.IPv4Broadcast {
+ return broadcastMAC, true
+ }
+ if header.IsV4MulticastAddress(addr) {
+ // RFC 1112 Host Extensions for IP Multicasting
+ //
+ // 6.4. Extensions to an Ethernet Local Network Module:
+ //
+ // An IP host group address is mapped to an Ethernet multicast
+ // address by placing the low-order 23-bits of the IP address
+ // into the low-order 23 bits of the Ethernet multicast address
+ // 01-00-5E-00-00-00 (hex).
+ return tcpip.LinkAddress([]byte{
+ 0x01,
+ 0x00,
+ 0x5e,
+ addr[header.IPv4AddressSize-3] & 0x7f,
+ addr[header.IPv4AddressSize-2],
+ addr[header.IPv4AddressSize-1],
+ }), true
+ }
+ return "", false
+}
+
+// SetOption implements NetworkProtocol.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+func init() {
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go
new file mode 100755
index 000000000..14a21baff
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package arp
+
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
new file mode 100644
index 000000000..9ad3e5a8a
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type fragment struct {
+ offset uint16
+ vv buffer.VectorisedView
+}
+
+type fragHeap []fragment
+
+func (h *fragHeap) Len() int {
+ return len(*h)
+}
+
+func (h *fragHeap) Less(i, j int) bool {
+ return (*h)[i].offset < (*h)[j].offset
+}
+
+func (h *fragHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+func (h *fragHeap) Push(x interface{}) {
+ *h = append(*h, x.(fragment))
+}
+
+func (h *fragHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
+
+// reassamble empties the heap and returns a VectorisedView
+// containing a reassambled version of the fragments inside the heap.
+func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
+ curr := heap.Pop(h).(fragment)
+ views := curr.vv.Views()
+ size := curr.vv.Size()
+
+ if curr.offset != 0 {
+ return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
+ }
+
+ for h.Len() > 0 {
+ curr := heap.Pop(h).(fragment)
+ if int(curr.offset) < size {
+ curr.vv.TrimFront(size - int(curr.offset))
+ } else if int(curr.offset) > size {
+ return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
+ }
+ size += curr.vv.Size()
+ views = append(views, curr.vv.Views()...)
+ }
+ return buffer.NewVectorisedView(size, views), nil
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
new file mode 100644
index 000000000..e90edb375
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fragmentation contains the implementation of IP fragmentation.
+// It is based on RFC 791 and RFC 815.
+package fragmentation
+
+import (
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+const DefaultReassembleTimeout = 30 * time.Second
+
+// HighFragThreshold is the threshold at which we start trimming old
+// fragmented packets. Linux uses a default value of 4 MB. See
+// net.ipv4.ipfrag_high_thresh for more information.
+const HighFragThreshold = 4 << 20 // 4MB
+
+// LowFragThreshold is the threshold we reach to when we start dropping
+// older fragmented packets. It's important that we keep enough room for newer
+// packets to be re-assembled. Hence, this needs to be lower than
+// HighFragThreshold enough. Linux uses a default value of 3 MB. See
+// net.ipv4.ipfrag_low_thresh for more information.
+const LowFragThreshold = 3 << 20 // 3MB
+
+// Fragmentation is the main structure that other modules
+// of the stack should use to implement IP Fragmentation.
+type Fragmentation struct {
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[uint32]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+}
+
+// NewFragmentation creates a new Fragmentation.
+//
+// highMemoryLimit specifies the limit on the memory consumed
+// by the fragments stored by Fragmentation (overhead of internal data-structures
+// is not accounted). Fragments are dropped when the limit is reached.
+//
+// lowMemoryLimit specifies the limit on which we will reach by dropping
+// fragments after reaching highMemoryLimit.
+//
+// reassemblingTimeout specifes the maximum time allowed to reassemble a packet.
+// Fragments are lazily evicted only when a new a packet with an
+// already existing fragmentation-id arrives after the timeout.
+func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+ if lowMemoryLimit >= highMemoryLimit {
+ lowMemoryLimit = highMemoryLimit
+ }
+
+ if lowMemoryLimit < 0 {
+ lowMemoryLimit = 0
+ }
+
+ return &Fragmentation{
+ reassemblers: make(map[uint32]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ }
+}
+
+// Process processes an incoming fragment beloning to an ID
+// and returns a complete packet when all the packets belonging to that ID have been received.
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
+ f.mu.Lock()
+ r, ok := f.reassemblers[id]
+ if ok && r.tooOld(f.timeout) {
+ // This is very likely to be an id-collision or someone performing a slow-rate attack.
+ f.release(r)
+ ok = false
+ }
+ if !ok {
+ r = newReassembler(id)
+ f.reassemblers[id] = r
+ f.rList.PushFront(r)
+ }
+ f.mu.Unlock()
+
+ res, done, consumed := r.process(first, last, more, vv)
+
+ f.mu.Lock()
+ f.size += consumed
+ if done {
+ f.release(r)
+ }
+ // Evict reassemblers if we are consuming more memory than highLimit until
+ // we reach lowLimit.
+ if f.size > f.highLimit {
+ tail := f.rList.Back()
+ for f.size > f.lowLimit && tail != nil {
+ f.release(tail)
+ tail = tail.Prev()
+ }
+ }
+ f.mu.Unlock()
+ return res, done
+}
+
+func (f *Fragmentation) release(r *reassembler) {
+ // Before releasing a fragment we need to check if r is already marked as done.
+ // Otherwise, we would delete it twice.
+ if r.checkDoneOrMark() {
+ return
+ }
+
+ delete(f.reassemblers, r.id)
+ f.rList.Remove(r)
+ f.size -= r.size
+ if f.size < 0 {
+ log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
+ f.size = 0
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go
new file mode 100755
index 000000000..c012e8012
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go
@@ -0,0 +1,38 @@
+// automatically generated by stateify.
+
+package fragmentation
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *reassemblerList) beforeSave() {}
+func (x *reassemblerList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *reassemblerList) afterLoad() {}
+func (x *reassemblerList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *reassemblerEntry) beforeSave() {}
+func (x *reassemblerEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *reassemblerEntry) afterLoad() {}
+func (x *reassemblerEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("fragmentation.reassemblerList", (*reassemblerList)(nil), state.Fns{Save: (*reassemblerList).save, Load: (*reassemblerList).load})
+ state.Register("fragmentation.reassemblerEntry", (*reassemblerEntry)(nil), state.Fns{Save: (*reassemblerEntry).save, Load: (*reassemblerEntry).load})
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
new file mode 100644
index 000000000..04f9ab964
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type hole struct {
+ first uint16
+ last uint16
+ deleted bool
+}
+
+type reassembler struct {
+ reassemblerEntry
+ id uint32
+ size int
+ mu sync.Mutex
+ holes []hole
+ deleted int
+ heap fragHeap
+ done bool
+ creationTime time.Time
+}
+
+func newReassembler(id uint32) *reassembler {
+ r := &reassembler{
+ id: id,
+ holes: make([]hole, 0, 16),
+ deleted: 0,
+ heap: make(fragHeap, 0, 8),
+ creationTime: time.Now(),
+ }
+ r.holes = append(r.holes, hole{
+ first: 0,
+ last: math.MaxUint16,
+ deleted: false})
+ return r
+}
+
+// updateHoles updates the list of holes for an incoming fragment and
+// returns true iff the fragment filled at least part of an existing hole.
+func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
+ used := false
+ for i := range r.holes {
+ if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
+ continue
+ }
+ used = true
+ r.deleted++
+ r.holes[i].deleted = true
+ if first > r.holes[i].first {
+ r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
+ }
+ if last < r.holes[i].last && more {
+ r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
+ }
+ }
+ return used
+}
+
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ consumed := 0
+ if r.done {
+ // A concurrent goroutine might have already reassembled
+ // the packet and emptied the heap while this goroutine
+ // was waiting on the mutex. We don't have to do anything in this case.
+ return buffer.VectorisedView{}, false, consumed
+ }
+ if r.updateHoles(first, last, more) {
+ // We store the incoming packet only if it filled some holes.
+ heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
+ consumed = vv.Size()
+ r.size += consumed
+ }
+ // Check if all the holes have been deleted and we are ready to reassamble.
+ if r.deleted < len(r.holes) {
+ return buffer.VectorisedView{}, false, consumed
+ }
+ res, err := r.heap.reassemble()
+ if err != nil {
+ panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
+ }
+ return res, true, consumed
+}
+
+func (r *reassembler) tooOld(timeout time.Duration) bool {
+ return time.Now().Sub(r.creationTime) > timeout
+}
+
+func (r *reassembler) checkDoneOrMark() bool {
+ r.mu.Lock()
+ prev := r.done
+ r.done = true
+ r.mu.Unlock()
+ return prev
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_list.go b/pkg/tcpip/network/fragmentation/reassembler_list.go
new file mode 100755
index 000000000..3189cae29
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler_list.go
@@ -0,0 +1,173 @@
+package fragmentation
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type reassemblerElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type reassemblerList struct {
+ head *reassembler
+ tail *reassembler
+}
+
+// Reset resets list l to the empty state.
+func (l *reassemblerList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *reassemblerList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *reassemblerList) Front() *reassembler {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *reassemblerList) Back() *reassembler {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *reassemblerList) PushFront(e *reassembler) {
+ reassemblerElementMapper{}.linkerFor(e).SetNext(l.head)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *reassemblerList) PushBack(e *reassembler) {
+ reassemblerElementMapper{}.linkerFor(e).SetNext(nil)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *reassemblerList) PushBackList(m *reassemblerList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *reassemblerList) InsertAfter(b, e *reassembler) {
+ a := reassemblerElementMapper{}.linkerFor(b).Next()
+ reassemblerElementMapper{}.linkerFor(e).SetNext(a)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(b)
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *reassemblerList) InsertBefore(a, e *reassembler) {
+ b := reassemblerElementMapper{}.linkerFor(a).Prev()
+ reassemblerElementMapper{}.linkerFor(e).SetNext(a)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(b)
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *reassemblerList) Remove(e *reassembler) {
+ prev := reassemblerElementMapper{}.linkerFor(e).Prev()
+ next := reassemblerElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ reassemblerElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ reassemblerElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type reassemblerEntry struct {
+ next *reassembler
+ prev *reassembler
+}
+
+// Next returns the entry that follows e in the list.
+func (e *reassemblerEntry) Next() *reassembler {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *reassemblerEntry) Prev() *reassembler {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *reassemblerEntry) SetNext(elem *reassembler) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *reassemblerEntry) SetPrev(elem *reassembler) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
new file mode 100644
index 000000000..0c91905dc
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hash contains utility functions for hashing.
+package hash
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+var hashIV = RandN32(1)[0]
+
+// RandN32 generates a slice of n cryptographic random 32-bit numbers.
+func RandN32(n int) []uint32 {
+ b := make([]byte, 4*n)
+ if _, err := rand.Read(b); err != nil {
+ panic("unable to get random numbers: " + err.Error())
+ }
+ r := make([]uint32, n)
+ for i := range r {
+ r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)])
+ }
+ return r
+}
+
+// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted
+// from linux.
+func Hash3Words(a, b, c, initval uint32) uint32 {
+ const iv = 0xdeadbeef + (3 << 2)
+ initval += iv
+
+ a += initval
+ b += initval
+ c += initval
+
+ c ^= b
+ c -= rol32(b, 14)
+ a ^= c
+ a -= rol32(c, 11)
+ b ^= a
+ b -= rol32(a, 25)
+ c ^= b
+ c -= rol32(b, 16)
+ a ^= c
+ a -= rol32(c, 4)
+ b ^= a
+ b -= rol32(a, 14)
+ c ^= b
+ c -= rol32(b, 24)
+
+ return c
+}
+
+// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791.
+func IPv4FragmentHash(h header.IPv4) uint32 {
+ x := uint32(h.ID())<<16 | uint32(h.Protocol())
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(x, y, z, hashIV)
+}
+
+// IPv6FragmentHash computes the hash of the ipv6 fragment.
+// Unlike IPv4, the protocol is not used to compute the hash.
+// RFC 2640 (sec 4.5) is not very sharp on this aspect.
+// As a reference, also Linux ignores the protocol to compute
+// the hash (inet6_hash_frag).
+func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 {
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(f.ID(), y, z, hashIV)
+}
+
+func rol32(v, shift uint32) uint32 {
+ return (v << shift) | (v >> ((-shift) & 31))
+}
diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go
new file mode 100755
index 000000000..a3bcd4b69
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package hash
+
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
new file mode 100644
index 000000000..770f56c3d
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -0,0 +1,160 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ h := header.IPv4(vv.First())
+
+ // We don't use IsValid() here because ICMP only requires that the IP
+ // header plus 8 bytes of the transport header be included. So it's
+ // likely that it is truncated, which would cause IsValid to return
+ // false.
+ //
+ // Drop packet if it doesn't have the basic IPv4 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ hlen := int(h.HeaderLength())
+ if vv.Size() < hlen || h.FragmentOffset() != 0 {
+ // We won't be able to handle this if it doesn't contain the
+ // full IPv4 header, or if it's a fragment not at offset 0
+ // (because it won't have the transport header).
+ return
+ }
+
+ // Skip the ip header, then deliver control message.
+ vv.TrimFront(hlen)
+ p := h.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ stats := r.Stats()
+ received := stats.ICMP.V4PacketsReceived
+ v := vv.First()
+ if len(v) < header.ICMPv4MinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.ICMPv4(v)
+
+ // TODO(b/112892170): Meaningfully handle all ICMP types.
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ received.Echo.Increment()
+ if len(v) < header.ICMPv4EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ // Only send a reply if the checksum is valid.
+ wantChecksum := h.Checksum()
+ // Reset the checksum field to 0 to can calculate the proper
+ // checksum. We'll have to reset this before we hand the packet
+ // off.
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */)
+ if gotChecksum != wantChecksum {
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ received.Invalid.Increment()
+ return
+ }
+
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+
+ vv := vv.Clone(nil)
+ vv.TrimFront(header.ICMPv4EchoMinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4EchoMinimumSize)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ copy(pkt, h)
+ pkt.SetType(header.ICMPv4EchoReply)
+ pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
+ sent := stats.ICMP.V4PacketsSent
+ if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.EchoReply.Increment()
+
+ case header.ICMPv4EchoReply:
+ received.EchoReply.Increment()
+ if len(v) < header.ICMPv4EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+
+ case header.ICMPv4DstUnreachable:
+ received.DstUnreachable.Increment()
+ if len(v) < header.ICMPv4DstUnreachableMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv4PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+
+ case header.ICMPv4FragmentationNeeded:
+ mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:]))
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ }
+
+ case header.ICMPv4SrcQuench:
+ received.SrcQuench.Increment()
+
+ case header.ICMPv4Redirect:
+ received.Redirect.Increment()
+
+ case header.ICMPv4TimeExceeded:
+ received.TimeExceeded.Increment()
+
+ case header.ICMPv4ParamProblem:
+ received.ParamProblem.Increment()
+
+ case header.ICMPv4Timestamp:
+ received.Timestamp.Increment()
+
+ case header.ICMPv4TimestampReply:
+ received.TimestampReply.Increment()
+
+ case header.ICMPv4InfoRequest:
+ received.InfoRequest.Increment()
+
+ case header.ICMPv4InfoReply:
+ received.InfoReply.Increment()
+
+ default:
+ received.Invalid.Increment()
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
new file mode 100644
index 000000000..da07a39e5
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -0,0 +1,344 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipv4 contains the implementation of the ipv4 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
+// network protocols when calling stack.New(). Then endpoints can be created
+// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv4
+
+import (
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ipv4 protocol name.
+ ProtocolName = "ipv4"
+
+ // ProtocolNumber is the ipv4 protocol number.
+ ProtocolNumber = header.IPv4ProtocolNumber
+
+ // MaxTotalSize is maximum size that can be encoded in the 16-bit
+ // TotalLength field of the ipv4 header.
+ MaxTotalSize = 0xffff
+
+ // buckets is the number of identifier buckets.
+ buckets = 2048
+)
+
+type endpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ fragmentation *fragmentation.Fragmentation
+}
+
+// NewEndpoint creates a new ipv4 endpoint.
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ e := &endpoint{
+ nicid: nicid,
+ id: stack.NetworkEndpointID{LocalAddress: addr},
+ linkEP: linkEP,
+ dispatcher: dispatcher,
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ }
+
+ return e, nil
+}
+
+// DefaultTTL is the default time-to-live value for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 255
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+// ID returns the ipv4 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
+// write. It assumes that the IP header is entirely in hdr but does not assume
+// that only the IP header is in hdr. It assumes that the input packet's stated
+// length matches the length of the hdr+payload. mtu includes the IP header and
+// options. This does not support the DontFragment IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error {
+ // This packet is too big, it needs to be fragmented.
+ ip := header.IPv4(hdr.View())
+ flags := ip.Flags()
+
+ // Update mtu to take into account the header, which will exist in all
+ // fragments anyway.
+ innerMTU := mtu - int(ip.HeaderLength())
+
+ // Round the MTU down to align to 8 bytes. Then calculate the number of
+ // fragments. Calculate fragment sizes as in RFC791.
+ innerMTU &^= 7
+ n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
+
+ outerMTU := innerMTU + int(ip.HeaderLength())
+ offset := ip.FragmentOffset()
+ originalAvailableLength := hdr.AvailableLength()
+ for i := 0; i < n; i++ {
+ // Where possible, the first fragment that is sent has the same
+ // hdr.UsedLength() as the input packet. The link-layer endpoint may depends
+ // on this for looking at, eg, L4 headers.
+ h := ip
+ if i > 0 {
+ hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
+ h = header.IPv4(hdr.Prepend(int(ip.HeaderLength())))
+ copy(h, ip[:ip.HeaderLength()])
+ }
+ if i != n-1 {
+ h.SetTotalLength(uint16(outerMTU))
+ h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
+ } else {
+ h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size()))
+ h.SetFlagsFragmentOffset(flags, offset)
+ }
+ h.SetChecksum(0)
+ h.SetChecksum(^h.CalculateChecksum())
+ offset += uint16(innerMTU)
+ if i > 0 {
+ newPayload := payload.Clone([]buffer.View{})
+ newPayload.CapLength(innerMTU)
+ if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ payload.TrimFront(newPayload.Size())
+ continue
+ }
+ // Special handling for the first fragment because it comes from the hdr.
+ if outerMTU >= hdr.UsedLength() {
+ // This fragment can fit all of hdr and possibly some of payload, too.
+ newPayload := payload.Clone([]buffer.View{})
+ newPayloadLength := outerMTU - hdr.UsedLength()
+ newPayload.CapLength(newPayloadLength)
+ if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ payload.TrimFront(newPayloadLength)
+ } else {
+ // The fragment is too small to fit all of hdr.
+ startOfHdr := hdr
+ startOfHdr.TrimBack(hdr.UsedLength() - outerMTU)
+ emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
+ if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ // Add the unused bytes of hdr into the payload that remains to be sent.
+ restOfHdr := hdr.View()[outerMTU:]
+ tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
+ tmp.Append(payload)
+ payload = tmp
+ }
+ }
+ return nil
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ length := uint16(hdr.UsedLength() + payload.Size())
+ id := uint32(0)
+ if length > header.IPv4MaximumHeaderSize+8 {
+ // Packets of 68 bytes or less are required by RFC 791 to not be
+ // fragmented, so we only assign ids to larger packets.
+ id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: length,
+ ID: uint16(id),
+ TTL: ttl,
+ Protocol: uint8(protocol),
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ if loop&stack.PacketLoop != 0 {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ e.HandlePacket(r, vv)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+ if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU()))
+ }
+ if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ return nil
+}
+
+// HandlePacket is called by the link layer when new ipv4 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+ headerView := vv.First()
+ h := header.IPv4(headerView)
+ if !h.IsValid(vv.Size()) {
+ return
+ }
+
+ hlen := int(h.HeaderLength())
+ tlen := int(h.TotalLength())
+ vv.TrimFront(hlen)
+ vv.CapLength(tlen - hlen)
+
+ more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
+ if more || h.FragmentOffset() != 0 {
+ // The packet is a fragment, let's try to reassemble it.
+ last := h.FragmentOffset() + uint16(vv.Size()) - 1
+ var ready bool
+ vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ if !ready {
+ return
+ }
+ }
+ p := h.TransportProtocol()
+ if p == header.ICMPv4ProtocolNumber {
+ headerView.CapLength(hlen)
+ e.handleICMP(r, headerView, vv)
+ return
+ }
+ r.Stats().IP.PacketsDelivered.Increment()
+ e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (e *endpoint) Close() {}
+
+type protocol struct{}
+
+// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
+// only for tests that short-circuit the stack. Regular use of the protocol is
+// done via the stack, which gets a protocol descriptor from the init() function
+// below.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
+
+// Number returns the ipv4 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv4 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv4MinimumSize
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv4(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ return mtu - header.IPv4MinimumSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address, the transport protocol number, and a random initial
+// value (generated once on initialization) to generate the hash.
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+ t := r.LocalAddress
+ a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = r.RemoteAddress
+ b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return hash.Hash3Words(a, b, uint32(protocol), hashIV)
+}
+
+var (
+ ids []uint32
+ hashIV uint32
+)
+
+func init() {
+ ids = make([]uint32, buckets)
+
+ // Randomly initialize hashIV and the ids.
+ r := hash.RandN32(1 + buckets)
+ for i := range ids {
+ ids[i] = r[i]
+ }
+ hashIV = r[buckets]
+
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
new file mode 100755
index 000000000..6b2cc0142
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package ipv4
+
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
new file mode 100644
index 000000000..9c011e107
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -0,0 +1,297 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ h := header.IPv6(vv.First())
+
+ // We don't use IsValid() here because ICMP only requires that up to
+ // 1280 bytes of the original packet be included. So it's likely that it
+ // is truncated, which would cause IsValid to return false.
+ //
+ // Drop packet if it doesn't have the basic IPv6 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ // Skip the IP header, then handle the fragmentation header if there
+ // is one.
+ vv.TrimFront(header.IPv6MinimumSize)
+ p := h.TransportProtocol()
+ if p == header.IPv6FragmentHeader {
+ f := header.IPv6Fragment(vv.First())
+ if !f.IsValid() || f.FragmentOffset() != 0 {
+ // We can't handle fragments that aren't at offset 0
+ // because they don't have the transport headers.
+ return
+ }
+
+ // Skip fragmentation header and find out the actual protocol
+ // number.
+ vv.TrimFront(header.IPv6FragmentHeaderSize)
+ p = f.TransportProtocol()
+ }
+
+ // Deliver the control packet to the transport endpoint.
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ stats := r.Stats().ICMP
+ sent := stats.V6PacketsSent
+ received := stats.V6PacketsReceived
+ v := vv.First()
+ if len(v) < header.ICMPv6MinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.ICMPv6(v)
+
+ // TODO(b/112892170): Meaningfully handle all ICMP types.
+ switch h.Type() {
+ case header.ICMPv6PacketTooBig:
+ received.PacketTooBig.Increment()
+ if len(v) < header.ICMPv6PacketTooBigMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+
+ case header.ICMPv6DstUnreachable:
+ received.DstUnreachable.Increment()
+ if len(v) < header.ICMPv6DstUnreachableMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv6PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ }
+
+ case header.ICMPv6NeighborSolicit:
+ received.NeighborSolicit.Increment()
+
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+
+ if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ targetAddr := tcpip.Address(v[8:][:16])
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
+ // We don't have a useful answer; the best we can do is ignore the request.
+ return
+ }
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag
+ copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr)
+ pkt[icmpV6OptOffset] = ndpOptDstLinkAddr
+ pkt[icmpV6LengthOffset] = 1
+ copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:])
+
+ // ICMPv6 Neighbor Solicit messages are always sent to
+ // specially crafted IPv6 multicast addresses. As a result, the
+ // route we end up with here has as its LocalAddress such a
+ // multicast address. It would be nonsense to claim that our
+ // source address is a multicast address, so we manually set
+ // the source address to the target address requested in the
+ // solicit message. Since that requires mutating the route, we
+ // must first clone it.
+ r := r.Clone()
+ defer r.Release()
+ r.LocalAddress = targetAddr
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.NeighborAdvert.Increment()
+
+ case header.ICMPv6NeighborAdvert:
+ received.NeighborAdvert.Increment()
+ if len(v) < header.ICMPv6NeighborAdvertSize {
+ received.Invalid.Increment()
+ return
+ }
+ targetAddr := tcpip.Address(v[8:][:16])
+ e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
+ if targetAddr != r.RemoteAddress {
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ }
+
+ case header.ICMPv6EchoRequest:
+ received.EchoRequest.Increment()
+ if len(v) < header.ICMPv6EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ vv.TrimFront(header.ICMPv6EchoMinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(pkt, h)
+ pkt.SetType(header.ICMPv6EchoReply)
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
+ if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.EchoReply.Increment()
+
+ case header.ICMPv6EchoReply:
+ received.EchoReply.Increment()
+ if len(v) < header.ICMPv6EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv)
+
+ case header.ICMPv6TimeExceeded:
+ received.TimeExceeded.Increment()
+
+ case header.ICMPv6ParamProblem:
+ received.ParamProblem.Increment()
+
+ case header.ICMPv6RouterSolicit:
+ received.RouterSolicit.Increment()
+
+ case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
+
+ case header.ICMPv6RedirectMsg:
+ received.RedirectMsg.Increment()
+
+ default:
+ received.Invalid.Increment()
+ }
+}
+
+const (
+ ndpSolicitedFlag = 1 << 6
+ ndpOverrideFlag = 1 << 5
+
+ ndpOptSrcLinkAddr = 1
+ ndpOptDstLinkAddr = 2
+
+ icmpV6FlagOffset = 4
+ icmpV6OptOffset = 24
+ icmpV6LengthOffset = 25
+)
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+var _ stack.LinkAddressResolver = (*protocol)(nil)
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ snaddr := header.SolicitedNodeAddr(addr)
+ r := &stack.Route{
+ LocalAddress: localAddr,
+ RemoteAddress: snaddr,
+ RemoteLinkAddress: broadcastMAC,
+ }
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ copy(pkt[icmpV6OptOffset-len(addr):], addr)
+ pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
+ pkt[icmpV6LengthOffset] = 1
+ copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ length := uint16(hdr.UsedLength())
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: defaultIPv6HopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+
+ // TODO(stijlist): count this in ICMP stats.
+ return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
+ //
+ // 7. Address Mapping -- Multicast
+ //
+ // An IPv6 packet with a multicast destination address DST,
+ // consisting of the sixteen octets DST[1] through DST[16], is
+ // transmitted to the Ethernet multicast address whose first
+ // two octets are the value 3333 hexadecimal and whose last
+ // four octets are the last four octets of DST.
+ return tcpip.LinkAddress([]byte{
+ 0x33,
+ 0x33,
+ addr[header.IPv6AddressSize-4],
+ addr[header.IPv6AddressSize-3],
+ addr[header.IPv6AddressSize-2],
+ addr[header.IPv6AddressSize-1],
+ }), true
+ }
+ return "", false
+}
+
+func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := header.Checksum([]byte(src), 0)
+ xsum = header.Checksum([]byte(dst), xsum)
+ var upperLayerLength [4]byte
+ binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
+ xsum = header.Checksum(upperLayerLength[:], xsum)
+ xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum)
+ for _, v := range vv.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^header.Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
new file mode 100644
index 000000000..4b8cd496b
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -0,0 +1,207 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipv6 contains the implementation of the ipv6 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
+// network protocols when calling stack.New(). Then endpoints can be created
+// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv6
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ipv6 protocol name.
+ ProtocolName = "ipv6"
+
+ // ProtocolNumber is the ipv6 protocol number.
+ ProtocolNumber = header.IPv6ProtocolNumber
+
+ // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // PayloadLength field of the ipv6 header.
+ maxPayloadSize = 0xffff
+
+ // defaultIPv6HopLimit is the default hop limit for IPv6 Packets
+ // egressed by Netstack.
+ defaultIPv6HopLimit = 255
+)
+
+type endpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+ dispatcher stack.TransportDispatcher
+}
+
+// DefaultTTL is the default hop limit for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 255
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+// ID returns the ipv6 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+ length := uint16(hdr.UsedLength() + payload.Size())
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(protocol),
+ HopLimit: ttl,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+
+ if loop&stack.PacketLoop != 0 {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ e.HandlePacket(r, vv)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ r.Stats().IP.PacketsSent.Increment()
+ return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
+}
+
+// HandlePacket is called by the link layer when new ipv6 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+ headerView := vv.First()
+ h := header.IPv6(headerView)
+ if !h.IsValid(vv.Size()) {
+ return
+ }
+
+ vv.TrimFront(header.IPv6MinimumSize)
+ vv.CapLength(int(h.PayloadLength()))
+
+ p := h.TransportProtocol()
+ if p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, headerView, vv)
+ return
+ }
+
+ r.Stats().IP.PacketsDelivered.Increment()
+ e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (*endpoint) Close() {}
+
+type protocol struct{}
+
+// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
+// only for tests that short-circuit the stack. Regular use of the protocol is
+// done via the stack, which gets a protocol descriptor from the init() function
+// below.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
+
+// Number returns the ipv6 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv6 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint creates a new ipv6 endpoint.
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ return &endpoint{
+ nicid: nicid,
+ id: stack.NetworkEndpointID{LocalAddress: addr},
+ linkEP: linkEP,
+ linkAddrCache: linkAddrCache,
+ dispatcher: dispatcher,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ mtu -= header.IPv6MinimumSize
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func init() {
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
new file mode 100755
index 000000000..53319e0c4
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package ipv6
+
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
new file mode 100644
index 000000000..a1712b590
--- /dev/null
+++ b/pkg/tcpip/ports/ports.go
@@ -0,0 +1,209 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ports provides PortManager that manages allocating, reserving and releasing ports.
+package ports
+
+import (
+ "math"
+ "math/rand"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ // FirstEphemeral is the first ephemeral port.
+ FirstEphemeral = 16000
+
+ anyIPAddress tcpip.Address = ""
+)
+
+type portDescriptor struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+ port uint16
+}
+
+// PortManager manages allocating, reserving and releasing ports.
+type PortManager struct {
+ mu sync.RWMutex
+ allocatedPorts map[portDescriptor]bindAddresses
+}
+
+type portNode struct {
+ reuse bool
+ refs int
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]portNode
+
+// isAvailable checks whether an IP address is available to bind to.
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
+ if addr == anyIPAddress {
+ if len(b) == 0 {
+ return true
+ }
+ if !reuse {
+ return false
+ }
+ for _, n := range b {
+ if !n.reuse {
+ return false
+ }
+ }
+ return true
+ }
+
+ // If all addresses for this portDescriptor are already bound, no
+ // address is available.
+ if n, ok := b[anyIPAddress]; ok {
+ if !reuse {
+ return false
+ }
+ if !n.reuse {
+ return false
+ }
+ }
+
+ if n, ok := b[addr]; ok {
+ if !reuse {
+ return false
+ }
+ return n.reuse
+ }
+ return true
+}
+
+// NewPortManager creates new PortManager.
+func NewPortManager() *PortManager {
+ return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)}
+}
+
+// PickEphemeralPort randomly chooses a starting point and iterates over all
+// possible ephemeral ports, allowing the caller to decide whether a given port
+// is suitable for its needs, and stopping when a port is found or an error
+// occurs.
+func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ count := uint16(math.MaxUint16 - FirstEphemeral + 1)
+ offset := uint16(rand.Int31n(int32(count)))
+
+ for i := uint16(0); i < count; i++ {
+ port = FirstEphemeral + (offset+i)%count
+ ok, err := testPort(port)
+ if err != nil {
+ return 0, err
+ }
+
+ if ok {
+ return port, nil
+ }
+ }
+
+ return 0, tcpip.ErrNoPortAvailable
+}
+
+// IsPortAvailable tests if the given port is available on all given protocols.
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
+}
+
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if addrs, ok := s.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(addr, reuse) {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+// ReservePort marks a port/IP combination as reserved so that it cannot be
+// reserved by another endpoint. If port is zero, ReservePort will search for
+// an unreserved ephemeral port and reserve it, returning its value in the
+// "port" return value.
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If a port is specified, just try to reserve it for all network
+ // protocols.
+ if port != 0 {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
+ return 0, tcpip.ErrPortInUse
+ }
+ return port, nil
+ }
+
+ // A port wasn't specified, so try to find one.
+ return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
+ })
+}
+
+// reserveSpecificPort tries to reserve the given port on all given protocols.
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
+ return false
+ }
+
+ // Reserve port on all network protocols.
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ m, ok := s.allocatedPorts[desc]
+ if !ok {
+ m = make(bindAddresses)
+ s.allocatedPorts[desc] = m
+ }
+ if n, ok := m[addr]; ok {
+ n.refs++
+ m[addr] = n
+ } else {
+ m[addr] = portNode{reuse: reuse, refs: 1}
+ }
+ }
+
+ return true
+}
+
+// ReleasePort releases the reservation on a port/IP combination so that it can
+// be reserved by other endpoints.
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if m, ok := s.allocatedPorts[desc]; ok {
+ n, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n.refs--
+ if n.refs == 0 {
+ delete(m, addr)
+ } else {
+ m[addr] = n
+ }
+ if len(m) == 0 {
+ delete(s.allocatedPorts, desc)
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go
new file mode 100755
index 000000000..664cc3e71
--- /dev/null
+++ b/pkg/tcpip/ports/ports_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package ports
+
diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go
new file mode 100644
index 000000000..b40a3c212
--- /dev/null
+++ b/pkg/tcpip/seqnum/seqnum.go
@@ -0,0 +1,67 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seqnum defines the types and methods for TCP sequence numbers such
+// that they fit in 32-bit words and work properly when overflows occur.
+package seqnum
+
+// Value represents the value of a sequence number.
+type Value uint32
+
+// Size represents the size (length) of a sequence number window.
+type Size uint32
+
+// LessThan checks if v is before w, i.e., v < w.
+func (v Value) LessThan(w Value) bool {
+ return int32(v-w) < 0
+}
+
+// LessThanEq returns true if v==w or v is before i.e., v < w.
+func (v Value) LessThanEq(w Value) bool {
+ if v == w {
+ return true
+ }
+ return v.LessThan(w)
+}
+
+// InRange checks if v is in the range [a,b), i.e., a <= v < b.
+func (v Value) InRange(a, b Value) bool {
+ return v-a < b-a
+}
+
+// InWindow checks if v is in the window that starts at 'first' and spans 'size'
+// sequence numbers.
+func (v Value) InWindow(first Value, size Size) bool {
+ return v.InRange(first, first.Add(size))
+}
+
+// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
+func Overlap(a Value, b Size, x Value, y Size) bool {
+ return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
+}
+
+// Add calculates the sequence number following the [v, v+s) window.
+func (v Value) Add(s Size) Value {
+ return v + Value(s)
+}
+
+// Size calculates the size of the window defined by [v, w).
+func (v Value) Size(w Value) Size {
+ return Size(w - v)
+}
+
+// UpdateForward updates v such that it becomes v + s.
+func (v *Value) UpdateForward(s Size) {
+ *v += Value(s)
+}
diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go
new file mode 100755
index 000000000..bf76f6ac4
--- /dev/null
+++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package seqnum
+
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
new file mode 100644
index 000000000..b952ad20f
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const linkAddrCacheSize = 512 // max cache entries
+
+// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
+//
+// The entries are stored in a ring buffer, oldest entry replaced first.
+//
+// This struct is safe for concurrent use.
+type linkAddrCache struct {
+ // ageLimit is how long a cache entry is valid for.
+ ageLimit time.Duration
+
+ // resolutionTimeout is the amount of time to wait for a link request to
+ // resolve an address.
+ resolutionTimeout time.Duration
+
+ // resolutionAttempts is the number of times an address is attempted to be
+ // resolved before failing.
+ resolutionAttempts int
+
+ mu sync.Mutex
+ cache map[tcpip.FullAddress]*linkAddrEntry
+ next int // array index of next available entry
+ entries [linkAddrCacheSize]linkAddrEntry
+}
+
+// entryState controls the state of a single entry in the cache.
+type entryState int
+
+const (
+ // incomplete means that there is an outstanding request to resolve the
+ // address. This is the initial state.
+ incomplete entryState = iota
+ // ready means that the address has been resolved and can be used.
+ ready
+ // failed means that address resolution timed out and the address
+ // could not be resolved.
+ failed
+ // expired means that the cache entry has expired and the address must be
+ // resolved again.
+ expired
+)
+
+// String implements Stringer.
+func (s entryState) String() string {
+ switch s {
+ case incomplete:
+ return "incomplete"
+ case ready:
+ return "ready"
+ case failed:
+ return "failed"
+ case expired:
+ return "expired"
+ default:
+ return fmt.Sprintf("unknown(%d)", s)
+ }
+}
+
+// A linkAddrEntry is an entry in the linkAddrCache.
+// This struct is thread-compatible.
+type linkAddrEntry struct {
+ addr tcpip.FullAddress
+ linkAddr tcpip.LinkAddress
+ expiration time.Time
+ s entryState
+
+ // wakers is a set of waiters for address resolution result. Anytime
+ // state transitions out of 'incomplete' these waiters are notified.
+ wakers map[*sleep.Waker]struct{}
+
+ done chan struct{}
+}
+
+func (e *linkAddrEntry) state() entryState {
+ if e.s != expired && time.Now().After(e.expiration) {
+ // Force the transition to ensure waiters are notified.
+ e.changeState(expired)
+ }
+ return e.s
+}
+
+func (e *linkAddrEntry) changeState(ns entryState) {
+ if e.s == ns {
+ return
+ }
+
+ // Validate state transition.
+ switch e.s {
+ case incomplete:
+ // All transitions are valid.
+ case ready, failed:
+ if ns != expired {
+ panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
+ }
+ case expired:
+ // Terminal state.
+ panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
+ default:
+ panic(fmt.Sprintf("invalid state: %s", e.s))
+ }
+
+ // Notify whoever is waiting on address resolution when transitioning
+ // out of 'incomplete'.
+ if e.s == incomplete {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ if e.done != nil {
+ close(e.done)
+ }
+ }
+ e.s = ns
+}
+
+func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) {
+ if w != nil {
+ e.wakers[w] = struct{}{}
+ }
+}
+
+func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
+ delete(e.wakers, w)
+}
+
+// add adds a k -> v mapping to the cache.
+func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry, ok := c.cache[k]
+ if ok {
+ s := entry.state()
+ if s != expired && entry.linkAddr == v {
+ // Disregard repeated calls.
+ return
+ }
+ // Check if entry is waiting for address resolution.
+ if s == incomplete {
+ entry.linkAddr = v
+ } else {
+ // Otherwise create a new entry to replace it.
+ entry = c.makeAndAddEntry(k, v)
+ }
+ } else {
+ entry = c.makeAndAddEntry(k, v)
+ }
+
+ entry.changeState(ready)
+}
+
+// makeAndAddEntry is a helper function to create and add a new
+// entry to the cache map and evict older entry as needed.
+func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
+ // Take over the next entry.
+ entry := &c.entries[c.next]
+ if c.cache[entry.addr] == entry {
+ delete(c.cache, entry.addr)
+ }
+
+ // Mark the soon-to-be-replaced entry as expired, just in case there is
+ // someone waiting for address resolution on it.
+ entry.changeState(expired)
+
+ *entry = linkAddrEntry{
+ addr: k,
+ linkAddr: v,
+ expiration: time.Now().Add(c.ageLimit),
+ wakers: make(map[*sleep.Waker]struct{}),
+ done: make(chan struct{}),
+ }
+
+ c.cache[k] = entry
+ c.next = (c.next + 1) % len(c.entries)
+ return entry
+}
+
+// get reports any known link address for k.
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+ if linkRes != nil {
+ if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ return addr, nil, nil
+ }
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if entry, ok := c.cache[k]; ok {
+ switch s := entry.state(); s {
+ case expired:
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return "", nil, tcpip.ErrNoLinkAddress
+ case incomplete:
+ // Address resolution is still in progress.
+ entry.maybeAddWaker(waker)
+ return "", entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
+ }
+
+ if linkRes == nil {
+ return "", nil, tcpip.ErrNoLinkAddress
+ }
+
+ // Add 'incomplete' entry in the cache to mark that resolution is in progress.
+ e := c.makeAndAddEntry(k, "")
+ e.maybeAddWaker(waker)
+
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+
+ return "", e.done, tcpip.ErrWouldBlock
+}
+
+// removeWaker removes a waker previously added through get().
+func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if entry, ok := c.cache[k]; ok {
+ entry.removeWaker(waker)
+ }
+}
+
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) {
+ for i := 0; ; i++ {
+ // Send link request, then wait for the timeout limit and check
+ // whether the request succeeded.
+ linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+
+ select {
+ case <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(k, i); stop {
+ return
+ }
+ case <-done:
+ return
+ }
+ }
+}
+
+// checkLinkRequest checks whether previous attempt to resolve address has succeeded
+// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
+// can stop, false if another request should be sent.
+func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry, ok := c.cache[k]
+ if !ok {
+ // Entry was evicted from the cache.
+ return true
+ }
+
+ switch s := entry.state(); s {
+ case ready, failed, expired:
+ // Entry was made ready by resolver or failed. Either way we're done.
+ return true
+ case incomplete:
+ if attempt+1 >= c.resolutionAttempts {
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed)
+ return true
+ }
+ // No response yet, need to send another ARP request.
+ return false
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
+}
+
+func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+ return &linkAddrCache{
+ ageLimit: ageLimit,
+ resolutionTimeout: resolutionTimeout,
+ resolutionAttempts: resolutionAttempts,
+ cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
new file mode 100644
index 000000000..50d35de88
--- /dev/null
+++ b/pkg/tcpip/stack/nic.go
@@ -0,0 +1,728 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/ilist"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+// NIC represents a "network interface card" to which the networking stack is
+// attached.
+type NIC struct {
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ loopback bool
+
+ demux *transportDemuxer
+
+ mu sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ subnets []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
+
+ stats NICStats
+}
+
+// NICStats includes transmitted and received stats.
+type NICStats struct {
+ Tx DirectionStats
+ Rx DirectionStats
+}
+
+// DirectionStats includes packet and byte counts.
+type DirectionStats struct {
+ Packets *tcpip.StatCounter
+ Bytes *tcpip.StatCounter
+}
+
+// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
+type PrimaryEndpointBehavior int
+
+const (
+ // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
+ // endpoint for new connections with no local address. This is the
+ // default when calling NIC.AddAddress.
+ CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
+
+ // FirstPrimaryEndpoint indicates the endpoint should be the first
+ // primary endpoint considered. If there are multiple endpoints with
+ // this behavior, the most recently-added one will be first.
+ FirstPrimaryEndpoint
+
+ // NeverPrimaryEndpoint indicates the endpoint should never be a
+ // primary endpoint.
+ NeverPrimaryEndpoint
+)
+
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
+ return &NIC{
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ loopback: loopback,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ mcastJoins: make(map[NetworkEndpointID]int32),
+ stats: NICStats{
+ Tx: DirectionStats{
+ Packets: &tcpip.StatCounter{},
+ Bytes: &tcpip.StatCounter{},
+ },
+ Rx: DirectionStats{
+ Packets: &tcpip.StatCounter{},
+ Bytes: &tcpip.StatCounter{},
+ },
+ },
+ }
+}
+
+// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
+// to start delivering packets.
+func (n *NIC) attachLinkEndpoint() {
+ n.linkEP.Attach(n)
+}
+
+// setPromiscuousMode enables or disables promiscuous mode.
+func (n *NIC) setPromiscuousMode(enable bool) {
+ n.mu.Lock()
+ n.promiscuous = enable
+ n.mu.Unlock()
+}
+
+func (n *NIC) isPromiscuousMode() bool {
+ n.mu.RLock()
+ rv := n.promiscuous
+ n.mu.RUnlock()
+ return rv
+}
+
+// setSpoofing enables or disables address spoofing.
+func (n *NIC) setSpoofing(enable bool) {
+ n.mu.Lock()
+ n.spoofing = enable
+ n.mu.Unlock()
+}
+
+func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var r *referencedNetworkEndpoint
+
+ // Check for a primary endpoint.
+ if list, ok := n.primary[protocol]; ok {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ if ref.holdsInsertRef && ref.tryIncRef() {
+ r = ref
+ break
+ }
+ }
+
+ }
+
+ if r == nil {
+ return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress
+ }
+
+ address := r.ep.ID().LocalAddress
+ r.decRef()
+
+ // Find the least-constrained matching subnet for the address, if one
+ // exists, and return it.
+ var subnet tcpip.Subnet
+ for _, s := range n.subnets {
+ if s.Contains(address) && !subnet.Contains(s.ID()) {
+ subnet = s
+ }
+ }
+ return address, subnet, nil
+}
+
+// primaryEndpoint returns the primary endpoint of n for the given network
+// protocol.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list := n.primary[protocol]
+ if list == nil {
+ return nil
+ }
+
+ for e := list.Front(); e != nil; e = e.Next() {
+ r := e.(*referencedNetworkEndpoint)
+ // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set.
+ switch r.ep.ID().LocalAddress {
+ case header.IPv4Broadcast, header.IPv4Any:
+ continue
+ }
+ if r.tryIncRef() {
+ return r
+ }
+ }
+
+ return nil
+}
+
+// findEndpoint finds the endpoint, if any, with the given address.
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ id := NetworkEndpointID{address}
+
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ spoofing := n.spoofing
+ n.mu.RUnlock()
+
+ if ref != nil || !spoofing {
+ return ref
+ }
+
+ // Try again with the lock in exclusive mode. If we still can't get the
+ // endpoint, create a new "temporary" endpoint. It will only exist while
+ // there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, address, peb, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
+ return ref
+}
+
+func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ // Create the new network endpoint.
+ ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
+ if err != nil {
+ return nil, err
+ }
+
+ id := *ep.ID()
+ if ref, ok := n.endpoints[id]; ok {
+ if !replace {
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ n.removeEndpointLocked(ref)
+ }
+
+ ref := &referencedNetworkEndpoint{
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocol,
+ holdsInsertRef: true,
+ }
+
+ // Set up cache if link address resolution exists for this protocol.
+ if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if _, ok := n.stack.linkAddrResolvers[protocol]; ok {
+ ref.linkCache = n.stack
+ }
+ }
+
+ n.endpoints[id] = ref
+
+ l, ok := n.primary[protocol]
+ if !ok {
+ l = &ilist.List{}
+ n.primary[protocol] = l
+ }
+
+ switch peb {
+ case CanBePrimaryEndpoint:
+ l.PushBack(ref)
+ case FirstPrimaryEndpoint:
+ l.PushFront(ref)
+ }
+
+ return ref, nil
+}
+
+// AddAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint)
+}
+
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
+func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+ // Add the endpoint.
+ n.mu.Lock()
+ _, err := n.addAddressLocked(protocol, addr, peb, false)
+ n.mu.Unlock()
+
+ return err
+}
+
+// Addresses returns the addresses associated with this NIC.
+func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
+ for nid, ep := range n.endpoints {
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: ep.protocol,
+ Address: nid.LocalAddress,
+ })
+ }
+ return addrs
+}
+
+// AddSubnet adds a new subnet to n, so that it starts accepting packets
+// targeted at the given address and network protocol.
+func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+ n.mu.Lock()
+ n.subnets = append(n.subnets, subnet)
+ n.mu.Unlock()
+}
+
+// RemoveSubnet removes the given subnet from n.
+func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
+ n.mu.Lock()
+
+ // Use the same underlying array.
+ tmp := n.subnets[:0]
+ for _, sub := range n.subnets {
+ if sub != subnet {
+ tmp = append(tmp, sub)
+ }
+ }
+ n.subnets = tmp
+
+ n.mu.Unlock()
+}
+
+// ContainsSubnet reports whether this NIC contains the given subnet.
+func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
+ for _, s := range n.Subnets() {
+ if s == subnet {
+ return true
+ }
+ }
+ return false
+}
+
+// Subnets returns the Subnets associated with this NIC.
+func (n *NIC) Subnets() []tcpip.Subnet {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
+ for nid := range n.endpoints {
+ sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
+ if err != nil {
+ // This should never happen as the mask has been carefully crafted to
+ // match the address.
+ panic("Invalid endpoint subnet: " + err.Error())
+ }
+ sns = append(sns, sn)
+ }
+ return append(sns, n.subnets...)
+}
+
+func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
+ id := *r.ep.ID()
+
+ // Nothing to do if the reference has already been replaced with a
+ // different one.
+ if n.endpoints[id] != r {
+ return
+ }
+
+ if r.holdsInsertRef {
+ panic("Reference count dropped to zero before being removed")
+ }
+
+ delete(n.endpoints, id)
+ wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
+ if wasInList {
+ n.primary[r.protocol].Remove(r)
+ }
+
+ r.ep.Close()
+}
+
+func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
+ n.mu.Lock()
+ n.removeEndpointLocked(r)
+ n.mu.Unlock()
+}
+
+func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
+ r := n.endpoints[NetworkEndpointID{addr}]
+ if r == nil || !r.holdsInsertRef {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ r.holdsInsertRef = false
+
+ r.decRefLocked()
+
+ return nil
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ return n.removeAddressLocked(addr)
+}
+
+// joinGroup adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count.
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ if joins == 0 {
+ if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins + 1
+ return nil
+}
+
+// leaveGroup decrements the count for the given multicast address, and when it
+// reaches zero removes the endpoint for this address.
+func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ switch joins {
+ case 0:
+ // There are no joins with this address on this NIC.
+ return tcpip.ErrBadLocalAddress
+ case 1:
+ // This is the last one, clean up.
+ if err := n.removeAddressLocked(addr); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins - 1
+ return nil
+}
+
+// DeliverNetworkPacket finds the appropriate network protocol endpoint and
+// hands the packet over for further processing. This function is called when
+// the NIC receives a packet from the physical interface.
+// Note that the ownership of the slice backing vv is retained by the caller.
+// This rule applies only to the slice itself, not to the items of the slice;
+// the ownership of the items is not retained by the caller.
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ n.stats.Rx.Packets.Increment()
+ n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size()))
+
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ return
+ }
+
+ if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
+ n.stack.stats.IP.PacketsReceived.Increment()
+ }
+
+ if len(vv.First()) < netProto.MinimumPacketSize() {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+
+ src, dst := netProto.ParseAddresses(vv.First())
+
+ // If the packet is destined to the IPv4 Broadcast address, then make a
+ // route to each IPv4 network endpoint and let each endpoint handle the
+ // packet.
+ if dst == header.IPv4Broadcast {
+ // n.endpoints is mutex protected so acquire lock.
+ n.mu.RLock()
+ for _, ref := range n.endpoints {
+ if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remote
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+ }
+ }
+ n.mu.RUnlock()
+ return
+ }
+
+ if ref := n.getRef(protocol, dst); ref != nil {
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remote
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+ return
+ }
+
+ // This NIC doesn't care about the packet. Find a NIC that cares about the
+ // packet and forward it to the NIC.
+ //
+ // TODO: Should we be forwarding the packet even if promiscuous?
+ if n.stack.Forwarding() {
+ r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
+ if err != nil {
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ return
+ }
+ defer r.Release()
+
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remote
+
+ // Found a NIC.
+ n := r.ref.nic
+ n.mu.RLock()
+ ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ n.mu.RUnlock()
+ if ok && ref.tryIncRef() {
+ r.RemoteAddress = src
+ // TODO(b/123449044): Update the source NIC as well.
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+ } else {
+ // n doesn't have a destination endpoint.
+ // Send the packet out of n.
+ hdr := buffer.NewPrependableFromView(vv.First())
+ vv.RemoveFirst()
+
+ // TODO(b/128629022): use route.WritePacket.
+ if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ } else {
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size()))
+ }
+ }
+ return
+ }
+
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+}
+
+func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
+ id := NetworkEndpointID{dst}
+
+ n.mu.RLock()
+ if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+
+ promiscuous := n.promiscuous
+ // Check if the packet is for a subnet this NIC cares about.
+ if !promiscuous {
+ for _, sn := range n.subnets {
+ if sn.Contains(dst) {
+ promiscuous = true
+ break
+ }
+ }
+ }
+ n.mu.RUnlock()
+ if promiscuous {
+ // Try again with the lock in exclusive mode. If we still can't
+ // get the endpoint, create a new "temporary" one. It will only
+ // exist while there's a route through it.
+ n.mu.Lock()
+ if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
+ }
+ ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true)
+ n.mu.Unlock()
+ if err == nil {
+ ref.holdsInsertRef = false
+ return ref
+ }
+ }
+
+ return nil
+}
+
+// DeliverTransportPacket delivers the packets to the appropriate transport
+// protocol endpoint.
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[protocol]
+ if !ok {
+ n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ return
+ }
+
+ transProto := state.proto
+
+ // Raw socket packets are delivered based solely on the transport
+ // protocol number. We do not inspect the payload to ensure it's
+ // validly formed.
+ if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
+ n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
+ }
+
+ if len(vv.First()) < transProto.MinimumPacketSize() {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+
+ id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
+ if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
+ return
+ }
+
+ // Try to deliver to per-stack default handler.
+ if state.defaultHandler != nil {
+ if state.defaultHandler(r, id, netHeader, vv) {
+ return
+ }
+ }
+
+ // We could not find an appropriate destination for this packet, so
+ // deliver it to the global handler.
+ if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ }
+}
+
+// DeliverTransportControlPacket delivers control packets to the appropriate
+// transport protocol endpoint.
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[trans]
+ if !ok {
+ return
+ }
+
+ transProto := state.proto
+
+ // ICMPv4 only guarantees that 8 bytes of the transport protocol will
+ // be present in the payload. We know that the ports are within the
+ // first 8 bytes for all known transport protocols.
+ if len(vv.First()) < 8 {
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ return
+ }
+
+ id := TransportEndpointID{srcPort, local, dstPort, remote}
+ if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+}
+
+// ID returns the identifier of n.
+func (n *NIC) ID() tcpip.NICID {
+ return n.id
+}
+
+type referencedNetworkEndpoint struct {
+ ilist.Entry
+ refs int32
+ ep NetworkEndpoint
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkCache is set if link address resolution is enabled for this
+ // protocol. Set to nil otherwise.
+ linkCache LinkAddressCache
+
+ // holdsInsertRef is protected by the NIC's mutex. It indicates whether
+ // the reference count is biased by 1 due to the insertion of the
+ // endpoint. It is reset to false when RemoveAddress is called on the
+ // NIC.
+ holdsInsertRef bool
+}
+
+// decRef decrements the ref count and cleans up the endpoint once it reaches
+// zero.
+func (r *referencedNetworkEndpoint) decRef() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpoint(r)
+ }
+}
+
+// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpointLocked(r)
+ }
+}
+
+// incRef increments the ref count. It must only be called when the caller is
+// known to be holding a reference to the endpoint, otherwise tryIncRef should
+// be used.
+func (r *referencedNetworkEndpoint) incRef() {
+ atomic.AddInt32(&r.refs, 1)
+}
+
+// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
+// not zero. That is, it will increment the count if the endpoint is still
+// alive, and do nothing if it has already been clean up.
+func (r *referencedNetworkEndpoint) tryIncRef() bool {
+ for {
+ v := atomic.LoadInt32(&r.refs)
+ if v == 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
+ return true
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
new file mode 100644
index 000000000..c70533a35
--- /dev/null
+++ b/pkg/tcpip/stack/registration.go
@@ -0,0 +1,441 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// NetworkEndpointID is the identifier of a network layer protocol endpoint.
+// Currently the local address is sufficient because all supported protocols
+// (i.e., IPv4 and IPv6) have different sizes for their addresses.
+type NetworkEndpointID struct {
+ LocalAddress tcpip.Address
+}
+
+// TransportEndpointID is the identifier of a transport layer protocol endpoint.
+//
+// +stateify savable
+type TransportEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// ControlType is the type of network control message.
+type ControlType int
+
+// The following are the allowed values for ControlType values.
+const (
+ ControlPacketTooBig ControlType = iota
+ ControlPortUnreachable
+ ControlUnknown
+)
+
+// TransportEndpoint is the interface that needs to be implemented by transport
+// protocol (e.g., tcp, udp) endpoints that can handle packets.
+type TransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint.
+ HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+
+ // HandleControlPacket is called by the stack when new control (e.g.,
+ // ICMP) packets arrive to this transport endpoint.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
+}
+
+// RawTransportEndpoint is the interface that needs to be implemented by raw
+// transport protocol endpoints. RawTransportEndpoints receive the entire
+// packet - including the link, network, and transport headers - as delivered
+// to netstack.
+type RawTransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint. The packet contains all data from the link
+ // layer up.
+ HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
+}
+
+// TransportProtocol is the interface that needs to be implemented by transport
+// protocols (e.g., tcp, udp) that want to be part of the networking stack.
+type TransportProtocol interface {
+ // Number returns the transport protocol number.
+ Number() tcpip.TransportProtocolNumber
+
+ // NewEndpoint creates a new endpoint of the transport protocol.
+ NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // NewRawEndpoint creates a new raw endpoint of the transport protocol.
+ NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // transport protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination ports stored in a
+ // packet of this protocol.
+ ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
+
+ // HandleUnknownDestinationPacket handles packets targeted at this
+ // protocol but that don't match any existing endpoint. For example,
+ // it is targeted at a port that have no listeners.
+ //
+ // The return value indicates whether the packet was well-formed (for
+ // stats purposes only).
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+}
+
+// TransportDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate transport endpoint after it has been handled by
+// the network layer.
+type TransportDispatcher interface {
+ // DeliverTransportPacket delivers packets to the appropriate
+ // transport protocol endpoint. It also returns the network layer
+ // header for the enpoint to inspect or pass up the stack.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
+
+ // DeliverTransportControlPacket delivers control packets to the
+ // appropriate transport protocol endpoint.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
+}
+
+// PacketLooping specifies where an outbound packet should be sent.
+type PacketLooping byte
+
+const (
+ // PacketOut indicates that the packet should be passed to the link
+ // endpoint.
+ PacketOut PacketLooping = 1 << iota
+
+ // PacketLoop indicates that the packet should be handled locally.
+ PacketLoop
+)
+
+// NetworkEndpoint is the interface that needs to be implemented by endpoints
+// of network layer protocols (e.g., ipv4, ipv6).
+type NetworkEndpoint interface {
+ // DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
+ // for this endpoint.
+ DefaultTTL() uint8
+
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // generally calculated as the MTU of the underlying data link endpoint
+ // minus the network endpoint max header length.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // underlying link-layer endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the network (and lower
+ // level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // WritePacket writes a packet to the given destination address and
+ // protocol.
+ WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
+
+ // ID returns the network protocol endpoint ID.
+ ID() *NetworkEndpointID
+
+ // NICID returns the id of the NIC this endpoint belongs to.
+ NICID() tcpip.NICID
+
+ // HandlePacket is called by the link layer when new packets arrive to
+ // this network endpoint.
+ HandlePacket(r *Route, vv buffer.VectorisedView)
+
+ // Close is called when the endpoint is reomved from a stack.
+ Close()
+}
+
+// NetworkProtocol is the interface that needs to be implemented by network
+// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
+type NetworkProtocol interface {
+ // Number returns the network protocol number.
+ Number() tcpip.NetworkProtocolNumber
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // network protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination addresses stored in a
+ // packet of this protocol.
+ ParseAddresses(v buffer.View) (src, dst tcpip.Address)
+
+ // NewEndpoint creates a new endpoint of this protocol.
+ NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+}
+
+// NetworkDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate network endpoint after it has been handled by
+// the data link layer.
+type NetworkDispatcher interface {
+ // DeliverNetworkPacket finds the appropriate network protocol
+ // endpoint and hands the packet over for further processing.
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+}
+
+// LinkEndpointCapabilities is the type associated with the capabilities
+// supported by a link-layer endpoint. It is a set of bitfields.
+type LinkEndpointCapabilities uint
+
+// The following are the supported link endpoint capabilities.
+const (
+ CapabilityNone LinkEndpointCapabilities = 0
+ // CapabilityTXChecksumOffload indicates that the link endpoint supports
+ // checksum computation for outgoing packets and the stack can skip
+ // computing checksums when sending packets.
+ CapabilityTXChecksumOffload LinkEndpointCapabilities = 1 << iota
+ // CapabilityRXChecksumOffload indicates that the link endpoint supports
+ // checksum verification on received packets and that it's safe for the
+ // stack to skip checksum verification.
+ CapabilityRXChecksumOffload
+ CapabilityResolutionRequired
+ CapabilitySaveRestore
+ CapabilityDisconnectOk
+ CapabilityLoopback
+ CapabilityGSO
+)
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint.
+type LinkEndpoint interface {
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // usually dictated by the backing physical network; when such a
+ // physical network doesn't exist, the limit is generally 64k, which
+ // includes the maximum size of an IP packet.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the data link (and
+ // lower level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // LinkAddress returns the link address (typically a MAC) of the
+ // link endpoint.
+ LinkAddress() tcpip.LinkAddress
+
+ // WritePacket writes a packet with the given protocol through the given
+ // route.
+ //
+ // To participate in transparent bridging, a LinkEndpoint implementation
+ // should call eth.Encode with header.EthernetFields.SrcAddr set to
+ // r.LocalLinkAddress if it is provided.
+ WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+
+ // Attach attaches the data link layer endpoint to the network-layer
+ // dispatcher of the stack.
+ Attach(dispatcher NetworkDispatcher)
+
+ // IsAttached returns whether a NetworkDispatcher is attached to the
+ // endpoint.
+ IsAttached() bool
+}
+
+// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
+// delivered via the Inject method.
+type InjectableLinkEndpoint interface {
+ LinkEndpoint
+
+ // Inject injects an inbound packet.
+ Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+
+ // WriteRawPacket writes a fully formed outbound packet directly to the link.
+ //
+ // dest is used by endpoints with multiple raw destinations.
+ WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error
+}
+
+// A LinkAddressResolver is an extension to a NetworkProtocol that
+// can resolve link addresses.
+type LinkAddressResolver interface {
+ // LinkAddressRequest sends a request for the LinkAddress of addr.
+ // The request is sent on linkEP with localAddr as the source.
+ //
+ // A valid response will cause the discovery protocol's network
+ // endpoint to call AddLinkAddress.
+ LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+
+ // ResolveStaticAddress attempts to resolve address without sending
+ // requests. It either resolves the name immediately or returns the
+ // empty LinkAddress.
+ //
+ // It can be used to resolve broadcast addresses for example.
+ ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
+
+ // LinkAddressProtocol returns the network protocol of the
+ // addresses this this resolver can resolve.
+ LinkAddressProtocol() tcpip.NetworkProtocolNumber
+}
+
+// A LinkAddressCache caches link addresses.
+type LinkAddressCache interface {
+ // CheckLocalAddress determines if the given local address exists, and if it
+ // does not exist.
+ CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
+
+ // AddLinkAddress adds a link address to the cache.
+ AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+
+ // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
+ // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
+ // registered with the network protocol, the cache attempts to resolve the address
+ // and returns ErrWouldBlock. Waker is notified when address resolution is
+ // complete (success or not).
+ //
+ // If address resolution is required, ErrNoLinkAddress and a notification channel is
+ // returned for the top level caller to block. Channel is closed once address resolution
+ // is complete (success or not).
+ GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
+
+ // RemoveWaker removes a waker that has been added in GetLinkAddress().
+ RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+}
+
+// TransportProtocolFactory functions are used by the stack to instantiate
+// transport protocols.
+type TransportProtocolFactory func() TransportProtocol
+
+// NetworkProtocolFactory provides methods to be used by the stack to
+// instantiate network protocols.
+type NetworkProtocolFactory func() NetworkProtocol
+
+var (
+ transportProtocols = make(map[string]TransportProtocolFactory)
+ networkProtocols = make(map[string]NetworkProtocolFactory)
+
+ linkEPMu sync.RWMutex
+ nextLinkEndpointID tcpip.LinkEndpointID = 1
+ linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
+)
+
+// RegisterTransportProtocolFactory registers a new transport protocol factory
+// with the stack so that it becomes available to users of the stack. This
+// function is intended to be called by init() functions of the protocols.
+func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
+ transportProtocols[name] = p
+}
+
+// RegisterNetworkProtocolFactory registers a new network protocol factory with
+// the stack so that it becomes available to users of the stack. This function
+// is intended to be called by init() functions of the protocols.
+func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
+ networkProtocols[name] = p
+}
+
+// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
+// ID that can be used to refer to it.
+func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
+ linkEPMu.Lock()
+ defer linkEPMu.Unlock()
+
+ v := nextLinkEndpointID
+ nextLinkEndpointID++
+
+ linkEndpoints[v] = linkEP
+
+ return v
+}
+
+// FindLinkEndpoint finds the link endpoint associated with the given ID.
+func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
+ linkEPMu.RLock()
+ defer linkEPMu.RUnlock()
+
+ return linkEndpoints[id]
+}
+
+// GSOType is the type of GSO segments.
+//
+// +stateify savable
+type GSOType int
+
+// Types of gso segments.
+const (
+ GSONone GSOType = iota
+ GSOTCPv4
+ GSOTCPv6
+)
+
+// GSO contains generic segmentation offload properties.
+//
+// +stateify savable
+type GSO struct {
+ // Type is one of GSONone, GSOTCPv4, etc.
+ Type GSOType
+ // NeedsCsum is set if the checksum offload is enabled.
+ NeedsCsum bool
+ // CsumOffset is offset after that to place checksum.
+ CsumOffset uint16
+
+ // Mss is maximum segment size.
+ MSS uint16
+ // L3Len is L3 (IP) header length.
+ L3HdrLen uint16
+
+ // MaxSize is maximum GSO packet size.
+ MaxSize uint32
+}
+
+// GSOEndpoint provides access to GSO properties.
+type GSOEndpoint interface {
+ // GSOMaxSize returns the maximum GSO packet size.
+ GSOMaxSize() uint32
+}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
new file mode 100644
index 000000000..3d4c282a9
--- /dev/null
+++ b/pkg/tcpip/stack/route.go
@@ -0,0 +1,189 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+// Route represents a route through the networking stack to a given destination.
+type Route struct {
+ // RemoteAddress is the final destination of the route.
+ RemoteAddress tcpip.Address
+
+ // RemoteLinkAddress is the link-layer (MAC) address of the
+ // final destination of the route.
+ RemoteLinkAddress tcpip.LinkAddress
+
+ // LocalAddress is the local address where the route starts.
+ LocalAddress tcpip.Address
+
+ // LocalLinkAddress is the link-layer (MAC) address of the
+ // where the route starts.
+ LocalLinkAddress tcpip.LinkAddress
+
+ // NextHop is the next node in the path to the destination.
+ NextHop tcpip.Address
+
+ // NetProto is the network-layer protocol.
+ NetProto tcpip.NetworkProtocolNumber
+
+ // ref a reference to the network endpoint through which the route
+ // starts.
+ ref *referencedNetworkEndpoint
+
+ // loop controls where WritePacket should send packets.
+ loop PacketLooping
+}
+
+// makeRoute initializes a new route. It takes ownership of the provided
+// reference to a network endpoint.
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route {
+ loop := PacketOut
+ if handleLocal && localAddr != "" && remoteAddr == localAddr {
+ loop = PacketLoop
+ } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
+ loop |= PacketLoop
+ }
+
+ return Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: localLinkAddr,
+ RemoteAddress: remoteAddr,
+ ref: ref,
+ loop: loop,
+ }
+}
+
+// NICID returns the id of the NIC from which this route originates.
+func (r *Route) NICID() tcpip.NICID {
+ return r.ref.ep.NICID()
+}
+
+// MaxHeaderLength forwards the call to the network endpoint's implementation.
+func (r *Route) MaxHeaderLength() uint16 {
+ return r.ref.ep.MaxHeaderLength()
+}
+
+// Stats returns a mutable copy of current stats.
+func (r *Route) Stats() tcpip.Stats {
+ return r.ref.nic.stack.Stats()
+}
+
+// PseudoHeaderChecksum forwards the call to the network endpoint's
+// implementation.
+func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, totalLen uint16) uint16 {
+ return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
+}
+
+// Capabilities returns the link-layer capabilities of the route.
+func (r *Route) Capabilities() LinkEndpointCapabilities {
+ return r.ref.ep.Capabilities()
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (r *Route) GSOMaxSize() uint32 {
+ if gso, ok := r.ref.ep.(GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
+// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
+// notified when address resolution is complete (success or not).
+//
+// If address resolution is required, ErrNoLinkAddress and a notification channel is
+// returned for the top level caller to block. Channel is closed once address resolution
+// is complete (success or not).
+func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
+ if !r.IsResolutionRequired() {
+ // Nothing to do if there is no cache (which does the resolution on cache miss) or
+ // link address is already known.
+ return nil, nil
+ }
+
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ // Local link address is already known.
+ if r.RemoteAddress == r.LocalAddress {
+ r.RemoteLinkAddress = r.LocalLinkAddress
+ return nil, nil
+ }
+ nextAddr = r.RemoteAddress
+ }
+ linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ if err != nil {
+ return ch, err
+ }
+ r.RemoteLinkAddress = linkAddr
+ return nil, nil
+}
+
+// RemoveWaker removes a waker that has been added in Resolve().
+func (r *Route) RemoveWaker(waker *sleep.Waker) {
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+ r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+}
+
+// IsResolutionRequired returns true if Resolve() must be called to resolve
+// the link address before the this route can be written to.
+func (r *Route) IsResolutionRequired() bool {
+ return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+}
+
+// WritePacket writes the packet through the given route.
+func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ } else {
+ r.ref.nic.stats.Tx.Packets.Increment()
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size()))
+ }
+ return err
+}
+
+// DefaultTTL returns the default TTL of the underlying network endpoint.
+func (r *Route) DefaultTTL() uint8 {
+ return r.ref.ep.DefaultTTL()
+}
+
+// MTU returns the MTU of the underlying network endpoint.
+func (r *Route) MTU() uint32 {
+ return r.ref.ep.MTU()
+}
+
+// Release frees all resources associated with the route.
+func (r *Route) Release() {
+ if r.ref != nil {
+ r.ref.decRef()
+ r.ref = nil
+ }
+}
+
+// Clone Clone a route such that the original one can be released and the new
+// one will remain valid.
+func (r *Route) Clone() Route {
+ r.ref.incRef()
+ return *r
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
new file mode 100644
index 000000000..9d8e8cda5
--- /dev/null
+++ b/pkg/tcpip/stack/stack.go
@@ -0,0 +1,1095 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package stack provides the glue between networking protocols and the
+// consumers of the networking stack.
+//
+// For consumers, the only function of interest is New(), everything else is
+// provided by the tcpip/public package.
+//
+// For protocol implementers, RegisterTransportProtocolFactory() and
+// RegisterNetworkProtocolFactory() are used to register protocol factories with
+// the stack, which will then be used to instantiate protocol objects when
+// consumers interact with the stack.
+package stack
+
+import (
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/ports"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ageLimit is set to the same cache stale time used in Linux.
+ ageLimit = 1 * time.Minute
+ // resolutionTimeout is set to the same ARP timeout used in Linux.
+ resolutionTimeout = 1 * time.Second
+ // resolutionAttempts is set to the same ARP retries used in Linux.
+ resolutionAttempts = 3
+)
+
+type transportProtocolState struct {
+ proto TransportProtocol
+ defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
+}
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPCubicState is used to hold a copy of the internal cubic state when the
+// TCPProbeFunc is invoked.
+type TCPCubicState struct {
+ WLastMax float64
+ WMax float64
+ T time.Time
+ TimeSinceLastCongestion time.Duration
+ C float64
+ K float64
+ Beta float64
+ WC float64
+ WEst float64
+}
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ MaxCwnd int
+
+ // HighRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase.
+ // See: RFC 6675 Section 2 for details.
+ HighRxt seqnum.Value
+
+ // RescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission.
+ // See: RFC 6675 Section 2 for details.
+ RescueRxt seqnum.Value
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for
+// a given TCP endpoint.
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is the TCP variable RCV.ACC.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive
+ // queue.
+ PendingBufUsed seqnum.Size
+
+ // PendingBufSize is the size of the socket receive buffer.
+ PendingBufSize seqnum.Size
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for
+// a given TCP Endpoint.
+type TCPSenderState struct {
+ // LastSendTime is the time at which we sent the last segment.
+ LastSendTime time.Time
+
+ // DupAckCount is the number of Duplicate ACK's received.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the slow start threshold in packets.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets consumed in congestion
+ // avoidance mode.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets in flight.
+ Outstanding int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time
+
+ // Closed indicates that the caller has closed the endpoint for sending.
+ Closed bool
+
+ // SRTT is the smoothed round-trip time as defined in section 2 of
+ // RFC 6298.
+ SRTT time.Duration
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
+ RTO time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates take a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+
+ // MaxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgement number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+
+ // Cubic holds the state related to CUBIC congestion control.
+ Cubic TCPCubicState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK Blocks that identify the out of order segments
+ // held by a given TCP endpoint.
+ Blocks []header.SACKBlock
+
+ // ReceivedBlocks are the SACK blocks received by this endpoint
+ // from the peer endpoint.
+ ReceivedBlocks []header.SACKBlock
+
+ // MaxSACKED is the highest sequence number that has been SACKED
+ // by the peer.
+ MaxSACKED seqnum.Value
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+type TCPEndpointState struct {
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time
+
+ // RcvBufSize is the size of the receive socket buffer for the endpoint.
+ RcvBufSize int
+
+ // RcvBufUsed is the amount of bytes actually held in the receive socket
+ // buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvClosed if true, indicates the endpoint has been closed for reading.
+ RcvClosed bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ RecentTS uint32
+
+ // TSOffset is a randomized offset added to the value of the TSVal field
+ // in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+
+ // Receiver holds variables related to the TCP receiver for the endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
+}
+
+// Stack is a networking stack, with all supported protocols, NICs, and route
+// table.
+type Stack struct {
+ transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState
+ networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
+ linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+
+ demux *transportDemuxer
+
+ stats tcpip.Stats
+
+ linkAddrCache *linkAddrCache
+
+ // raw indicates whether raw sockets may be created. It is set during
+ // Stack creation and is immutable.
+ raw bool
+
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+ forwarding bool
+
+ // route is the route table passed in by the user via SetRouteTable(),
+ // it is used by FindRoute() to build a route for a specific
+ // destination.
+ routeTable []tcpip.Route
+
+ *ports.PortManager
+
+ // If not nil, then any new endpoints will have this probe function
+ // invoked everytime they receive a TCP segment.
+ tcpProbeFunc TCPProbeFunc
+
+ // clock is used to generate user-visible times.
+ clock tcpip.Clock
+
+ // handleLocal allows non-loopback interfaces to loop packets.
+ handleLocal bool
+}
+
+// Options contains optional Stack configuration.
+type Options struct {
+ // Clock is an optional clock source used for timestampping packets.
+ //
+ // If no Clock is specified, the clock source will be time.Now.
+ Clock tcpip.Clock
+
+ // Stats are optional statistic counters.
+ Stats tcpip.Stats
+
+ // HandleLocal indicates whether packets destined to their source
+ // should be handled by the stack internally (true) or outside the
+ // stack (false).
+ HandleLocal bool
+
+ // Raw indicates whether raw sockets may be created.
+ Raw bool
+}
+
+// New allocates a new networking stack with only the requested networking and
+// transport protocols configured with default options.
+//
+// Protocol options can be changed by calling the
+// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
+// stack. Please refer to individual protocol implementations as to what options
+// are supported.
+func New(network []string, transport []string, opts Options) *Stack {
+ clock := opts.Clock
+ if clock == nil {
+ clock = &tcpip.StdClock{}
+ }
+
+ s := &Stack{
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ raw: opts.Raw,
+ }
+
+ // Add specified network protocols.
+ for _, name := range network {
+ netProtoFactory, ok := networkProtocols[name]
+ if !ok {
+ continue
+ }
+ netProto := netProtoFactory()
+ s.networkProtocols[netProto.Number()] = netProto
+ if r, ok := netProto.(LinkAddressResolver); ok {
+ s.linkAddrResolvers[r.LinkAddressProtocol()] = r
+ }
+ }
+
+ // Add specified transport protocols.
+ for _, name := range transport {
+ transProtoFactory, ok := transportProtocols[name]
+ if !ok {
+ continue
+ }
+ transProto := transProtoFactory()
+ s.transportProtocols[transProto.Number()] = &transportProtocolState{
+ proto: transProto,
+ }
+ }
+
+ // Create the global transport demuxer.
+ s.demux = newTransportDemuxer(s)
+
+ return s
+}
+
+// SetNetworkProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.SetOption(option)
+}
+
+// NetworkProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// e.g.
+// var v ipv4.MyOption
+// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
+// if err != nil {
+// ...
+// }
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.Option(option)
+}
+
+// SetTransportProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.SetOption(option)
+}
+
+// TransportProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// var v tcp.SACKEnabled
+// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
+// ...
+// }
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.Option(option)
+}
+
+// SetTransportProtocolHandler sets the per-stack default handler for the given
+// protocol.
+//
+// It must be called only during initialization of the stack. Changing it as the
+// stack is operating is not supported.
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) {
+ state := s.transportProtocols[p]
+ if state != nil {
+ state.defaultHandler = h
+ }
+}
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (s *Stack) NowNanoseconds() int64 {
+ return s.clock.NowNanoseconds()
+}
+
+// Stats returns a mutable copy of the current stats.
+//
+// This is not generally exported via the public interface, but is available
+// internally.
+func (s *Stack) Stats() tcpip.Stats {
+ return s.stats
+}
+
+// SetForwarding enables or disables the packet forwarding between NICs.
+func (s *Stack) SetForwarding(enable bool) {
+ // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
+ s.mu.Lock()
+ s.forwarding = enable
+ s.mu.Unlock()
+}
+
+// Forwarding returns if the packet forwarding between NICs is enabled.
+func (s *Stack) Forwarding() bool {
+ // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.forwarding
+}
+
+// SetRouteTable assigns the route table to be used by this stack. It
+// specifies which NIC to use for given destination address ranges.
+func (s *Stack) SetRouteTable(table []tcpip.Route) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.routeTable = table
+}
+
+// GetRouteTable returns the route table which is currently in use.
+func (s *Stack) GetRouteTable() []tcpip.Route {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return append([]tcpip.Route(nil), s.routeTable...)
+}
+
+// NewEndpoint creates a new transport layer endpoint of the given protocol.
+func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewEndpoint(s, network, waiterQueue)
+}
+
+// NewRawEndpoint creates a new raw transport layer endpoint of the given
+// protocol. Raw endpoints receive all traffic for a given protocol regardless
+// of address.
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if !s.raw {
+ return nil, tcpip.ErrNotPermitted
+ }
+
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewRawEndpoint(s, network, waiterQueue)
+}
+
+// createNIC creates a NIC with the provided id and link-layer endpoint, and
+// optionally enable it.
+func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error {
+ ep := FindLinkEndpoint(linkEP)
+ if ep == nil {
+ return tcpip.ErrBadLinkEndpoint
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Make sure id is unique.
+ if _, ok := s.nics[id]; ok {
+ return tcpip.ErrDuplicateNICID
+ }
+
+ n := newNIC(s, id, name, ep, loopback)
+
+ s.nics[id] = n
+ if enabled {
+ n.attachLinkEndpoint()
+ }
+
+ return nil
+}
+
+// CreateNIC creates a NIC with the provided id and link-layer endpoint.
+func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, "", linkEP, true, false)
+}
+
+// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
+// and a human-readable name.
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true, false)
+}
+
+// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
+// endpoint, and a human-readable name.
+func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true, true)
+}
+
+// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
+// but leave it disable. Stack.EnableNIC must be called before the link-layer
+// endpoint starts delivering packets to it.
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, "", linkEP, false, false)
+}
+
+// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
+// CreateDisabledNIC.
+func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, false, false)
+}
+
+// EnableNIC enables the given NIC so that the link-layer endpoint can start
+// delivering packets to it.
+func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.attachLinkEndpoint()
+
+ return nil
+}
+
+// CheckNIC checks if a NIC is usable.
+func (s *Stack) CheckNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ nic, ok := s.nics[id]
+ s.mu.RUnlock()
+ if ok {
+ return nic.linkEP.IsAttached()
+ }
+ return false
+}
+
+// NICSubnets returns a map of NICIDs to their associated subnets.
+func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := map[tcpip.NICID][]tcpip.Subnet{}
+
+ for id, nic := range s.nics {
+ nics[id] = append(nics[id], nic.Subnets()...)
+ }
+ return nics
+}
+
+// NICInfo captures the name and addresses assigned to a NIC.
+type NICInfo struct {
+ Name string
+ LinkAddress tcpip.LinkAddress
+ ProtocolAddresses []tcpip.ProtocolAddress
+
+ // Flags indicate the state of the NIC.
+ Flags NICStateFlags
+
+ // MTU is the maximum transmission unit.
+ MTU uint32
+
+ Stats NICStats
+}
+
+// NICInfo returns a map of NICIDs to their associated information.
+func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID]NICInfo)
+ for id, nic := range s.nics {
+ flags := NICStateFlags{
+ Up: true, // Netstack interfaces are always up.
+ Running: nic.linkEP.IsAttached(),
+ Promiscuous: nic.isPromiscuousMode(),
+ Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0,
+ }
+ nics[id] = NICInfo{
+ Name: nic.name,
+ LinkAddress: nic.linkEP.LinkAddress(),
+ ProtocolAddresses: nic.Addresses(),
+ Flags: flags,
+ MTU: nic.linkEP.MTU(),
+ Stats: nic.stats,
+ }
+ }
+ return nics
+}
+
+// NICStateFlags holds information about the state of an NIC.
+type NICStateFlags struct {
+ // Up indicates whether the interface is running.
+ Up bool
+
+ // Running indicates whether resources are allocated.
+ Running bool
+
+ // Promiscuous indicates whether the interface is in promiscuous mode.
+ Promiscuous bool
+
+ // Loopback indicates whether the interface is a loopback.
+ Loopback bool
+}
+
+// AddAddress adds a new network-layer address to the specified NIC.
+func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
+}
+
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
+func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.AddAddressWithOptions(protocol, addr, peb)
+}
+
+// AddSubnet adds a subnet range to the specified NIC.
+func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ nic.AddSubnet(protocol, subnet)
+ return nil
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// RemoveSubnet removes the subnet range from the specified NIC.
+func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ nic.RemoveSubnet(subnet)
+ return nil
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// ContainsSubnet reports whether the specified NIC contains the specified
+// subnet.
+func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.ContainsSubnet(subnet), nil
+ }
+
+ return false, tcpip.ErrUnknownNICID
+}
+
+// RemoveAddress removes an existing network-layer address from the specified
+// NIC.
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.RemoveAddress(addr)
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// GetMainNICAddress returns the first primary address (and the subnet that
+// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
+// address if no primary addresses exist. Returns an error if the NIC doesn't
+// exist or has no endpoints.
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.getMainNICAddress(protocol)
+ }
+
+ return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID
+}
+
+func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+ if len(localAddr) == 0 {
+ return nic.primaryEndpoint(netProto)
+ }
+ return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
+}
+
+// FindRoute creates a route to the given destination address, leaving through
+// the given nic and local address (if provided).
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ isBroadcast := remoteAddr == header.IPv4Broadcast
+ isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
+ needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ if id != 0 && !needRoute {
+ if nic, ok := s.nics[id]; ok {
+ if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
+ return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil
+ }
+ }
+ } else {
+ for _, route := range s.routeTable {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Match(remoteAddr)) {
+ continue
+ }
+ if nic, ok := s.nics[route.NIC]; ok {
+ if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
+ if len(remoteAddr) == 0 {
+ // If no remote address was provided, then the route
+ // provided will refer to the link local address.
+ remoteAddr = ref.ep.ID().LocalAddress
+ }
+
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback)
+ if needRoute {
+ r.NextHop = route.Gateway
+ }
+ return r, nil
+ }
+ }
+ }
+ }
+
+ if !needRoute {
+ return Route{}, tcpip.ErrNetworkUnreachable
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+}
+
+// CheckNetworkProtocol checks if a given network protocol is enabled in the
+// stack.
+func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool {
+ _, ok := s.networkProtocols[protocol]
+ return ok
+}
+
+// CheckLocalAddress determines if the given local address exists, and if it
+// does, returns the id of the NIC it's bound to. Returns 0 if the address
+// does not exist.
+func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ // If a NIC is specified, we try to find the address there only.
+ if nicid != 0 {
+ nic := s.nics[nicid]
+ if nic == nil {
+ return 0
+ }
+
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if ref == nil {
+ return 0
+ }
+
+ ref.decRef()
+
+ return nic.id
+ }
+
+ // Go through all the NICs.
+ for _, nic := range s.nics {
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if ref != nil {
+ ref.decRef()
+ return nic.id
+ }
+ }
+
+ return 0
+}
+
+// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
+func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setPromiscuousMode(enable)
+
+ return nil
+}
+
+// SetSpoofing enables or disables address spoofing in the given NIC, allowing
+// endpoints to bind to any address in the NIC.
+func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setSpoofing(enable)
+
+ return nil
+}
+
+// AddLinkAddress adds a link address to the stack link cache.
+func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.add(fullAddr, linkAddr)
+ // TODO: provide a way for a transport endpoint to receive a signal
+ // that AddLinkAddress for a particular address has been called.
+}
+
+// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
+func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+ s.mu.RLock()
+ nic := s.nics[nicid]
+ if nic == nil {
+ s.mu.RUnlock()
+ return "", nil, tcpip.ErrUnknownNICID
+ }
+ s.mu.RUnlock()
+
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ linkRes := s.linkAddrResolvers[protocol]
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+}
+
+// RemoveWaker implements LinkAddressCache.RemoveWaker.
+func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic := s.nics[nicid]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.removeWaker(fullAddr, waker)
+ }
+}
+
+// RegisterTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided id will be
+// delivered to the given endpoint; specifying a nic is optional, but
+// nic-specific IDs have precedence over global ones.
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+}
+
+// UnregisterTransportEndpoint removes the endpoint with the given id from the
+// stack transport dispatcher.
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+ if nicID == 0 {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
+ }
+}
+
+// RegisterRawTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided transport
+// protocol will be delivered to the given endpoint.
+func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+}
+
+// UnregisterRawTransportEndpoint removes the endpoint for the transport
+// protocol from the stack transport dispatcher.
+func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ if nicID == 0 {
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
+ }
+}
+
+// NetworkProtocolInstance returns the protocol instance in the stack for the
+// specified network protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol {
+ if p, ok := s.networkProtocols[num]; ok {
+ return p
+ }
+ return nil
+}
+
+// TransportProtocolInstance returns the protocol instance in the stack for the
+// specified transport protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol {
+ if pState, ok := s.transportProtocols[num]; ok {
+ return pState.proto
+ }
+ return nil
+}
+
+// AddTCPProbe installs a probe function that will be invoked on every segment
+// received by a given TCP endpoint. The probe function is passed a copy of the
+// TCP endpoint state before and after processing of the segment.
+//
+// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
+// created prior to this call will not call the probe function.
+//
+// Further, installing two different probes back to back can result in some
+// endpoints calling the first one and some the second one. There is no
+// guarantee provided on which probe will be invoked. Ideally this should only
+// be called once per stack.
+func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
+ s.mu.Lock()
+ s.tcpProbeFunc = probe
+ s.mu.Unlock()
+}
+
+// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
+// otherwise.
+func (s *Stack) GetTCPProbe() TCPProbeFunc {
+ s.mu.Lock()
+ p := s.tcpProbeFunc
+ s.mu.Unlock()
+ return p
+}
+
+// RemoveTCPProbe removes an installed TCP probe.
+//
+// NOTE: This only ensures that endpoints created after this call do not
+// have a probe attached. Endpoints already created will continue to invoke
+// TCP probe.
+func (s *Stack) RemoveTCPProbe() {
+ s.mu.Lock()
+ s.tcpProbeFunc = nil
+ s.mu.Unlock()
+}
+
+// JoinGroup joins the given multicast group on the given NIC.
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ // TODO: notify network of subscription via igmp protocol.
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.joinGroup(protocol, multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
+}
+
+// LeaveGroup leaves the given multicast group on the given NIC.
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.leaveGroup(multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
+}
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
new file mode 100644
index 000000000..dfec4258a
--- /dev/null
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -0,0 +1,19 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+// StackFromEnv is the global stack created in restore run.
+// FIXME(b/36201077)
+var StackFromEnv *Stack
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
new file mode 100755
index 000000000..bb05ff7c1
--- /dev/null
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -0,0 +1,59 @@
+// automatically generated by stateify.
+
+package stack
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *TransportEndpointID) beforeSave() {}
+func (x *TransportEndpointID) save(m state.Map) {
+ x.beforeSave()
+ m.Save("LocalPort", &x.LocalPort)
+ m.Save("LocalAddress", &x.LocalAddress)
+ m.Save("RemotePort", &x.RemotePort)
+ m.Save("RemoteAddress", &x.RemoteAddress)
+}
+
+func (x *TransportEndpointID) afterLoad() {}
+func (x *TransportEndpointID) load(m state.Map) {
+ m.Load("LocalPort", &x.LocalPort)
+ m.Load("LocalAddress", &x.LocalAddress)
+ m.Load("RemotePort", &x.RemotePort)
+ m.Load("RemoteAddress", &x.RemoteAddress)
+}
+
+func (x *GSOType) save(m state.Map) {
+ m.SaveValue("", (int)(*x))
+}
+
+func (x *GSOType) load(m state.Map) {
+ m.LoadValue("", new(int), func(y interface{}) { *x = (GSOType)(y.(int)) })
+}
+
+func (x *GSO) beforeSave() {}
+func (x *GSO) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Type", &x.Type)
+ m.Save("NeedsCsum", &x.NeedsCsum)
+ m.Save("CsumOffset", &x.CsumOffset)
+ m.Save("MSS", &x.MSS)
+ m.Save("L3HdrLen", &x.L3HdrLen)
+ m.Save("MaxSize", &x.MaxSize)
+}
+
+func (x *GSO) afterLoad() {}
+func (x *GSO) load(m state.Map) {
+ m.Load("Type", &x.Type)
+ m.Load("NeedsCsum", &x.NeedsCsum)
+ m.Load("CsumOffset", &x.CsumOffset)
+ m.Load("MSS", &x.MSS)
+ m.Load("L3HdrLen", &x.L3HdrLen)
+ m.Load("MaxSize", &x.MaxSize)
+}
+
+func init() {
+ state.Register("stack.TransportEndpointID", (*TransportEndpointID)(nil), state.Fns{Save: (*TransportEndpointID).save, Load: (*TransportEndpointID).load})
+ state.Register("stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load})
+ state.Register("stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load})
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
new file mode 100644
index 000000000..605bfadeb
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -0,0 +1,420 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "math/rand"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+type protocolIDs struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+}
+
+// transportEndpoints manages all endpoints of a given protocol. It has its own
+// mutex so as to reduce interference between protocols.
+type transportEndpoints struct {
+ // mu protects all fields of the transportEndpoints.
+ mu sync.RWMutex
+ endpoints map[TransportEndpointID]TransportEndpoint
+ // rawEndpoints contains endpoints for raw sockets, which receive all
+ // traffic of a given protocol regardless of port.
+ rawEndpoints []RawTransportEndpoint
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ e, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if multiPortEp, ok := e.(*multiPortEndpoint); ok {
+ if !multiPortEp.unregisterEndpoint(ep) {
+ return
+ }
+ }
+ delete(eps.endpoints, id)
+}
+
+// transportDemuxer demultiplexes packets targeted at a transport endpoint
+// (i.e., after they've been parsed by the network layer). It does two levels
+// of demultiplexing: first based on the network and transport protocols, then
+// based on endpoints IDs. It should only be instantiated via
+// newTransportDemuxer.
+type transportDemuxer struct {
+ // protocol is immutable.
+ protocol map[protocolIDs]*transportEndpoints
+}
+
+func newTransportDemuxer(stack *Stack) *transportDemuxer {
+ d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+
+ // Add each network and transport pair to the demuxer.
+ for netProto := range stack.networkProtocols {
+ for proto := range stack.transportProtocols {
+ d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ }
+ }
+ }
+
+ return d
+}
+
+// registerEndpoint registers the given endpoint with the dispatcher such that
+// packets that match the endpoint ID are delivered to it.
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// multiPortEndpoint is a container for TransportEndpoints which are bound to
+// the same pair of address and port.
+type multiPortEndpoint struct {
+ mu sync.RWMutex
+ endpointsArr []TransportEndpoint
+ endpointsMap map[TransportEndpoint]int
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// reciprocalScale scales a value into range [0, n).
+//
+// This is similar to val % n, but faster.
+// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
+func reciprocalScale(val, n uint32) uint32 {
+ return uint32((uint64(val) * uint64(n)) >> 32)
+}
+
+// selectEndpoint calculates a hash of destination and source addresses and
+// ports then uses it to select a socket. In this case, all packets from one
+// address will be sent to same endpoint.
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8),
+ }
+
+ h := jenkins.Sum32(ep.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+ hash := h.Sum32()
+
+ idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
+ return ep.endpointsArr[idx]
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints managed by ep.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
+ }
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+ }
+ } else {
+ ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
+}
+
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // A new endpoint is added into endpointsArr and its index there is
+ // saved in endpointsMap. This will allows to remove endpoint from
+ // the array fast.
+ ep.endpointsMap[t] = len(ep.endpointsArr)
+ ep.endpointsArr = append(ep.endpointsArr, t)
+}
+
+// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ idx, ok := ep.endpointsMap[t]
+ if !ok {
+ return false
+ }
+ delete(ep.endpointsMap, t)
+ l := len(ep.endpointsArr)
+ if l > 1 {
+ // The last endpoint in endpointsArr is moved instead of the deleted one.
+ lastEp := ep.endpointsArr[l-1]
+ ep.endpointsArr[idx] = lastEp
+ ep.endpointsMap[lastEp] = idx
+ ep.endpointsArr = ep.endpointsArr[0 : l-1]
+ return false
+ }
+ return true
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if id.RemotePort != 0 {
+ reusePort = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+
+ var multiPortEp *multiPortEndpoint
+ if _, ok := eps.endpoints[id]; ok {
+ if !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
+ if !ok {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if reusePort {
+ if multiPortEp == nil {
+ multiPortEp = &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.seed = rand.Uint32()
+ eps.endpoints[id] = multiPortEp
+ }
+
+ multiPortEp.singleRegisterEndpoint(ep)
+
+ return nil
+ }
+ eps.endpoints[id] = ep
+
+ return nil
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.unregisterEndpoint(id, ep)
+ }
+ }
+}
+
+var loopbackSubnet = func() tcpip.Subnet {
+ sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
+ if err != nil {
+ panic(err)
+ }
+ return sn
+}()
+
+// deliverPacket attempts to find one or more matching transport endpoints, and
+// then, if matches are found, delivers the packet to them. Returns true if it
+// found one or more endpoints, false otherwise.
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ // If a sender bound to the Loopback interface sends a broadcast,
+ // that broadcast must not be delivered to the sender.
+ if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort {
+ return false
+ }
+
+ // If the packet is a broadcast, then find all matching transport endpoints.
+ // Otherwise, try to find a single matching transport endpoint.
+ destEps := make([]TransportEndpoint, 0, 1)
+ eps.mu.RLock()
+
+ if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
+ for epID, endpoint := range eps.endpoints {
+ if epID.LocalPort == id.LocalPort {
+ destEps = append(destEps, endpoint)
+ }
+ }
+ } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
+
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEps) == 0 {
+ // UDP packet could not be delivered to an unknown destination port.
+ if protocol == header.UDPProtocolNumber {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ }
+ return false
+ }
+
+ // Deliver the packet.
+ for _, ep := range destEps {
+ ep.HandlePacket(r, id, vv)
+ }
+
+ return true
+}
+
+// deliverRawPacket attempts to deliver the given packet and returns whether it
+// was delivered successfully.
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
+ // raw endpoint first. If there are multiple raw endpoints, they all
+ // receive the packet.
+ foundRaw := false
+ eps.mu.RLock()
+ for _, rawEP := range eps.rawEndpoints {
+ // Each endpoint gets its own copy of the packet for the sake
+ // of save/restore.
+ rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ foundRaw = true
+ }
+ eps.mu.RUnlock()
+
+ return foundRaw
+}
+
+// deliverControlPacket attempts to deliver the given control packet. Returns
+// true if it found an endpoint, false otherwise.
+func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{net, trans}]
+ if !ok {
+ return false
+ }
+
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, vv, id)
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find one.
+ if ep == nil {
+ return false
+ }
+
+ // Deliver the packet.
+ ep.HandleControlPacket(id, typ, extra, vv)
+
+ return true
+}
+
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+ // Try to find a match with the id as provided.
+ if ep, ok := eps.endpoints[id]; ok {
+ return ep
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ return nil
+}
+
+// registerRawEndpoint registers the given endpoint with the dispatcher such
+// that packets of the appropriate protocol are delivered to it. A single
+// packet can be sent to one or more raw endpoints along with a non-raw
+// endpoint.
+func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ eps.rawEndpoints = append(eps.rawEndpoints, ep)
+
+ return nil
+}
+
+// unregisterRawEndpoint unregisters the raw endpoint for the given transport
+// protocol such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
+ return
+ }
+ }
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
new file mode 100644
index 000000000..f9886c6e4
--- /dev/null
+++ b/pkg/tcpip/tcpip.go
@@ -0,0 +1,1055 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcpip provides the interfaces and related types that users of the
+// tcpip stack will use in order to create endpoints used to send and receive
+// data over the network stack.
+//
+// The starting point is the creation and configuration of a stack. A stack can
+// be created by calling the New() function of the tcpip/stack/stack package;
+// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
+// adding network addresses (via calls to Stack.AddAddress()), and
+// setting a route table (via a call to Stack.SetRouteTable()).
+//
+// Once a stack is configured, endpoints can be created by calling
+// Stack.NewEndpoint(). Such endpoints can be used to send/receive data, connect
+// to peers, listen for connections, accept connections, etc., depending on the
+// transport protocol selected.
+package tcpip
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Error represents an error in the netstack error space. Using a special type
+// ensures that errors outside of this space are not accidentally introduced.
+//
+// Note: to support save / restore, it is important that all tcpip errors have
+// distinct error messages.
+type Error struct {
+ msg string
+
+ ignoreStats bool
+}
+
+// String implements fmt.Stringer.String.
+func (e *Error) String() string {
+ return e.msg
+}
+
+// IgnoreStats indicates whether this error type should be included in failure
+// counts in tcpip.Stats structs.
+func (e *Error) IgnoreStats() bool {
+ return e.ignoreStats
+}
+
+// Errors that can be returned by the network stack.
+var (
+ ErrUnknownProtocol = &Error{msg: "unknown protocol"}
+ ErrUnknownNICID = &Error{msg: "unknown nic id"}
+ ErrUnknownDevice = &Error{msg: "unknown device"}
+ ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
+ ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
+ ErrDuplicateAddress = &Error{msg: "duplicate address"}
+ ErrNoRoute = &Error{msg: "no route"}
+ ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
+ ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
+ ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
+ ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
+ ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
+ ErrNoPortAvailable = &Error{msg: "no ports are available"}
+ ErrPortInUse = &Error{msg: "port is in use"}
+ ErrBadLocalAddress = &Error{msg: "bad local address"}
+ ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
+ ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
+ ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
+ ErrConnectionRefused = &Error{msg: "connection was refused"}
+ ErrTimeout = &Error{msg: "operation timed out"}
+ ErrAborted = &Error{msg: "operation aborted"}
+ ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
+ ErrDestinationRequired = &Error{msg: "destination address is required"}
+ ErrNotSupported = &Error{msg: "operation not supported"}
+ ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
+ ErrNotConnected = &Error{msg: "endpoint not connected"}
+ ErrConnectionReset = &Error{msg: "connection reset by peer"}
+ ErrConnectionAborted = &Error{msg: "connection aborted"}
+ ErrNoSuchFile = &Error{msg: "no such file"}
+ ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
+ ErrNoLinkAddress = &Error{msg: "no remote link address"}
+ ErrBadAddress = &Error{msg: "bad address"}
+ ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
+ ErrMessageTooLong = &Error{msg: "message too long"}
+ ErrNoBufferSpace = &Error{msg: "no buffer space available"}
+ ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
+ ErrNotPermitted = &Error{msg: "operation not permitted"}
+)
+
+// Errors related to Subnet
+var (
+ errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
+ errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
+)
+
+// ErrSaveRejection indicates a failed save due to unsupported networking state.
+// This type of errors is only used for save logic.
+type ErrSaveRejection struct {
+ Err error
+}
+
+// Error returns a sensible description of the save rejection error.
+func (e ErrSaveRejection) Error() string {
+ return "save rejected due to unsupported networking state: " + e.Err.Error()
+}
+
+// A Clock provides the current time.
+//
+// Times returned by a Clock should always be used for application-visible
+// time. Only monotonic times should be used for netstack internal timekeeping.
+type Clock interface {
+ // NowNanoseconds returns the current real time as a number of
+ // nanoseconds since the Unix epoch.
+ NowNanoseconds() int64
+
+ // NowMonotonic returns a monotonic time value.
+ NowMonotonic() int64
+}
+
+// Address is a byte slice cast as a string that represents the address of a
+// network node. Or, in the case of unix endpoints, it may represent a path.
+type Address string
+
+// AddressMask is a bitmask for an address.
+type AddressMask string
+
+// String implements Stringer.
+func (a AddressMask) String() string {
+ return Address(a).String()
+}
+
+// Subnet is a subnet defined by its address and mask.
+type Subnet struct {
+ address Address
+ mask AddressMask
+}
+
+// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
+func NewSubnet(a Address, m AddressMask) (Subnet, error) {
+ if len(a) != len(m) {
+ return Subnet{}, errSubnetLengthMismatch
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i]&^m[i] != 0 {
+ return Subnet{}, errSubnetAddressMasked
+ }
+ }
+ return Subnet{a, m}, nil
+}
+
+// Contains returns true iff the address is of the same length and matches the
+// subnet address and mask.
+func (s *Subnet) Contains(a Address) bool {
+ if len(a) != len(s.address) {
+ return false
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i]&s.mask[i] != s.address[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// ID returns the subnet ID.
+func (s *Subnet) ID() Address {
+ return s.address
+}
+
+// Bits returns the number of ones (network bits) and zeros (host bits) in the
+// subnet mask.
+func (s *Subnet) Bits() (ones int, zeros int) {
+ for _, b := range []byte(s.mask) {
+ for i := uint(0); i < 8; i++ {
+ if b&(1<<i) == 0 {
+ zeros++
+ } else {
+ ones++
+ }
+ }
+ }
+ return
+}
+
+// Prefix returns the number of bits before the first host bit.
+func (s *Subnet) Prefix() int {
+ for i, b := range []byte(s.mask) {
+ for j := 7; j >= 0; j-- {
+ if b&(1<<uint(j)) == 0 {
+ return i*8 + 7 - j
+ }
+ }
+ }
+ return len(s.mask) * 8
+}
+
+// Mask returns the subnet mask.
+func (s *Subnet) Mask() AddressMask {
+ return s.mask
+}
+
+// NICID is a number that uniquely identifies a NIC.
+type NICID int32
+
+// ShutdownFlags represents flags that can be passed to the Shutdown() method
+// of the Endpoint interface.
+type ShutdownFlags int
+
+// Values of the flags that can be passed to the Shutdown() method. They can
+// be OR'ed together.
+const (
+ ShutdownRead ShutdownFlags = 1 << iota
+ ShutdownWrite
+)
+
+// FullAddress represents a full transport node address, as required by the
+// Connect() and Bind() methods.
+//
+// +stateify savable
+type FullAddress struct {
+ // NIC is the ID of the NIC this address refers to.
+ //
+ // This may not be used by all endpoint types.
+ NIC NICID
+
+ // Addr is the network address.
+ Addr Address
+
+ // Port is the transport port.
+ //
+ // This may not be used by all endpoint types.
+ Port uint16
+}
+
+// Payload provides an interface around data that is being sent to an endpoint.
+// This allows the endpoint to request the amount of data it needs based on
+// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
+type Payload interface {
+ // Get returns a slice containing exactly 'min(size, p.Size())' bytes.
+ Get(size int) ([]byte, *Error)
+
+ // Size returns the payload size.
+ Size() int
+}
+
+// SlicePayload implements Payload on top of slices for convenience.
+type SlicePayload []byte
+
+// Get implements Payload.
+func (s SlicePayload) Get(size int) ([]byte, *Error) {
+ if size > s.Size() {
+ size = s.Size()
+ }
+ return s[:size], nil
+}
+
+// Size implements Payload.
+func (s SlicePayload) Size() int {
+ return len(s)
+}
+
+// A ControlMessages contains socket control messages for IP sockets.
+//
+// +stateify savable
+type ControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packed used to create
+ // the read data was received.
+ Timestamp int64
+}
+
+// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
+// that exposes functionality like read, write, connect, etc. to users of the
+// networking stack.
+type Endpoint interface {
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it.
+ Close()
+
+ // Read reads data from the endpoint and optionally returns the sender.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ Read(*FullAddress) (buffer.View, ControlMessages, *Error)
+
+ // Write writes data to the endpoint's peer. This method does not block if
+ // the data cannot be written.
+ //
+ // Unlike io.Writer.Write, Endpoint.Write transfers ownership of any bytes
+ // successfully written to the Endpoint. That is, if a call to
+ // Write(SlicePayload{data}) returns (n, err), it may retain data[:n], and
+ // the caller should not use data[:n] after Write returns.
+ //
+ // Note that unlike io.Writer.Write, it is not an error for Write to
+ // perform a partial write (if n > 0, no error may be returned). Only
+ // stream (TCP) Endpoints may return partial writes, and even then only
+ // in the case where writing additional data would block. Other Endpoints
+ // will either write the entire message or return an error.
+ //
+ // For UDP and Ping sockets if address resolution is required,
+ // ErrNoLinkAddress and a notification channel is returned for the caller to
+ // block. Channel is closed once address resolution is complete (success or
+ // not). The channel is only non-nil in this case.
+ Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error)
+
+ // Peek reads data without consuming it from the endpoint.
+ //
+ // This method does not block if there is no data pending.
+ Peek([][]byte) (uintptr, ControlMessages, *Error)
+
+ // Connect connects the endpoint to its peer. Specifying a NIC is
+ // optional.
+ //
+ // There are three classes of return values:
+ // nil -- the attempt to connect succeeded.
+ // ErrConnectStarted/ErrAlreadyConnecting -- the connect attempt started
+ // but hasn't completed yet. In this case, the caller must call Connect
+ // or GetSockOpt(ErrorOption) when the endpoint becomes writable to
+ // get the actual result. The first call to Connect after the socket has
+ // connected returns nil. Calling connect again results in ErrAlreadyConnected.
+ // Anything else -- the attempt to connect failed.
+ Connect(address FullAddress) *Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags ShutdownFlags) *Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *waiter.Queue, *Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ Bind(address FullAddress) *Error
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (FullAddress, *Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (FullAddress, *Error)
+
+ // Readiness returns the current readiness of the endpoint. For example,
+ // if waiter.EventIn is set, the endpoint is immediately readable.
+ Readiness(mask waiter.EventMask) waiter.EventMask
+
+ // SetSockOpt sets a socket option. opt should be one of the *Option types.
+ SetSockOpt(opt interface{}) *Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // *Option types.
+ GetSockOpt(opt interface{}) *Error
+}
+
+// WriteOptions contains options for Endpoint.Write.
+type WriteOptions struct {
+ // If To is not nil, write to the given address instead of the endpoint's
+ // peer.
+ To *FullAddress
+
+ // More has the same semantics as Linux's MSG_MORE.
+ More bool
+
+ // EndOfRecord has the same semantics as Linux's MSG_EOR.
+ EndOfRecord bool
+}
+
+// ErrorOption is used in GetSockOpt to specify that the last error reported by
+// the endpoint should be cleared and returned.
+type ErrorOption struct{}
+
+// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
+// buffer size option.
+type SendBufferSizeOption int
+
+// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
+// receive buffer size option.
+type ReceiveBufferSizeOption int
+
+// SendQueueSizeOption is used in GetSockOpt to specify that the number of
+// unread bytes in the output buffer should be returned.
+type SendQueueSizeOption int
+
+// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
+// unread bytes in the input buffer should be returned.
+type ReceiveQueueSizeOption int
+
+// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
+// socket is to be restricted to sending and receiving IPv6 packets only.
+type V6OnlyOption int
+
+// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be
+// sent out immediately by the transport protocol. For TCP, it determines if the
+// Nagle algorithm is on or off.
+type DelayOption int
+
+// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
+// held until segments are full by the TCP transport protocol.
+type CorkOption int
+
+// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
+// should allow reuse of local address.
+type ReuseAddressOption int
+
+// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
+// to be bound to an identical socket address.
+type ReusePortOption int
+
+// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
+type QuickAckOption int
+
+// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
+// SCM_CREDENTIALS socket control messages are enabled.
+//
+// Only supported on Unix sockets.
+type PasscredOption int
+
+// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
+//
+// TODO(b/64800844): Add and populate stat fields.
+type TCPInfoOption struct {
+ RTT time.Duration
+ RTTVar time.Duration
+}
+
+// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
+// TCP keepalive is enabled for this socket.
+type KeepaliveEnabledOption int
+
+// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
+// connection must remain idle before the first TCP keepalive packet is sent.
+// Once this time is reached, KeepaliveIntervalOption is used instead.
+type KeepaliveIdleOption time.Duration
+
+// KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the
+// interval between sending TCP keepalive packets.
+type KeepaliveIntervalOption time.Duration
+
+// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
+// of un-ACKed TCP keepalives that will be sent before the connection is
+// closed.
+type KeepaliveCountOption int
+
+// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
+// TTL value for multicast messages. The default is 1.
+type MulticastTTLOption uint8
+
+// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
+// default interface for multicast.
+type MulticastInterfaceOption struct {
+ NIC NICID
+ InterfaceAddr Address
+}
+
+// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
+// multicast packets sent over a non-loopback interface will be looped back.
+type MulticastLoopOption bool
+
+// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
+// AddMembershipOption and RemoveMembershipOption.
+type MembershipOption struct {
+ NIC NICID
+ InterfaceAddr Address
+ MulticastAddr Address
+}
+
+// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type AddMembershipOption MembershipOption
+
+// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type RemoveMembershipOption MembershipOption
+
+// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
+// TCP out-of-band data is delivered along with the normal in-band data.
+type OutOfBandInlineOption int
+
+// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
+// datagram sockets are allowed to send packets to a broadcast address.
+type BroadcastOption int
+
+// Route is a row in the routing table. It specifies through which NIC (and
+// gateway) sets of packets should be routed. A row is considered viable if the
+// masked target address matches the destination adddress in the row.
+type Route struct {
+ // Destination is the address that must be matched against the masked
+ // target address to check if this row is viable.
+ Destination Address
+
+ // Mask specifies which bits of the Destination and the target address
+ // must match for this row to be viable.
+ Mask AddressMask
+
+ // Gateway is the gateway to be used if this row is viable.
+ Gateway Address
+
+ // NIC is the id of the nic to be used if this row is viable.
+ NIC NICID
+}
+
+// Match determines if r is viable for the given destination address.
+func (r *Route) Match(addr Address) bool {
+ if len(addr) != len(r.Destination) {
+ return false
+ }
+
+ // Using header.Ipv4Broadcast would introduce an import cycle, so
+ // we'll use a literal instead.
+ if addr == "\xff\xff\xff\xff" {
+ return true
+ }
+
+ for i := 0; i < len(r.Destination); i++ {
+ if (addr[i] & r.Mask[i]) != r.Destination[i] {
+ return false
+ }
+ }
+
+ return true
+}
+
+// LinkEndpointID represents a data link layer endpoint.
+type LinkEndpointID uint64
+
+// TransportProtocolNumber is the number of a transport protocol.
+type TransportProtocolNumber uint32
+
+// NetworkProtocolNumber is the number of a network protocol.
+type NetworkProtocolNumber uint32
+
+// A StatCounter keeps track of a statistic.
+type StatCounter struct {
+ count uint64
+}
+
+// Increment adds one to the counter.
+func (s *StatCounter) Increment() {
+ s.IncrementBy(1)
+}
+
+// Value returns the current value of the counter.
+func (s *StatCounter) Value() uint64 {
+ return atomic.LoadUint64(&s.count)
+}
+
+// IncrementBy increments the counter by v.
+func (s *StatCounter) IncrementBy(v uint64) {
+ atomic.AddUint64(&s.count, v)
+}
+
+func (s *StatCounter) String() string {
+ return strconv.FormatUint(s.Value(), 10)
+}
+
+// ICMPv4PacketStats enumerates counts for all ICMPv4 packet types.
+type ICMPv4PacketStats struct {
+ // Echo is the total number of ICMPv4 echo packets counted.
+ Echo *StatCounter
+
+ // EchoReply is the total number of ICMPv4 echo reply packets counted.
+ EchoReply *StatCounter
+
+ // DstUnreachable is the total number of ICMPv4 destination unreachable
+ // packets counted.
+ DstUnreachable *StatCounter
+
+ // SrcQuench is the total number of ICMPv4 source quench packets
+ // counted.
+ SrcQuench *StatCounter
+
+ // Redirect is the total number of ICMPv4 redirect packets counted.
+ Redirect *StatCounter
+
+ // TimeExceeded is the total number of ICMPv4 time exceeded packets
+ // counted.
+ TimeExceeded *StatCounter
+
+ // ParamProblem is the total number of ICMPv4 parameter problem packets
+ // counted.
+ ParamProblem *StatCounter
+
+ // Timestamp is the total number of ICMPv4 timestamp packets counted.
+ Timestamp *StatCounter
+
+ // TimestampReply is the total number of ICMPv4 timestamp reply packets
+ // counted.
+ TimestampReply *StatCounter
+
+ // InfoRequest is the total number of ICMPv4 information request
+ // packets counted.
+ InfoRequest *StatCounter
+
+ // InfoReply is the total number of ICMPv4 information reply packets
+ // counted.
+ InfoReply *StatCounter
+}
+
+// ICMPv6PacketStats enumerates counts for all ICMPv6 packet types.
+type ICMPv6PacketStats struct {
+ // EchoRequest is the total number of ICMPv6 echo request packets
+ // counted.
+ EchoRequest *StatCounter
+
+ // EchoReply is the total number of ICMPv6 echo reply packets counted.
+ EchoReply *StatCounter
+
+ // DstUnreachable is the total number of ICMPv6 destination unreachable
+ // packets counted.
+ DstUnreachable *StatCounter
+
+ // PacketTooBig is the total number of ICMPv6 packet too big packets
+ // counted.
+ PacketTooBig *StatCounter
+
+ // TimeExceeded is the total number of ICMPv6 time exceeded packets
+ // counted.
+ TimeExceeded *StatCounter
+
+ // ParamProblem is the total number of ICMPv6 parameter problem packets
+ // counted.
+ ParamProblem *StatCounter
+
+ // RouterSolicit is the total number of ICMPv6 router solicit packets
+ // counted.
+ RouterSolicit *StatCounter
+
+ // RouterAdvert is the total number of ICMPv6 router advert packets
+ // counted.
+ RouterAdvert *StatCounter
+
+ // NeighborSolicit is the total number of ICMPv6 neighbor solicit
+ // packets counted.
+ NeighborSolicit *StatCounter
+
+ // NeighborAdvert is the total number of ICMPv6 neighbor advert packets
+ // counted.
+ NeighborAdvert *StatCounter
+
+ // RedirectMsg is the total number of ICMPv6 redirect message packets
+ // counted.
+ RedirectMsg *StatCounter
+}
+
+// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
+type ICMPv4SentPacketStats struct {
+ ICMPv4PacketStats
+
+ // Dropped is the total number of ICMPv4 packets dropped due to link
+ // layer errors.
+ Dropped *StatCounter
+}
+
+// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
+type ICMPv4ReceivedPacketStats struct {
+ ICMPv4PacketStats
+
+ // Invalid is the total number of ICMPv4 packets received that the
+ // transport layer could not parse.
+ Invalid *StatCounter
+}
+
+// ICMPv6SentPacketStats collects outbound ICMPv6-specific stats.
+type ICMPv6SentPacketStats struct {
+ ICMPv6PacketStats
+
+ // Dropped is the total number of ICMPv6 packets dropped due to link
+ // layer errors.
+ Dropped *StatCounter
+}
+
+// ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats.
+type ICMPv6ReceivedPacketStats struct {
+ ICMPv6PacketStats
+
+ // Invalid is the total number of ICMPv6 packets received that the
+ // transport layer could not parse.
+ Invalid *StatCounter
+}
+
+// ICMPStats collects ICMP-specific stats (both v4 and v6).
+type ICMPStats struct {
+ // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
+ // and a single count of packets which failed to write to the link
+ // layer.
+ V4PacketsSent ICMPv4SentPacketStats
+
+ // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
+ // packet type and a single count of invalid packets received.
+ V4PacketsReceived ICMPv4ReceivedPacketStats
+
+ // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
+ // and a single count of packets which failed to write to the link
+ // layer.
+ V6PacketsSent ICMPv6SentPacketStats
+
+ // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
+ // packet type and a single count of invalid packets received.
+ V6PacketsReceived ICMPv6ReceivedPacketStats
+}
+
+// IPStats collects IP-specific stats (both v4 and v6).
+type IPStats struct {
+ // PacketsReceived is the total number of IP packets received from the
+ // link layer in nic.DeliverNetworkPacket.
+ PacketsReceived *StatCounter
+
+ // InvalidAddressesReceived is the total number of IP packets received
+ // with an unknown or invalid destination address.
+ InvalidAddressesReceived *StatCounter
+
+ // PacketsDelivered is the total number of incoming IP packets that
+ // are successfully delivered to the transport layer via HandlePacket.
+ PacketsDelivered *StatCounter
+
+ // PacketsSent is the total number of IP packets sent via WritePacket.
+ PacketsSent *StatCounter
+
+ // OutgoingPacketErrors is the total number of IP packets which failed
+ // to write to a link-layer endpoint.
+ OutgoingPacketErrors *StatCounter
+}
+
+// TCPStats collects TCP-specific stats.
+type TCPStats struct {
+ // ActiveConnectionOpenings is the number of connections opened
+ // successfully via Connect.
+ ActiveConnectionOpenings *StatCounter
+
+ // PassiveConnectionOpenings is the number of connections opened
+ // successfully via Listen.
+ PassiveConnectionOpenings *StatCounter
+
+ // ListenOverflowSynDrop is the number of times the listen queue overflowed
+ // and a SYN was dropped.
+ ListenOverflowSynDrop *StatCounter
+
+ // ListenOverflowAckDrop is the number of times the final ACK
+ // in the handshake was dropped due to overflow.
+ ListenOverflowAckDrop *StatCounter
+
+ // ListenOverflowCookieSent is the number of times a SYN cookie was sent.
+ ListenOverflowSynCookieSent *StatCounter
+
+ // ListenOverflowSynCookieRcvd is the number of times a valid SYN
+ // cookie was received.
+ ListenOverflowSynCookieRcvd *StatCounter
+
+ // ListenOverflowInvalidSynCookieRcvd is the number of times an invalid SYN cookie
+ // was received.
+ ListenOverflowInvalidSynCookieRcvd *StatCounter
+
+ // FailedConnectionAttempts is the number of calls to Connect or Listen
+ // (active and passive openings, respectively) that end in an error.
+ FailedConnectionAttempts *StatCounter
+
+ // ValidSegmentsReceived is the number of TCP segments received that
+ // the transport layer successfully parsed.
+ ValidSegmentsReceived *StatCounter
+
+ // InvalidSegmentsReceived is the number of TCP segments received that
+ // the transport layer could not parse.
+ InvalidSegmentsReceived *StatCounter
+
+ // SegmentsSent is the number of TCP segments sent.
+ SegmentsSent *StatCounter
+
+ // ResetsSent is the number of TCP resets sent.
+ ResetsSent *StatCounter
+
+ // ResetsReceived is the number of TCP resets received.
+ ResetsReceived *StatCounter
+
+ // Retransmits is the number of TCP segments retransmitted.
+ Retransmits *StatCounter
+
+ // FastRecovery is the number of times Fast Recovery was used to
+ // recover from packet loss.
+ FastRecovery *StatCounter
+
+ // SACKRecovery is the number of times SACK Recovery was used to
+ // recover from packet loss.
+ SACKRecovery *StatCounter
+
+ // SlowStartRetransmits is the number of segments retransmitted in slow
+ // start.
+ SlowStartRetransmits *StatCounter
+
+ // FastRetransmit is the number of segments retransmitted in fast
+ // recovery.
+ FastRetransmit *StatCounter
+
+ // Timeouts is the number of times the RTO expired.
+ Timeouts *StatCounter
+
+ // ChecksumErrors is the number of segments dropped due to bad checksums.
+ ChecksumErrors *StatCounter
+}
+
+// UDPStats collects UDP-specific stats.
+type UDPStats struct {
+ // PacketsReceived is the number of UDP datagrams received via
+ // HandlePacket.
+ PacketsReceived *StatCounter
+
+ // UnknownPortErrors is the number of incoming UDP datagrams dropped
+ // because they did not have a known destination port.
+ UnknownPortErrors *StatCounter
+
+ // ReceiveBufferErrors is the number of incoming UDP datagrams dropped
+ // due to the receiving buffer being in an invalid state.
+ ReceiveBufferErrors *StatCounter
+
+ // MalformedPacketsReceived is the number of incoming UDP datagrams
+ // dropped due to the UDP header being in a malformed state.
+ MalformedPacketsReceived *StatCounter
+
+ // PacketsSent is the number of UDP datagrams sent via sendUDP.
+ PacketsSent *StatCounter
+}
+
+// Stats holds statistics about the networking stack.
+//
+// All fields are optional.
+type Stats struct {
+ // UnknownProtocolRcvdPackets is the number of packets received by the
+ // stack that were for an unknown or unsupported protocol.
+ UnknownProtocolRcvdPackets *StatCounter
+
+ // MalformedRcvPackets is the number of packets received by the stack
+ // that were deemed malformed.
+ MalformedRcvdPackets *StatCounter
+
+ // DroppedPackets is the number of packets dropped due to full queues.
+ DroppedPackets *StatCounter
+
+ // ICMP breaks out ICMP-specific stats (both v4 and v6).
+ ICMP ICMPStats
+
+ // IP breaks out IP-specific stats (both v4 and v6).
+ IP IPStats
+
+ // TCP breaks out TCP-specific stats.
+ TCP TCPStats
+
+ // UDP breaks out UDP-specific stats.
+ UDP UDPStats
+}
+
+func fillIn(v reflect.Value) {
+ for i := 0; i < v.NumField(); i++ {
+ v := v.Field(i)
+ switch v.Kind() {
+ case reflect.Ptr:
+ if s := v.Addr().Interface().(**StatCounter); *s == nil {
+ *s = &StatCounter{}
+ }
+ case reflect.Struct:
+ fillIn(v)
+ default:
+ panic(fmt.Sprintf("unexpected type %s", v.Type()))
+ }
+ }
+}
+
+// FillIn returns a copy of s with nil fields initialized to new StatCounters.
+func (s Stats) FillIn() Stats {
+ fillIn(reflect.ValueOf(&s).Elem())
+ return s
+}
+
+// String implements the fmt.Stringer interface.
+func (a Address) String() string {
+ switch len(a) {
+ case 4:
+ return fmt.Sprintf("%d.%d.%d.%d", int(a[0]), int(a[1]), int(a[2]), int(a[3]))
+ case 16:
+ // Find the longest subsequence of hexadecimal zeros.
+ start, end := -1, -1
+ for i := 0; i < len(a); i += 2 {
+ j := i
+ for j < len(a) && a[j] == 0 && a[j+1] == 0 {
+ j += 2
+ }
+ if j > i+2 && j-i > end-start {
+ start, end = i, j
+ }
+ }
+
+ var b strings.Builder
+ for i := 0; i < len(a); i += 2 {
+ if i == start {
+ b.WriteString("::")
+ i = end
+ if end >= len(a) {
+ break
+ }
+ } else if i > 0 {
+ b.WriteByte(':')
+ }
+ v := uint16(a[i+0])<<8 | uint16(a[i+1])
+ if v == 0 {
+ b.WriteByte('0')
+ } else {
+ const digits = "0123456789abcdef"
+ for i := uint(3); i < 4; i-- {
+ if v := v >> (i * 4); v != 0 {
+ b.WriteByte(digits[v&0xf])
+ }
+ }
+ }
+ }
+ return b.String()
+ default:
+ return fmt.Sprintf("%x", []byte(a))
+ }
+}
+
+// To4 converts the IPv4 address to a 4-byte representation.
+// If the address is not an IPv4 address, To4 returns "".
+func (a Address) To4() Address {
+ const (
+ ipv4len = 4
+ ipv6len = 16
+ )
+ if len(a) == ipv4len {
+ return a
+ }
+ if len(a) == ipv6len &&
+ isZeros(a[0:10]) &&
+ a[10] == 0xff &&
+ a[11] == 0xff {
+ return a[12:16]
+ }
+ return ""
+}
+
+// isZeros reports whether a is all zeros.
+func isZeros(a Address) bool {
+ for i := 0; i < len(a); i++ {
+ if a[i] != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// LinkAddress is a byte slice cast as a string that represents a link address.
+// It is typically a 6-byte MAC address.
+type LinkAddress string
+
+// String implements the fmt.Stringer interface.
+func (a LinkAddress) String() string {
+ switch len(a) {
+ case 6:
+ return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5])
+ default:
+ return fmt.Sprintf("%x", []byte(a))
+ }
+}
+
+// ParseMACAddress parses an IEEE 802 address.
+//
+// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
+func ParseMACAddress(s string) (LinkAddress, error) {
+ parts := strings.FieldsFunc(s, func(c rune) bool {
+ return c == ':' || c == '-'
+ })
+ if len(parts) != 6 {
+ return "", fmt.Errorf("inconsistent parts: %s", s)
+ }
+ addr := make([]byte, 0, len(parts))
+ for _, part := range parts {
+ u, err := strconv.ParseUint(part, 16, 8)
+ if err != nil {
+ return "", fmt.Errorf("invalid hex digits: %s", s)
+ }
+ addr = append(addr, byte(u))
+ }
+ return LinkAddress(addr), nil
+}
+
+// ProtocolAddress is an address and the network protocol it is associated
+// with.
+type ProtocolAddress struct {
+ // Protocol is the protocol of the address.
+ Protocol NetworkProtocolNumber
+
+ // Address is a network address.
+ Address Address
+}
+
+// danglingEndpointsMu protects access to danglingEndpoints.
+var danglingEndpointsMu sync.Mutex
+
+// danglingEndpoints tracks all dangling endpoints no longer owned by the app.
+var danglingEndpoints = make(map[Endpoint]struct{})
+
+// GetDanglingEndpoints returns all dangling endpoints.
+func GetDanglingEndpoints() []Endpoint {
+ es := make([]Endpoint, 0, len(danglingEndpoints))
+ danglingEndpointsMu.Lock()
+ for e := range danglingEndpoints {
+ es = append(es, e)
+ }
+ danglingEndpointsMu.Unlock()
+ return es
+}
+
+// AddDanglingEndpoint adds a dangling endpoint.
+func AddDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ danglingEndpoints[e] = struct{}{}
+ danglingEndpointsMu.Unlock()
+}
+
+// DeleteDanglingEndpoint removes a dangling endpoint.
+func DeleteDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ delete(danglingEndpoints, e)
+ danglingEndpointsMu.Unlock()
+}
+
+// AsyncLoading is the global barrier for asynchronous endpoint loading
+// activities.
+var AsyncLoading sync.WaitGroup
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
new file mode 100755
index 000000000..3ed2e29f4
--- /dev/null
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -0,0 +1,40 @@
+// automatically generated by stateify.
+
+package tcpip
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *FullAddress) beforeSave() {}
+func (x *FullAddress) save(m state.Map) {
+ x.beforeSave()
+ m.Save("NIC", &x.NIC)
+ m.Save("Addr", &x.Addr)
+ m.Save("Port", &x.Port)
+}
+
+func (x *FullAddress) afterLoad() {}
+func (x *FullAddress) load(m state.Map) {
+ m.Load("NIC", &x.NIC)
+ m.Load("Addr", &x.Addr)
+ m.Load("Port", &x.Port)
+}
+
+func (x *ControlMessages) beforeSave() {}
+func (x *ControlMessages) save(m state.Map) {
+ x.beforeSave()
+ m.Save("HasTimestamp", &x.HasTimestamp)
+ m.Save("Timestamp", &x.Timestamp)
+}
+
+func (x *ControlMessages) afterLoad() {}
+func (x *ControlMessages) load(m state.Map) {
+ m.Load("HasTimestamp", &x.HasTimestamp)
+ m.Load("Timestamp", &x.Timestamp)
+}
+
+func init() {
+ state.Register("tcpip.FullAddress", (*FullAddress)(nil), state.Fns{Save: (*FullAddress).save, Load: (*FullAddress).load})
+ state.Register("tcpip.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load})
+}
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
new file mode 100644
index 000000000..a52262e87
--- /dev/null
+++ b/pkg/tcpip/time_unsafe.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.9
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+package tcpip
+
+import (
+ _ "time" // Used with go:linkname.
+ _ "unsafe" // Required for go:linkname.
+)
+
+// StdClock implements Clock with the time package.
+type StdClock struct{}
+
+var _ Clock = (*StdClock)(nil)
+
+//go:linkname now time.now
+func now() (sec int64, nsec int32, mono int64)
+
+// NowNanoseconds implements Clock.NowNanoseconds.
+func (*StdClock) NowNanoseconds() int64 {
+ sec, nsec, _ := now()
+ return sec*1e9 + int64(nsec)
+}
+
+// NowMonotonic implements Clock.NowMonotonic.
+func (*StdClock) NowMonotonic() int64 {
+ _, _, mono := now()
+ return mono
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
new file mode 100644
index 000000000..e2b90ef10
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -0,0 +1,710 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp
+
+import (
+ "encoding/binary"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type icmpPacket struct {
+ icmpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents an ICMP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList icmpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+ id stack.TransportEndpointID
+ state endpointState
+ // bindNICID and bindAddr are set via calls to Bind(). They are used to
+ // reject attempts to send data or connect via a different NIC or
+ // address
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ // regNICID is the default NIC to be used when callers don't specify a
+ // NIC.
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ transProto: transProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }, nil
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ }
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ if to == nil {
+ route = &e.route
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in Route.Resolve()
+ // call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ toCopy := *to
+ to = &toCopy
+ netProto, err := e.checkV4Mapped(to, true)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Find the enpoint.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer r.Release()
+
+ route = &r
+ }
+
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch e.netProto {
+ case header.IPv4ProtocolNumber:
+ err = e.send4(route, v)
+
+ case header.IPv6ProtocolNumber:
+ err = send6(route, e.id.LocalPort, v)
+ }
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(len(v)), nil, nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv4EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort)
+
+ hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ copy(icmpv4, data)
+ data = data[header.ICMPv4EchoMinimumSize:]
+
+ // Linux performs these basic checks.
+ if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv4.SetChecksum(0)
+ icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+}
+
+func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv6EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Set the ident. Sequence number is provided by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
+
+ hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(icmpv6, data)
+ data = data[header.ICMPv6EchoMinimumSize:]
+
+ if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv6.SetChecksum(0)
+ icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
+
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL())
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ localPort := uint16(0)
+ switch e.state {
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.regNICID = nicid
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ if e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if id.LocalPort != 0 {
+ // The endpoint already has a local port, just attempt to
+ // register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ return id, err
+ }
+
+ // We need to find a port for the endpoint.
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ id.LocalPort = p
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.regNICID = addr.NIC
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ e.bindNICID = addr.NIC
+ e.bindAddr = addr.Addr
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ // Only accept echo replies.
+ switch e.netProto {
+ case header.IPv4ProtocolNumber:
+ h := header.ICMPv4(vv.First())
+ if h.Type() != header.ICMPv4EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ case header.IPv6ProtocolNumber:
+ h := header.ICMPv6(vv.First())
+ if h.Type() != header.ICMPv6EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ }
+
+ e.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &icmpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ },
+ }
+
+ pkt.data = vv.Clone(pkt.views[:])
+
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += pkt.data.Size()
+
+ pkt.timestamp = e.stack.NowNanoseconds()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
new file mode 100644
index 000000000..332b3cd33
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves icmpPacket.data field.
+func (p *icmpPacket) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads icmpPacket.data field.
+func (p *icmpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
+ if err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go
new file mode 100755
index 000000000..1b35e5b4a
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go
@@ -0,0 +1,173 @@
+package icmp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type icmpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type icmpPacketList struct {
+ head *icmpPacket
+ tail *icmpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *icmpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *icmpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *icmpPacketList) Front() *icmpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *icmpPacketList) Back() *icmpPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *icmpPacketList) PushFront(e *icmpPacket) {
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *icmpPacketList) PushBack(e *icmpPacket) {
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *icmpPacketList) PushBackList(m *icmpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) {
+ a := icmpPacketElementMapper{}.linkerFor(b).Next()
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) {
+ b := icmpPacketElementMapper{}.linkerFor(a).Prev()
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *icmpPacketList) Remove(e *icmpPacket) {
+ prev := icmpPacketElementMapper{}.linkerFor(e).Prev()
+ next := icmpPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ icmpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type icmpPacketEntry struct {
+ next *icmpPacket
+ prev *icmpPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *icmpPacketEntry) Next() *icmpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *icmpPacketEntry) Prev() *icmpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *icmpPacketEntry) SetNext(elem *icmpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
new file mode 100755
index 000000000..b66857348
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
@@ -0,0 +1,98 @@
+// automatically generated by stateify.
+
+package icmp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *icmpPacket) beforeSave() {}
+func (x *icmpPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("icmpPacketEntry", &x.icmpPacketEntry)
+ m.Save("senderAddress", &x.senderAddress)
+ m.Save("timestamp", &x.timestamp)
+}
+
+func (x *icmpPacket) afterLoad() {}
+func (x *icmpPacket) load(m state.Map) {
+ m.Load("icmpPacketEntry", &x.icmpPacketEntry)
+ m.Load("senderAddress", &x.senderAddress)
+ m.Load("timestamp", &x.timestamp)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("transProto", &x.transProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvReady", &x.rcvReady)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("id", &x.id)
+ m.Save("state", &x.state)
+ m.Save("bindNICID", &x.bindNICID)
+ m.Save("bindAddr", &x.bindAddr)
+ m.Save("regNICID", &x.regNICID)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("transProto", &x.transProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvReady", &x.rcvReady)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("id", &x.id)
+ m.Load("state", &x.state)
+ m.Load("bindNICID", &x.bindNICID)
+ m.Load("bindAddr", &x.bindAddr)
+ m.Load("regNICID", &x.regNICID)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *icmpPacketList) beforeSave() {}
+func (x *icmpPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *icmpPacketList) afterLoad() {}
+func (x *icmpPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *icmpPacketEntry) beforeSave() {}
+func (x *icmpPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *icmpPacketEntry) afterLoad() {}
+func (x *icmpPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("icmp.icmpPacket", (*icmpPacket)(nil), state.Fns{Save: (*icmpPacket).save, Load: (*icmpPacket).load})
+ state.Register("icmp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("icmp.icmpPacketList", (*icmpPacketList)(nil), state.Fns{Save: (*icmpPacketList).save, Load: (*icmpPacketList).load})
+ state.Register("icmp.icmpPacketEntry", (*icmpPacketEntry)(nil), state.Fns{Save: (*icmpPacketEntry).save, Load: (*icmpPacketEntry).load})
+}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
new file mode 100644
index 000000000..954fde9d8
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
+// protocols for use in ping. To use it in the networking stack, this package
+// must be added to the project, and
+// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
+// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
+// calling stack.New(). Then endpoints can be created by passing
+// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
+// when calling Stack.NewEndpoint().
+package icmp
+
+import (
+ "encoding/binary"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName4 is the string representation of the icmp protocol name.
+ ProtocolName4 = "icmp4"
+
+ // ProtocolNumber4 is the ICMP protocol number.
+ ProtocolNumber4 = header.ICMPv4ProtocolNumber
+
+ // ProtocolName6 is the string representation of the icmp protocol name.
+ ProtocolName6 = "icmp6"
+
+ // ProtocolNumber6 is the IPv6-ICMP protocol number.
+ ProtocolNumber6 = header.ICMPv6ProtocolNumber
+)
+
+// protocol implements stack.TransportProtocol.
+type protocol struct {
+ number tcpip.TransportProtocolNumber
+}
+
+// Number returns the ICMP protocol number.
+func (p *protocol) Number() tcpip.TransportProtocolNumber {
+ return p.number
+}
+
+func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.IPv4ProtocolNumber
+ case ProtocolNumber6:
+ return header.IPv6ProtocolNumber
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// NewEndpoint creates a new icmp endpoint. It implements
+// stack.TransportProtocol.NewEndpoint.
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return newEndpoint(stack, netProto, p.number, waiterQueue)
+}
+
+// NewRawEndpoint creates a new raw icmp endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid icmp packet size.
+func (p *protocol) MinimumPacketSize() int {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.ICMPv4EchoMinimumSize
+ case ProtocolNumber6:
+ return header.ICMPv6EchoMinimumSize
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// ParsePorts returns the source and destination ports stored in the given icmp
+// packet.
+func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ switch p.number {
+ case ProtocolNumber4:
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil
+ case ProtocolNumber6:
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+ })
+
+ stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
+ })
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
new file mode 100644
index 000000000..1daf5823f
--- /dev/null
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -0,0 +1,521 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package raw provides the implementation of raw sockets (see raw(7)). Raw
+// sockets allow applications to:
+//
+// * manually write and inspect transport layer headers and payloads
+// * receive all traffic of a given transport protcol (e.g. ICMP or UDP)
+// * optionally write and inspect network layer and link layer headers for
+// packets
+//
+// Raw sockets don't have any notion of ports, and incoming packets are
+// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
+// receive every UDP packet received by netstack. bind(2) and connect(2) can be
+// used to filter incoming packets by source and destination.
+package raw
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is pre-allocated space to back data. As long as the packet is
+ // made up of fewer than 8 buffer.Views, no extra allocation is
+ // necessary to store packet data.
+ views [8]buffer.View `state:"nosave"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
+// have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ closed bool
+ connected bool
+ bound bool
+ // registeredNIC is the NIC to which th endpoint is explicitly
+ // registered. Is set when Connect or Bind are used to specify a NIC.
+ registeredNIC tcpip.NICID
+ // boundNIC and boundAddr are set on calls to Bind(). When callers
+ // attempt actions that would invalidate the binding data (e.g. sending
+ // data via a NIC other than boundNIC), the endpoint will return an
+ // error.
+ boundNIC tcpip.NICID
+ boundAddr tcpip.Address
+ // route is the route to a remote network endpoint. It is set via
+ // Connect(), and is valid only when conneted is true.
+ route stack.Route `state:"manual"`
+}
+
+// NewEndpoint returns a raw endpoint for the given protocols.
+// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET.
+func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != header.IPv4ProtocolNumber {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ ep := &endpoint{
+ stack: stack,
+ netProto: netProto,
+ transProto: transProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ return nil, err
+ }
+
+ return ep, nil
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ if ep.connected {
+ ep.route.Release()
+ }
+
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+// Write implements tcpip.Endpoint.Write.
+func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ ep.mu.RLock()
+
+ if ep.closed {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Did the user caller provide a destination? If not, use the connected
+ // destination.
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if !ep.connected {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrDestinationRequired
+ }
+
+ if ep.route.IsResolutionRequired() {
+ savedRoute := &ep.route
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in finishWrite.
+ ep.mu.RUnlock()
+ ep.mu.Lock()
+
+ // Make sure that the route didn't change during the
+ // time we didn't hold the lock.
+ if !ep.connected || savedRoute != &ep.route {
+ ep.mu.Unlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ n, ch, err := ep.finishWrite(payload, savedRoute)
+ ep.mu.Unlock()
+ return n, ch, err
+ }
+
+ n, ch, err := ep.finishWrite(payload, &ep.route)
+ ep.mu.RUnlock()
+ return n, ch, err
+ }
+
+ // The caller provided a destination. Reject destination address if it
+ // goes through a different NIC than the endpoint was bound to.
+ nic := opts.To.NIC
+ if ep.bound && nic != 0 && nic != ep.boundNIC {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ // We don't support IPv6 yet, so this has to be an IPv4 address.
+ if len(opts.To.Addr) != header.IPv4AddressSize {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Find the route to the destination. If boundAddress is 0,
+ // FindRoute will choose an appropriate source address.
+ route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false)
+ if err != nil {
+ ep.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ n, ch, err := ep.finishWrite(payload, &route)
+ route.Release()
+ ep.mu.RUnlock()
+ return n, ch, err
+}
+
+// finishWrite writes the payload to a route. It resolves the route if
+// necessary. It's really just a helper to make defer unnecessary in Write.
+func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // We may need to resolve the route (match a link layer address to the
+ // network address). If that requires blocking (e.g. to use ARP),
+ // return a channel on which the caller can wait.
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ payloadBytes, err := payload.Get(payload.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch ep.netProto {
+ case header.IPv4ProtocolNumber:
+ hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
+ if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil {
+ return 0, nil, err
+ }
+
+ default:
+ return 0, nil, tcpip.ErrUnknownProtocol
+ }
+
+ return uintptr(len(payloadBytes)), nil, nil
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Connect implements tcpip.Endpoint.Connect.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // We don't support IPv6 yet.
+ if len(addr.Addr) != header.IPv4AddressSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nic := addr.NIC
+ if ep.bound {
+ if ep.boundNIC == 0 {
+ // If we're bound, but not to a specific NIC, the NIC
+ // in addr will be used. Nothing to do here.
+ } else if addr.NIC == 0 {
+ // If we're bound to a specific NIC, but addr doesn't
+ // specify a NIC, use the bound NIC.
+ nic = ep.boundNIC
+ } else if addr.NIC != ep.boundNIC {
+ // We're bound and addr specifies a NIC. They must be
+ // the same.
+ return tcpip.ErrInvalidEndpointState
+ }
+ }
+
+ // Find a route to the destination.
+ route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ // Save the route and NIC we've connected via.
+ ep.route = route.Clone()
+ ep.registeredNIC = nic
+ ep.connected = true
+
+ return nil
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if !ep.connected {
+ return tcpip.ErrNotConnected
+ }
+ return nil
+}
+
+// Listen implements tcpip.Endpoint.Listen.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // Callers must provide an IPv4 address or no network address (for
+ // binding to a NIC, but not an address).
+ if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If a local address was specified, verify that it's valid.
+ if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ ep.registeredNIC = addr.NIC
+ ep.boundNIC = addr.NIC
+ ep.boundAddr = addr.Addr
+ ep.bound = true
+
+ return nil
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
+ ep.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
+ ep.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ ep.rcvMu.Lock()
+ if ep.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := ep.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ ep.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.rcvMu.Unlock()
+ return
+ }
+
+ if ep.bound {
+ // If bound to a NIC, only accept data for that NIC.
+ if ep.boundNIC != 0 && ep.boundNIC != route.NICID() {
+ ep.rcvMu.Unlock()
+ return
+ }
+ // If bound to an address, only accept data for that address.
+ if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress {
+ ep.rcvMu.Unlock()
+ return
+ }
+ }
+
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if ep.connected && ep.route.RemoteAddress != route.RemoteAddress {
+ ep.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ packet := &packet{
+ senderAddr: tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ },
+ }
+
+ combinedVV := netHeader.ToVectorisedView()
+ combinedVV.Append(vv)
+ packet.data = combinedVV.Clone(packet.views[:])
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
new file mode 100644
index 000000000..e8907ebb1
--- /dev/null
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -0,0 +1,88 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raw
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // If the endpoint is connected, re-connect via the save/restore stack.
+ if ep.connected {
+ var err *tcpip.Error
+ ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
+ if err != nil {
+ panic(*err)
+ }
+ }
+
+ // If the endpoint is bound, re-bind via the save/restore stack.
+ if ep.bound {
+ if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/packet_list.go b/pkg/tcpip/transport/raw/packet_list.go
new file mode 100755
index 000000000..2e9074934
--- /dev/null
+++ b/pkg/tcpip/transport/raw/packet_list.go
@@ -0,0 +1,173 @@
+package raw
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type packetElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (packetElementMapper) linkerFor(elem *packet) *packet { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type packetList struct {
+ head *packet
+ tail *packet
+}
+
+// Reset resets list l to the empty state.
+func (l *packetList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *packetList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *packetList) Front() *packet {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *packetList) Back() *packet {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *packetList) PushFront(e *packet) {
+ packetElementMapper{}.linkerFor(e).SetNext(l.head)
+ packetElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ packetElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *packetList) PushBack(e *packet) {
+ packetElementMapper{}.linkerFor(e).SetNext(nil)
+ packetElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *packetList) PushBackList(m *packetList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *packetList) InsertAfter(b, e *packet) {
+ a := packetElementMapper{}.linkerFor(b).Next()
+ packetElementMapper{}.linkerFor(e).SetNext(a)
+ packetElementMapper{}.linkerFor(e).SetPrev(b)
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *packetList) InsertBefore(a, e *packet) {
+ b := packetElementMapper{}.linkerFor(a).Prev()
+ packetElementMapper{}.linkerFor(e).SetNext(a)
+ packetElementMapper{}.linkerFor(e).SetPrev(b)
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *packetList) Remove(e *packet) {
+ prev := packetElementMapper{}.linkerFor(e).Prev()
+ next := packetElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ packetElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ packetElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type packetEntry struct {
+ next *packet
+ prev *packet
+}
+
+// Next returns the entry that follows e in the list.
+func (e *packetEntry) Next() *packet {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *packetEntry) Prev() *packet {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *packetEntry) SetNext(elem *packet) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *packetEntry) SetPrev(elem *packet) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go
new file mode 100755
index 000000000..3327811b4
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_state_autogen.go
@@ -0,0 +1,96 @@
+// automatically generated by stateify.
+
+package raw
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *packet) beforeSave() {}
+func (x *packet) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("packetEntry", &x.packetEntry)
+ m.Save("timestampNS", &x.timestampNS)
+ m.Save("senderAddr", &x.senderAddr)
+}
+
+func (x *packet) afterLoad() {}
+func (x *packet) load(m state.Map) {
+ m.Load("packetEntry", &x.packetEntry)
+ m.Load("timestampNS", &x.timestampNS)
+ m.Load("senderAddr", &x.senderAddr)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("transProto", &x.transProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("closed", &x.closed)
+ m.Save("connected", &x.connected)
+ m.Save("bound", &x.bound)
+ m.Save("registeredNIC", &x.registeredNIC)
+ m.Save("boundNIC", &x.boundNIC)
+ m.Save("boundAddr", &x.boundAddr)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("transProto", &x.transProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("closed", &x.closed)
+ m.Load("connected", &x.connected)
+ m.Load("bound", &x.bound)
+ m.Load("registeredNIC", &x.registeredNIC)
+ m.Load("boundNIC", &x.boundNIC)
+ m.Load("boundAddr", &x.boundAddr)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *packetList) beforeSave() {}
+func (x *packetList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *packetList) afterLoad() {}
+func (x *packetList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *packetEntry) beforeSave() {}
+func (x *packetEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *packetEntry) afterLoad() {}
+func (x *packetEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("raw.packet", (*packet)(nil), state.Fns{Save: (*packet).save, Load: (*packet).load})
+ state.Register("raw.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("raw.packetList", (*packetList)(nil), state.Fns{Save: (*packetList).save, Load: (*packetList).load})
+ state.Register("raw.packetEntry", (*packetEntry)(nil), state.Fns{Save: (*packetEntry).save, Load: (*packetEntry).load})
+}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
new file mode 100644
index 000000000..d4b860975
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -0,0 +1,499 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "crypto/sha1"
+ "encoding/binary"
+ "hash"
+ "io"
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // tsLen is the length, in bits, of the timestamp in the SYN cookie.
+ tsLen = 8
+
+ // tsMask is a mask for timestamp values (i.e., tsLen bits).
+ tsMask = (1 << tsLen) - 1
+
+ // tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
+ tsOffset = 24
+
+ // hashMask is the mask for hash values (i.e., tsOffset bits).
+ hashMask = (1 << tsOffset) - 1
+
+ // maxTSDiff is the maximum allowed difference between a received cookie
+ // timestamp and the current timestamp. If the difference is greater
+ // than maxTSDiff, the cookie is expired.
+ maxTSDiff = 2
+)
+
+var (
+ // SynRcvdCountThreshold is the global maximum number of connections
+ // that are allowed to be in SYN-RCVD state before TCP starts using SYN
+ // cookies to accept connections.
+ //
+ // It is an exported variable only for testing, and should not otherwise
+ // be used by importers of this package.
+ SynRcvdCountThreshold uint64 = 1000
+
+ // mssTable is a slice containing the possible MSS values that we
+ // encode in the SYN cookie with two bits.
+ mssTable = []uint16{536, 1300, 1440, 1460}
+)
+
+func encodeMSS(mss uint16) uint32 {
+ for i := len(mssTable) - 1; i > 0; i-- {
+ if mss >= mssTable[i] {
+ return uint32(i)
+ }
+ }
+ return 0
+}
+
+// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
+// protected by a mutex so that we can increment only when it's guaranteed not
+// to go above a threshold.
+var synRcvdCount struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+}
+
+// listenContext is used by a listening endpoint to store state used while
+// listening for connections. This struct is allocated by the listen goroutine
+// and must not be accessed or have its methods called concurrently as they
+// may mutate the stored objects.
+type listenContext struct {
+ stack *stack.Stack
+ rcvWnd seqnum.Size
+ nonce [2][sha1.BlockSize]byte
+ listenEP *endpoint
+
+ hasherMu sync.Mutex
+ hasher hash.Hash
+ v6only bool
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
+func timeStamp() uint32 {
+ return uint32(time.Now().Unix()>>6) & tsMask
+}
+
+// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
+// state. It succeeds if the increment doesn't make the count go beyond the
+// threshold, and fails otherwise.
+func incSynRcvdCount() bool {
+ synRcvdCount.Lock()
+
+ if synRcvdCount.value >= SynRcvdCountThreshold {
+ synRcvdCount.Unlock()
+ return false
+ }
+
+ synRcvdCount.pending.Add(1)
+ synRcvdCount.value++
+
+ synRcvdCount.Unlock()
+ return true
+}
+
+// decSynRcvdCount atomically decrements the global number of endpoints in
+// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
+// succeeded.
+func decSynRcvdCount() {
+ synRcvdCount.Lock()
+
+ synRcvdCount.value--
+ synRcvdCount.pending.Done()
+ synRcvdCount.Unlock()
+}
+
+// newListenContext creates a new listen context.
+func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+ l := &listenContext{
+ stack: stack,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6only: v6only,
+ netProto: netProto,
+ listenEP: listenEP,
+ }
+
+ rand.Read(l.nonce[0][:])
+ rand.Read(l.nonce[1][:])
+
+ return l
+}
+
+// cookieHash calculates the cookieHash for the given id, timestamp and nonce
+// index. The hash is used to create and validate cookies.
+func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
+
+ // Initialize block with fixed-size data: local ports and v.
+ var payload [8]byte
+ binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
+ binary.BigEndian.PutUint32(payload[4:], ts)
+
+ // Feed everything to the hasher.
+ l.hasherMu.Lock()
+ l.hasher.Reset()
+ l.hasher.Write(payload[:])
+ l.hasher.Write(l.nonce[nonceIndex][:])
+ io.WriteString(l.hasher, string(id.LocalAddress))
+ io.WriteString(l.hasher, string(id.RemoteAddress))
+
+ // Finalize the calculation of the hash and return the first 4 bytes.
+ h := make([]byte, 0, sha1.Size)
+ h = l.hasher.Sum(h)
+ l.hasherMu.Unlock()
+
+ return binary.BigEndian.Uint32(h[:])
+}
+
+// createCookie creates a SYN cookie for the given id and incoming sequence
+// number.
+func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value {
+ ts := timeStamp()
+ v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
+ v += (l.cookieHash(id, ts, 1) + data) & hashMask
+ return seqnum.Value(v)
+}
+
+// isCookieValid checks if the supplied cookie is valid for the given id and
+// sequence number. If it is, it also returns the data originally encoded in the
+// cookie when createCookie was called.
+func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
+ ts := timeStamp()
+ v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
+ cookieTS := v >> tsOffset
+ if ((ts - cookieTS) & tsMask) > maxTSDiff {
+ return 0, false
+ }
+
+ return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
+}
+
+// createConnectingEndpoint creates a new endpoint in a connecting state, with
+// the connection parameters given by the arguments.
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+ // Create a new endpoint.
+ netProto := l.netProto
+ if netProto == 0 {
+ netProto = s.route.NetProto
+ }
+ n := newEndpoint(l.stack, netProto, nil)
+ n.v6only = l.v6only
+ n.id = s.id
+ n.boundNICID = s.route.NICID()
+ n.route = s.route.Clone()
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.rcvBufSize = int(l.rcvWnd)
+
+ n.maybeEnableTimestamp(rcvdSynOpts)
+ n.maybeEnableSACKPermitted(rcvdSynOpts)
+
+ n.initGSO()
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ n.Close()
+ return nil, err
+ }
+
+ n.isRegistered = true
+ n.state = stateConnecting
+
+ // Create sender and receiver.
+ //
+ // The receiver at least temporarily has a zero receive window scale,
+ // but the caller may change it (before starting the protocol loop).
+ n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
+ n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
+
+ return n, nil
+}
+
+// createEndpoint creates a new endpoint in connected state and then performs
+// the TCP 3-way handshake.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+ // Create new endpoint.
+ irs := s.sequenceNumber
+ cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
+ ep, err := l.createConnectingEndpoint(s, cookie, irs, opts)
+ if err != nil {
+ return nil, err
+ }
+
+ // Perform the 3-way handshake.
+ h := newHandshake(ep, l.rcvWnd)
+
+ h.resetToSynRcvd(cookie, irs, opts, l.listenEP)
+ if err := h.execute(); err != nil {
+ ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ ep.Close()
+ return nil, err
+ }
+
+ ep.state = stateConnected
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+
+ return ep, nil
+}
+
+// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
+// endpoint has transitioned out of the listen state, the new endpoint is closed
+// instead.
+func (e *endpoint) deliverAccepted(n *endpoint) {
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == stateListen {
+ e.acceptedChan <- n
+ e.waiterQueue.Notify(waiter.EventIn)
+ } else {
+ n.Close()
+ }
+}
+
+// handleSynSegment is called in its own goroutine once the listening endpoint
+// receives a SYN segment. It is responsible for completing the handshake and
+// queueing the new endpoint for acceptance.
+//
+// A limited number of these goroutines are allowed before TCP starts using SYN
+// cookies to accept connections.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
+ defer decSynRcvdCount()
+ defer e.decSynRcvdCount()
+ defer s.decRef()
+
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ e.deliverAccepted(n)
+}
+
+func (e *endpoint) incSynRcvdCount() bool {
+ e.mu.Lock()
+ log.Printf("l: %d, c: %d, e.synRcvdCount: %d", len(e.acceptedChan), cap(e.acceptedChan), e.synRcvdCount)
+ if l, c := len(e.acceptedChan), cap(e.acceptedChan); l == c && e.synRcvdCount >= c {
+ e.mu.Unlock()
+ return false
+ }
+ e.synRcvdCount++
+ e.mu.Unlock()
+ return true
+}
+
+func (e *endpoint) decSynRcvdCount() {
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+}
+
+// handleListenSegment is called when a listening endpoint receives a segment
+// and needs to handle it.
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+ switch s.flags {
+ case header.TCPFlagSyn:
+ opts := parseSynSegmentOptions(s)
+ if incSynRcvdCount() {
+ // Drop the SYN if the listen endpoint's accept queue is
+ // overflowing.
+ if e.incSynRcvdCount() {
+ log.Printf("processing syn packet")
+ s.incRef()
+ go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+ return
+ }
+ log.Printf("dropping syn packet")
+ e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ } else {
+ // TODO(bhaskerh): Increment syncookie sent stat.
+ cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ // Send SYN with window scaling because we currently
+ // dont't encode this information in the cookie.
+ //
+ // Enable Timestamp option if the original syn did have
+ // the timestamp option specified.
+ synOpts := header.TCPSynOptions{
+ WS: -1,
+ TS: opts.TS,
+ TSVal: tcpTimeStamp(timeStampOffset()),
+ TSEcr: opts.TSVal,
+ }
+ sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ }
+
+ case header.TCPFlagAck:
+ if len(e.acceptedChan) == cap(e.acceptedChan) {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+
+ // Validate the cookie.
+ data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
+ if !ok || int(data) >= len(mssTable) {
+ e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
+ // Create newly accepted endpoint and deliver it.
+ rcvdSynOptions := &header.TCPSynOptions{
+ MSS: mssTable[data],
+ // Disable Window scaling as original SYN is
+ // lost.
+ WS: -1,
+ }
+
+ // When syn cookies are in use we enable timestamp only
+ // if the ack specifies the timestamp option assuming
+ // that the other end did in fact negotiate the
+ // timestamp option in the original SYN.
+ if s.parsedOptions.TS {
+ rcvdSynOptions.TS = true
+ rcvdSynOptions.TSVal = s.parsedOptions.TSVal
+ rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
+ }
+
+ n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ // clear the tsOffset for the newly created
+ // endpoint as the Timestamp was already
+ // randomly offset when the original SYN-ACK was
+ // sent above.
+ n.tsOffset = 0
+
+ // Switch state to connected.
+ n.state = stateConnected
+
+ // Do the delivery in a separate goroutine so
+ // that we don't block the listen loop in case
+ // the application is slow to accept or stops
+ // accepting.
+ //
+ // NOTE: This won't result in an unbounded
+ // number of goroutines as we do check before
+ // entering here that there was at least some
+ // space available in the backlog.
+ go e.deliverAccepted(n)
+ }
+}
+
+// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
+// its own goroutine and is responsible for handling connection requests.
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ defer func() {
+ // Mark endpoint as closed. This will prevent goroutines running
+ // handleSynSegment() from attempting to queue new connections
+ // to the endpoint.
+ e.mu.Lock()
+ e.state = stateClosed
+
+ // Do cleanup if needed.
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+ e.mu.Unlock()
+
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ }()
+
+ e.mu.Lock()
+ v6only := e.v6only
+ e.mu.Unlock()
+
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.notificationWaker, wakerForNotification)
+ s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
+ for {
+ switch index, _ := s.Fetch(true); index {
+ case wakerForNotification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ s := e.segmentQueue.dequeue()
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+ synRcvdCount.pending.Wait()
+ close(e.drainDone)
+ <-e.undrain
+ }
+
+ case wakerForNewSegment:
+ // Process at most maxSegmentsPerWake segments.
+ mayRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ mayRequeue = false
+ break
+ }
+
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up
+ // in the next iteration.
+ if mayRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
new file mode 100644
index 000000000..2aed6f286
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -0,0 +1,1066 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// maxSegmentsPerWake is the maximum number of segments to process in the main
+// protocol goroutine per wake-up. Yielding [after this number of segments are
+// processed] allows other events to be processed as well (e.g., timeouts,
+// resets, etc.).
+const maxSegmentsPerWake = 100
+
+type handshakeState int
+
+// The following are the possible states of the TCP connection during a 3-way
+// handshake. A depiction of the states and transitions can be found in RFC 793,
+// page 23.
+const (
+ handshakeSynSent handshakeState = iota
+ handshakeSynRcvd
+ handshakeCompleted
+)
+
+// The following are used to set up sleepers.
+const (
+ wakerForNotification = iota
+ wakerForNewSegment
+ wakerForResend
+ wakerForResolution
+)
+
+const (
+ // Maximum space available for options.
+ maxOptionSize = 40
+)
+
+// handshake holds the state used during a TCP 3-way handshake.
+type handshake struct {
+ ep *endpoint
+ listenEP *endpoint // only non nil when doing passive connects.
+ state handshakeState
+ active bool
+ flags uint8
+ ackNum seqnum.Value
+
+ // iss is the initial send sequence number, as defined in RFC 793.
+ iss seqnum.Value
+
+ // rcvWnd is the receive window, as defined in RFC 793.
+ rcvWnd seqnum.Size
+
+ // sndWnd is the send window, as defined in RFC 793.
+ sndWnd seqnum.Size
+
+ // mss is the maximum segment size received from the peer.
+ mss uint16
+
+ // sndWndScale is the send window scale, as defined in RFC 1323. A
+ // negative value means no scaling is supported by the peer.
+ sndWndScale int
+
+ // rcvWndScale is the receive window scale, as defined in RFC 1323.
+ rcvWndScale int
+}
+
+func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
+ h := handshake{
+ ep: ep,
+ active: true,
+ rcvWnd: rcvWnd,
+ rcvWndScale: FindWndScale(rcvWnd),
+ }
+ h.resetState()
+ return h
+}
+
+// FindWndScale determines the window scale to use for the given maximum window
+// size.
+func FindWndScale(wnd seqnum.Size) int {
+ if wnd < 0x10000 {
+ return 0
+ }
+
+ max := seqnum.Size(0xffff)
+ s := 0
+ for wnd > max && s < header.MaxWndScale {
+ s++
+ max <<= 1
+ }
+
+ return s
+}
+
+// resetState resets the state of the handshake object such that it becomes
+// ready for a new 3-way handshake.
+func (h *handshake) resetState() {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+
+ h.state = handshakeSynSent
+ h.flags = header.TCPFlagSyn
+ h.ackNum = 0
+ h.mss = 0
+ h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+}
+
+// effectiveRcvWndScale returns the effective receive window scale to be used.
+// If the peer doesn't support window scaling, the effective rcv wnd scale is
+// zero; otherwise it's the value calculated based on the initial rcv wnd.
+func (h *handshake) effectiveRcvWndScale() uint8 {
+ if h.sndWndScale < 0 {
+ return 0
+ }
+ return uint8(h.rcvWndScale)
+}
+
+// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
+// state.
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, listenEP *endpoint) {
+ h.active = false
+ h.state = handshakeSynRcvd
+ h.flags = header.TCPFlagSyn | header.TCPFlagAck
+ h.iss = iss
+ h.ackNum = irs + 1
+ h.mss = opts.MSS
+ h.sndWndScale = opts.WS
+ h.listenEP = listenEP
+}
+
+// checkAck checks if the ACK number, if present, of a segment received during
+// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in
+// response.
+func (h *handshake) checkAck(s *segment) bool {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 {
+ // RFC 793, page 36, states that a reset must be generated when
+ // the connection is in any non-synchronized state and an
+ // incoming segment acknowledges something not yet sent. The
+ // connection remains in the same state.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, s.ackNumber, ack, 0)
+ return false
+ }
+
+ return true
+}
+
+// synSentState handles a segment received when the TCP 3-way handshake is in
+// the SYN-SENT state.
+func (h *handshake) synSentState(s *segment) *tcpip.Error {
+ // RFC 793, page 37, states that in the SYN-SENT state, a reset is
+ // acceptable if the ack field acknowledges the SYN.
+ if s.flagIsSet(header.TCPFlagRst) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ // We are in the SYN-SENT state. We only care about segments that have
+ // the SYN flag.
+ if !s.flagIsSet(header.TCPFlagSyn) {
+ return nil
+ }
+
+ // Parse the SYN options.
+ rcvSynOpts := parseSynSegmentOptions(s)
+
+ // Remember if the Timestamp option was negotiated.
+ h.ep.maybeEnableTimestamp(&rcvSynOpts)
+
+ // Remember if the SACKPermitted option was negotiated.
+ h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+
+ // Remember the sequence we'll ack from now on.
+ h.ackNum = s.sequenceNumber + 1
+ h.flags |= header.TCPFlagAck
+ h.mss = rcvSynOpts.MSS
+ h.sndWndScale = rcvSynOpts.WS
+
+ // If this is a SYN ACK response, we only need to acknowledge the SYN
+ // and the handshake is completed.
+ if s.flagIsSet(header.TCPFlagAck) {
+ h.state = handshakeCompleted
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
+ return nil
+ }
+
+ // A SYN segment was received, but no ACK in it. We acknowledge the SYN
+ // but resend our own SYN and wait for it to be acknowledged in the
+ // SYN-RCVD state.
+ h.state = handshakeSynRcvd
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: rcvSynOpts.TS,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+
+ // We only send SACKPermitted if the other side indicated it
+ // permits SACK. This is not explicitly defined in the RFC but
+ // this is the behaviour implemented by Linux.
+ SACKPermitted: rcvSynOpts.SACKPermitted,
+ }
+ sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ return nil
+}
+
+// synRcvdState handles a segment received when the TCP 3-way handshake is in
+// the SYN-RCVD state.
+func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
+ if s.flagIsSet(header.TCPFlagRst) {
+ // RFC 793, page 37, states that in the SYN-RCVD state, a reset
+ // is acceptable if the sequence number is in the window.
+ if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
+ // We received two SYN segments with different sequence
+ // numbers, so we reset this and restart the whole
+ // process, except that we don't reset the timer.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ seq := seqnum.Value(0)
+ if s.flagIsSet(header.TCPFlagAck) {
+ seq = s.ackNumber
+ }
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0)
+
+ if !h.active {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ h.resetState()
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: h.ep.sendTSOk,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: h.ep.sackPermitted,
+ }
+ sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ return nil
+ }
+
+ // We have previously received (and acknowledged) the peer's SYN. If the
+ // peer acknowledges our SYN, the handshake is completed.
+ if s.flagIsSet(header.TCPFlagAck) {
+ // listenContext is also used by a tcp.Forwarder and in that
+ // context we do not have a listening endpoint to check the
+ // backlog. So skip this check if listenEP is nil.
+ if h.listenEP != nil && len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) {
+ // If there is no space in the accept queue to accept
+ // this endpoint then silently drop this ACK. The peer
+ // will anyway resend the ack and we can complete the
+ // connection the next time it's retransmitted.
+ h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+ // If the timestamp option is negotiated and the segment does
+ // not carry a timestamp option then the segment must be dropped
+ // as per https://tools.ietf.org/html/rfc7323#section-3.2.
+ if h.ep.sendTSOk && !s.parsedOptions.TS {
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
+ // Update timestamp if required. See RFC7323, section-4.3.
+ if h.ep.sendTSOk && s.parsedOptions.TS {
+ h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
+ }
+ h.state = handshakeCompleted
+ return nil
+ }
+
+ return nil
+}
+
+func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+ h.sndWnd = s.window
+ if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 {
+ h.sndWnd <<= uint8(h.sndWndScale)
+ }
+
+ switch h.state {
+ case handshakeSynRcvd:
+ return h.synRcvdState(s)
+ case handshakeSynSent:
+ return h.synSentState(s)
+ }
+ return nil
+}
+
+// processSegments goes through the segment queue and processes up to
+// maxSegmentsPerWake (if they're available).
+func (h *handshake) processSegments() *tcpip.Error {
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := h.ep.segmentQueue.dequeue()
+ if s == nil {
+ return nil
+ }
+
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+
+ // We stop processing packets once the handshake is completed,
+ // otherwise we may process packets meant to be processed by
+ // the main protocol goroutine.
+ if h.state == handshakeCompleted {
+ break
+ }
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if !h.ep.segmentQueue.empty() {
+ h.ep.newSegmentWaker.Assert()
+ }
+
+ return nil
+}
+
+func (h *handshake) resolveRoute() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resolutionWaker := &sleep.Waker{}
+ s.AddWaker(resolutionWaker, wakerForResolution)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ defer s.Done()
+
+ // Initial action is to resolve route.
+ index := wakerForResolution
+ for {
+ switch index {
+ case wakerForResolution:
+ if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ // Either success (err == nil) or failure.
+ return err
+ }
+ // Resolution not completed. Keep trying...
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ h.ep.route.RemoveWaker(resolutionWaker)
+ return tcpip.ErrAborted
+ }
+ if n&notifyDrain != 0 {
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
+ }
+
+ // Wait for notification.
+ index, _ = s.Fetch(true)
+ }
+}
+
+// execute executes the TCP 3-way handshake.
+func (h *handshake) execute() *tcpip.Error {
+ if h.ep.route.IsResolutionRequired() {
+ if err := h.resolveRoute(); err != nil {
+ return err
+ }
+ }
+
+ // Initialize the resend timer.
+ resendWaker := sleep.Waker{}
+ timeOut := time.Duration(time.Second)
+ rt := time.AfterFunc(timeOut, func() {
+ resendWaker.Assert()
+ })
+ defer rt.Stop()
+
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+
+ var sackEnabled SACKEnabled
+ if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
+ // If stack returned an error when checking for SACKEnabled
+ // status then just default to switching off SACK negotiation.
+ sackEnabled = false
+ }
+
+ // Send the initial SYN segment and loop until the handshake is
+ // completed.
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: true,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: bool(sackEnabled),
+ }
+
+ // Execute is also called in a listen context so we want to make sure we
+ // only send the TS/SACK option when we received the TS/SACK in the
+ // initial SYN.
+ if h.state == handshakeSynRcvd {
+ synOpts.TS = h.ep.sendTSOk
+ synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ }
+ sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ for h.state != handshakeCompleted {
+ switch index, _ := s.Fetch(true); index {
+ case wakerForResend:
+ timeOut *= 2
+ if timeOut > 60*time.Second {
+ return tcpip.ErrTimeout
+ }
+ rt.Reset(timeOut)
+ sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ return tcpip.ErrAborted
+ }
+ if n&notifyDrain != 0 {
+ for !h.ep.segmentQueue.empty() {
+ s := h.ep.segmentQueue.dequeue()
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+ if h.state == handshakeCompleted {
+ return nil
+ }
+ }
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
+
+ case wakerForNewSegment:
+ if err := h.processSegments(); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
+ synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
+ if synOpts.TS {
+ s.parsedOptions.TSVal = synOpts.TSVal
+ s.parsedOptions.TSEcr = synOpts.TSEcr
+ }
+ return synOpts
+}
+
+var optionPool = sync.Pool{
+ New: func() interface{} {
+ return make([]byte, maxOptionSize)
+ },
+}
+
+func getOptions() []byte {
+ return optionPool.Get().([]byte)
+}
+
+func putOptions(options []byte) {
+ // Reslice to full capacity.
+ optionPool.Put(options[0:cap(options)])
+}
+
+func makeSynOptions(opts header.TCPSynOptions) []byte {
+ // Emulate linux option order. This is as follows:
+ //
+ // if md5: NOP NOP MD5SIG 18 md5sig(16)
+ // if mss: MSS 4 mss(2)
+ // if ts and sack_advertise:
+ // SACK 2 TIMESTAMP 2 timestamp(8)
+ // elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
+ // elif sack: NOP NOP SACK 2
+ // if wscale: NOP WINDOW 3 ws(1)
+ // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
+ // [for each block] start_seq(4) end_seq(4)
+ // if fastopen_cookie:
+ // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
+ // else: FASTOPEN (2 + len(cookie))
+ // cookie(variable) [padding to four bytes]
+ //
+ options := getOptions()
+
+ // Always encode the mss.
+ offset := header.EncodeMSSOption(uint32(opts.MSS), options)
+
+ // Special ordering is required here. If both TS and SACK are enabled,
+ // then the SACK option precedes TS, with no padding. If they are
+ // enabled individually, then we see padding before the option.
+ if opts.TS && opts.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.TS {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.SACKPermitted {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ }
+
+ // Initialize the WS option.
+ if opts.WS >= 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeWSOption(opts.WS, options[offset:])
+ }
+
+ // Padding to the end; note that this never apply unless we add a
+ // fastopen option, we always expect the offset to remain the same.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
+ // The MSS in opts is automatically calculated as this function is
+ // called from many places and we don't want every call point being
+ // embedded with the MSS calculation.
+ if opts.MSS == 0 {
+ opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
+ }
+
+ options := makeSynOptions(opts)
+ err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil)
+ putOptions(options)
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ // Allocate a buffer for the TCP header.
+ hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
+
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ // Initialize the header.
+ tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
+ tcp.Encode(&header.TCPFields{
+ SrcPort: id.LocalPort,
+ DstPort: id.RemotePort,
+ SeqNum: uint32(seq),
+ AckNum: uint32(ack),
+ DataOffset: uint8(header.TCPMinimumSize + optLen),
+ Flags: flags,
+ WindowSize: uint16(rcvWnd),
+ })
+ copy(tcp[header.TCPMinimumSize:], opts)
+
+ length := uint16(hdr.UsedLength() + data.Size())
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ // Only calculate the checksum if offloading isn't supported.
+ if gso != nil && gso.NeedsCsum {
+ // This is called CHECKSUM_PARTIAL in the Linux kernel. We
+ // calculate a checksum of the pseudo-header and save it in the
+ // TCP header, then the kernel calculate a checksum of the
+ // header and data and get the right sum of the TCP packet.
+ tcp.SetChecksum(xsum)
+ } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ xsum = header.ChecksumVV(data, xsum)
+ tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
+ }
+
+ r.Stats().TCP.SegmentsSent.Increment()
+ if (flags & header.TCPFlagRst) != 0 {
+ r.Stats().TCP.ResetsSent.Increment()
+ }
+
+ return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl)
+}
+
+// makeOptions makes an options slice.
+func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
+ options := getOptions()
+ offset := 0
+
+ // N.B. the ordering here matches the ordering used by Linux internally
+ // and described in the raw makeOptions function. We don't include
+ // unnecessary cases here (post connection.)
+ if e.sendTSOk {
+ // Embed the timestamp if timestamp has been enabled.
+ //
+ // We only use the lower 32 bits of the unix time in
+ // milliseconds. This is similar to what Linux does where it
+ // uses the lower 32 bits of the jiffies value in the tsVal
+ // field of the timestamp option.
+ //
+ // Further, RFC7323 section-5.4 recommends millisecond
+ // resolution as the lowest recommended resolution for the
+ // timestamp clock.
+ //
+ // Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ }
+ if e.sackPermitted && len(sackBlocks) > 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
+ }
+
+ // We expect the above to produce an aligned offset.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+// sendRaw sends a TCP segment to the endpoint's peer.
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+ var sackBlocks []header.SACKBlock
+ if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
+ }
+ options := e.makeOptions(sackBlocks)
+ err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso)
+ putOptions(options)
+ return err
+}
+
+func (e *endpoint) handleWrite() *tcpip.Error {
+ // Move packets from send queue to send list. The queue is accessible
+ // from other goroutines and protected by the send mutex, while the send
+ // list is only accessible from the handler goroutine, so it needs no
+ // mutexes.
+ e.sndBufMu.Lock()
+
+ first := e.sndQueue.Front()
+ if first != nil {
+ e.snd.writeList.PushBackList(&e.sndQueue)
+ e.snd.sndNxtList.UpdateForward(e.sndBufInQueue)
+ e.sndBufInQueue = 0
+ }
+
+ e.sndBufMu.Unlock()
+
+ // Initialize the next segment to write if it's currently nil.
+ if e.snd.writeNext == nil {
+ e.snd.writeNext = first
+ }
+
+ // Push out any new packets.
+ e.snd.sendData()
+
+ return nil
+}
+
+func (e *endpoint) handleClose() *tcpip.Error {
+ // Drain the send queue.
+ e.handleWrite()
+
+ // Mark send side as closed.
+ e.snd.closed = true
+
+ return nil
+}
+
+// resetConnectionLocked sends a RST segment and puts the endpoint in an error
+// state with the given error code. This method must only be called from the
+// protocol goroutine.
+func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+
+ e.state = stateError
+ e.hardError = err
+}
+
+// completeWorkerLocked is called by the worker goroutine when it's about to
+// exit. It marks the worker as completed and performs cleanup work if requested
+// by Close().
+func (e *endpoint) completeWorkerLocked() {
+ e.workerRunning = false
+ if e.workerCleanup {
+ e.cleanupLocked()
+ }
+}
+
+// handleSegments pulls segments from the queue and processes them. It returns
+// no error if the protocol loop should continue, an error otherwise.
+func (e *endpoint) handleSegments() *tcpip.Error {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ s.decRef()
+ return tcpip.ErrConnectionReset
+ }
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ e.rcv.handleRcvdSegment(s)
+ e.snd.handleRcvdSegment(s)
+ }
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+
+ // Send an ACK for all processed packets if needed.
+ if e.rcv.rcvNxt != e.snd.maxSentAck {
+ e.snd.sendAck()
+ }
+
+ e.resetKeepaliveTimer(true)
+
+ return nil
+}
+
+// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
+// keepalive packets periodically when the connection is idle. If we don't hear
+// from the other side after a number of tries, we terminate the connection.
+func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ e.keepalive.Lock()
+ if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ e.keepalive.Unlock()
+ return nil
+ }
+
+ if e.keepalive.unacked >= e.keepalive.count {
+ e.keepalive.Unlock()
+ return tcpip.ErrConnectionReset
+ }
+
+ // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
+ // seg.seq = snd.nxt-1.
+ e.keepalive.unacked++
+ e.keepalive.Unlock()
+ e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1)
+ e.resetKeepaliveTimer(false)
+ return nil
+}
+
+// resetKeepaliveTimer restarts or stops the keepalive timer, depending on
+// whether it is enabled for this endpoint.
+func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
+ e.keepalive.Lock()
+ defer e.keepalive.Unlock()
+ if receivedData {
+ e.keepalive.unacked = 0
+ }
+ // Start the keepalive timer IFF it's enabled and there is no pending
+ // data to send.
+ if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ e.keepalive.timer.disable()
+ return
+ }
+ if e.keepalive.unacked > 0 {
+ e.keepalive.timer.enable(e.keepalive.interval)
+ } else {
+ e.keepalive.timer.enable(e.keepalive.idle)
+ }
+}
+
+// disableKeepaliveTimer stops the keepalive timer.
+func (e *endpoint) disableKeepaliveTimer() {
+ e.keepalive.Lock()
+ e.keepalive.timer.disable()
+ e.keepalive.Unlock()
+}
+
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
+ var closeTimer *time.Timer
+ var closeWaker sleep.Waker
+
+ epilogue := func() {
+ // e.mu is expected to be hold upon entering this section.
+
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ }
+
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+
+ e.mu.Unlock()
+
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+
+ if handshake {
+ // This is an active connection, so we must initiate the 3-way
+ // handshake, and then inform potential waiters about its
+ // completion.
+ h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable()))
+ if err := h.execute(); err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.mu.Lock()
+ e.state = stateError
+ e.hardError = err
+ // Lock released below.
+ epilogue()
+
+ return err
+ }
+
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ e.rcvListMu.Unlock()
+ }
+
+ e.keepalive.timer.init(&e.keepalive.waker)
+ defer e.keepalive.timer.cleanup()
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.mu.Lock()
+ e.state = stateConnected
+ drained := e.drainDone != nil
+ e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
+
+ e.waiterQueue.Notify(waiter.EventOut)
+
+ // Set up the functions that will be called when the main protocol loop
+ // wakes up.
+ funcs := []struct {
+ w *sleep.Waker
+ f func() *tcpip.Error
+ }{
+ {
+ w: &e.sndWaker,
+ f: e.handleWrite,
+ },
+ {
+ w: &e.sndCloseWaker,
+ f: e.handleClose,
+ },
+ {
+ w: &e.newSegmentWaker,
+ f: e.handleSegments,
+ },
+ {
+ w: &closeWaker,
+ f: func() *tcpip.Error {
+ return tcpip.ErrConnectionAborted
+ },
+ },
+ {
+ w: &e.snd.resendWaker,
+ f: func() *tcpip.Error {
+ if !e.snd.retransmitTimerExpired() {
+ return tcpip.ErrTimeout
+ }
+ return nil
+ },
+ },
+ {
+ w: &e.keepalive.waker,
+ f: e.keepaliveTimerExpired,
+ },
+ {
+ w: &e.notificationWaker,
+ f: func() *tcpip.Error {
+ n := e.fetchNotifications()
+ if n&notifyNonZeroReceiveWindow != 0 {
+ e.rcv.nonZeroWindow()
+ }
+
+ if n&notifyReceiveWindowChanged != 0 {
+ e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
+ }
+
+ if n&notifyMTUChanged != 0 {
+ e.sndBufMu.Lock()
+ count := e.packetTooBigCount
+ e.packetTooBigCount = 0
+ mtu := e.sndMTU
+ e.sndBufMu.Unlock()
+
+ e.snd.updateMaxPayloadSize(mtu, count)
+ }
+
+ if n&notifyReset != 0 {
+ e.mu.Lock()
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ }
+ if n&notifyClose != 0 && closeTimer == nil {
+ // Reset the connection 3 seconds after
+ // the endpoint has been closed.
+ //
+ // The timer could fire in background
+ // when the endpoint is drained. That's
+ // OK as the loop here will not honor
+ // the firing until the undrain arrives.
+ closeTimer = time.AfterFunc(3*time.Second, func() {
+ closeWaker.Assert()
+ })
+ }
+
+ if n&notifyKeepaliveChanged != 0 {
+ // The timer could fire in background
+ // when the endpoint is drained. That's
+ // OK. See above.
+ e.resetKeepaliveTimer(true)
+ }
+
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ if err := e.handleSegments(); err != nil {
+ return err
+ }
+ }
+ if e.state != stateError {
+ close(e.drainDone)
+ <-e.undrain
+ }
+ }
+
+ return nil
+ },
+ },
+ }
+
+ // Initialize the sleeper based on the wakers in funcs.
+ s := sleep.Sleeper{}
+ for i := range funcs {
+ s.AddWaker(funcs[i].w, i)
+ }
+
+ // The following assertions and notifications are needed for restored
+ // endpoints. Fresh newly created endpoints have empty states and should
+ // not invoke any.
+ e.segmentQueue.mu.Lock()
+ if !e.segmentQueue.list.Empty() {
+ e.newSegmentWaker.Assert()
+ }
+ e.segmentQueue.mu.Unlock()
+
+ e.rcvListMu.Lock()
+ if !e.rcvList.Empty() {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ e.rcvListMu.Unlock()
+
+ e.mu.RLock()
+ if e.workerCleanup {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+ e.mu.RUnlock()
+
+ // Main loop. Handle segments until both send and receive ends of the
+ // connection have completed.
+ for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
+ e.workMu.Unlock()
+ v, _ := s.Fetch(true)
+ e.workMu.Lock()
+ if err := funcs[v].f(); err != nil {
+ e.mu.Lock()
+ e.resetConnectionLocked(err)
+ // Lock released below.
+ epilogue()
+
+ return nil
+ }
+ }
+
+ // Mark endpoint as closed.
+ e.mu.Lock()
+ if e.state != stateError {
+ e.state = stateClosed
+ }
+ // Lock released below.
+ epilogue()
+
+ return nil
+}
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
new file mode 100644
index 000000000..e618cd2b9
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -0,0 +1,233 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "time"
+)
+
+// cubicState stores the variables related to TCP CUBIC congestion
+// control algorithm state.
+//
+// See: https://tools.ietf.org/html/rfc8312.
+type cubicState struct {
+ // wLastMax is the previous wMax value.
+ wLastMax float64
+
+ // wMax is the value of the congestion window at the
+ // time of last congestion event.
+ wMax float64
+
+ // t denotes the time when the current congestion avoidance
+ // was entered.
+ t time.Time
+
+ // numCongestionEvents tracks the number of congestion events since last
+ // RTO.
+ numCongestionEvents int
+
+ // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
+ // per RFC.
+ c float64
+
+ // k is the time period that the above function takes to increase the
+ // current window size to W_max if there are no further congestion
+ // events and is calculated using the following equation:
+ //
+ // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
+ k float64
+
+ // beta is the CUBIC multiplication decrease factor. that is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // W_cubic(0)=W_max*beta_cubic.
+ beta float64
+
+ // wC is window computed by CUBIC at time t. It's calculated using the
+ // formula:
+ //
+ // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
+ wC float64
+
+ // wEst is the window computed by CUBIC at time t+RTT i.e
+ // W_cubic(t+RTT).
+ wEst float64
+
+ s *sender
+}
+
+// newCubicCC returns a partially initialized cubic state with the constants
+// beta and c set and t set to current time.
+func newCubicCC(s *sender) *cubicState {
+ return &cubicState{
+ t: time.Now(),
+ beta: 0.7,
+ c: 0.4,
+ s: s,
+ }
+}
+
+// enterCongestionAvoidance is used to initialize cubic in cases where we exit
+// SlowStart without a real congestion event taking place. This can happen when
+// a connection goes back to slow start due to a retransmit and we exceed the
+// previously lowered ssThresh without experiencing packet loss.
+//
+// Refer: https://tools.ietf.org/html/rfc8312#section-4.8
+func (c *cubicState) enterCongestionAvoidance() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.7 &
+ // https://tools.ietf.org/html/rfc8312#section-4.8
+ if c.numCongestionEvents == 0 {
+ c.k = 0
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+ }
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window we cross
+// the ssThresh then it will return the number of packets that must be consumed
+// in congestion avoidance mode.
+func (c *cubicState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := c.s.sndCwnd + packetsAcked
+ enterCA := false
+ if newcwnd >= c.s.sndSsthresh {
+ newcwnd = c.s.sndSsthresh
+ c.s.sndCAAckCount = 0
+ enterCA = true
+ }
+
+ packetsAcked -= newcwnd - c.s.sndCwnd
+ c.s.sndCwnd = newcwnd
+ if enterCA {
+ c.enterCongestionAvoidance()
+ }
+ return packetsAcked
+}
+
+// Update updates cubic's internal state variables. It must be called on every
+// ACK received.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) Update(packetsAcked int) {
+ if c.s.sndCwnd < c.s.sndSsthresh {
+ packetsAcked = c.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ } else {
+ c.s.rtt.Lock()
+ srtt := c.s.rtt.srtt
+ c.s.rtt.Unlock()
+ c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt)
+ }
+}
+
+// cubicCwnd computes the CUBIC congestion window after t seconds from last
+// congestion event.
+func (c *cubicState) cubicCwnd(t float64) float64 {
+ return c.c*math.Pow(t, 3.0) + c.wMax
+}
+
+// getCwnd returns the current congestion window as computed by CUBIC.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
+ elapsed := time.Since(c.t).Seconds()
+
+ // Compute the window as per Cubic after 'elapsed' time
+ // since last congestion event.
+ c.wC = c.cubicCwnd(elapsed - c.k)
+
+ // Compute the TCP friendly estimate of the congestion window.
+ c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
+
+ // Make sure in the TCP friendly region CUBIC performs at least
+ // as well as Reno.
+ if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
+ // TCP Friendly region of cubic.
+ return int(c.wEst)
+ }
+
+ // In Concave/Convex region of CUBIC, calculate what CUBIC window
+ // will be after 1 RTT and use that to grow congestion window
+ // for every ack.
+ tEst := (time.Since(c.t) + srtt).Seconds()
+ wtRtt := c.cubicCwnd(tEst - c.k)
+ // As per 4.3 for each received ACK cwnd must be incremented
+ // by (w_cubic(t+RTT) - cwnd/cwnd.
+ cwnd := float64(sndCwnd)
+ for i := 0; i < packetsAcked; i++ {
+ // Concave/Convex regions of cubic have the same formulas.
+ // See: https://tools.ietf.org/html/rfc8312#section-4.3
+ cwnd += (wtRtt - cwnd) / cwnd
+ }
+ return int(cwnd)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (c *cubicState) HandleNDupAcks() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.5
+ c.numCongestionEvents++
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+ c.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
+func (c *cubicState) HandleRTOExpired() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.6
+ c.t = time.Now()
+ c.numCongestionEvents = 0
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+
+ // We lost a packet, so reduce ssthresh.
+ c.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ c.s.sndCwnd = 1
+}
+
+// fastConvergence implements the logic for Fast Convergence algorithm as
+// described in https://tools.ietf.org/html/rfc8312#section-4.6.
+func (c *cubicState) fastConvergence() {
+ if c.wMax < c.wLastMax {
+ c.wLastMax = c.wMax
+ c.wMax = c.wMax * (1.0 + c.beta) / 2.0
+ } else {
+ c.wLastMax = c.wMax
+ }
+ // Recompute k as wMax may have changed.
+ c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
+}
+
+// PostRecovery implemements congestionControl.PostRecovery.
+func (c *cubicState) PostRecovery() {
+ c.t = time.Now()
+}
+
+// reduceSlowStartThreshold returns new SsThresh as described in
+// https://tools.ietf.org/html/rfc8312#section-4.7.
+func (c *cubicState) reduceSlowStartThreshold() {
+ c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
new file mode 100644
index 000000000..fd697402e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -0,0 +1,1741 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tmutex"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateListen
+ stateConnecting
+ stateConnected
+ stateClosed
+ stateError
+)
+
+// Reasons for notifying the protocol goroutine.
+const (
+ notifyNonZeroReceiveWindow = 1 << iota
+ notifyReceiveWindowChanged
+ notifyClose
+ notifyMTUChanged
+ notifyDrain
+ notifyReset
+ notifyKeepaliveChanged
+)
+
+// SACKInfo holds TCP SACK related information for a given endpoint.
+//
+// +stateify savable
+type SACKInfo struct {
+ // Blocks is the maximum number of SACK blocks we track
+ // per endpoint.
+ Blocks [MaxSACKBlocks]header.SACKBlock
+
+ // NumBlocks is the number of valid SACK blocks stored in the
+ // blocks array above.
+ NumBlocks int
+}
+
+// endpoint represents a TCP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized. The protocol implementation, however, runs in a single
+// goroutine.
+//
+// +stateify savable
+type endpoint struct {
+ // workMu is used to arbitrate which goroutine may perform protocol
+ // work. Only the main protocol goroutine is expected to call Lock() on
+ // it, but other goroutines (e.g., send) may call TryLock() to eagerly
+ // perform work without having to wait for the main one to wake up.
+ workMu tmutex.Mutex `state:"nosave"`
+
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue `state:"wait"`
+
+ // lastError represents the last error that the endpoint reported;
+ // access to it is protected by the following mutex.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
+
+ // The following fields are used to manage the receive queue. The
+ // protocol goroutine adds ready-for-delivery segments to rcvList,
+ // which are returned by Read() calls to users.
+ //
+ // Once the peer has closed its send side, rcvClosed is set to true
+ // to indicate to users that no more data is coming.
+ //
+ // rcvListMu can be taken after the endpoint mu below.
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ rcvBufSize int
+ rcvBufUsed int
+
+ // The following fields are protected by the mutex.
+ mu sync.RWMutex `state:"nosave"`
+ id stack.TransportEndpointID
+ state endpointState `state:".(endpointState)"`
+ isPortReserved bool `state:"manual"`
+ isRegistered bool
+ boundNICID tcpip.NICID `state:"manual"`
+ route stack.Route `state:"manual"`
+ v6only bool
+ isConnectNotified bool
+ // TCP should never broadcast but Linux nevertheless supports enabling/
+ // disabling SO_BROADCAST, albeit as a NOOP.
+ broadcast bool
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
+
+ // hardError is meaningful only when state is stateError, it stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. hardError is protected by mu.
+ hardError *tcpip.Error `state:".(string)"`
+
+ // workerRunning specifies if a worker goroutine is running.
+ workerRunning bool
+
+ // workerCleanup specifies if the worker goroutine must perform cleanup
+ // before exitting. This can only be set to true when workerRunning is
+ // also true, and they're both protected by the mutex.
+ workerCleanup bool
+
+ // sendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1
+ sendTSOk bool
+
+ // recentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ recentTS uint32
+
+ // tsOffset is a randomized offset added to the value of the
+ // TSVal field in the timestamp option.
+ tsOffset uint32
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // sackPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ sackPermitted bool
+
+ // sack holds TCP SACK related information for this endpoint.
+ sack SACKInfo
+
+ // reusePort is set to true if SO_REUSEPORT is enabled.
+ reusePort bool
+
+ // delay enables Nagle's algorithm.
+ //
+ // delay is a boolean (0 is false) and must be accessed atomically.
+ delay uint32
+
+ // cork holds back segments until full.
+ //
+ // cork is a boolean (0 is false) and must be accessed atomically.
+ cork uint32
+
+ // scoreboard holds TCP SACK Scoreboard information for this endpoint.
+ scoreboard *SACKScoreboard
+
+ // The options below aren't implemented, but we remember the user
+ // settings because applications expect to be able to set/query these
+ // options.
+ reuseAddr bool
+
+ // slowAck holds the negated state of quick ack. It is stubbed out and
+ // does nothing.
+ //
+ // slowAck is a boolean (0 is false) and must be accessed atomically.
+ slowAck uint32
+
+ // segmentQueue is used to hand received segments to the protocol
+ // goroutine. Segments are queued as long as the queue is not full,
+ // and dropped when it is.
+ segmentQueue segmentQueue `state:"wait"`
+
+ // synRcvdCount is the number of connections for this endpoint that are
+ // in SYN-RCVD state.
+ synRcvdCount int
+
+ // The following fields are used to manage the send buffer. When
+ // segments are ready to be sent, they are added to sndQueue and the
+ // protocol goroutine is signaled via sndWaker.
+ //
+ // When the send side is closed, the protocol goroutine is notified via
+ // sndCloseWaker, and sndClosed is set to true.
+ sndBufMu sync.Mutex `state:"nosave"`
+ sndBufSize int
+ sndBufUsed int
+ sndClosed bool
+ sndBufInQueue seqnum.Size
+ sndQueue segmentList `state:"wait"`
+ sndWaker sleep.Waker `state:"manual"`
+ sndCloseWaker sleep.Waker `state:"manual"`
+
+ // cc stores the name of the Congestion Control algorithm to use for
+ // this endpoint.
+ cc CongestionControlOption
+
+ // The following are used when a "packet too big" control packet is
+ // received. They are protected by sndBufMu. They are used to
+ // communicate to the main protocol goroutine how many such control
+ // messages have been received since the last notification was processed
+ // and what was the smallest MTU seen.
+ packetTooBigCount int
+ sndMTU int
+
+ // newSegmentWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and handle new segments queued to it.
+ newSegmentWaker sleep.Waker `state:"manual"`
+
+ // notificationWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and check for notifications.
+ notificationWaker sleep.Waker `state:"manual"`
+
+ // notifyFlags is a bitmask of flags used to indicate to the protocol
+ // goroutine what it was notified; this is only accessed atomically.
+ notifyFlags uint32 `state:"nosave"`
+
+ // keepalive manages TCP keepalive state. When the connection is idle
+ // (no data sent or received) for keepaliveIdle, we start sending
+ // keepalives every keepalive.interval. If we send keepalive.count
+ // without hearing a response, the connection is closed.
+ keepalive keepalive
+
+ // acceptedChan is used by a listening endpoint protocol goroutine to
+ // send newly accepted connections to the endpoint so that they can be
+ // read by Accept() calls.
+ acceptedChan chan *endpoint `state:".([]*endpoint)"`
+
+ // The following are only used from the protocol goroutine, and
+ // therefore don't need locks to protect them.
+ rcv *receiver `state:"wait"`
+ snd *sender `state:"wait"`
+
+ // The goroutine drain completion notification channel.
+ drainDone chan struct{} `state:"nosave"`
+
+ // The goroutine undrain notification channel.
+ undrain chan struct{} `state:"nosave"`
+
+ // probe if not nil is invoked on every received segment. It is passed
+ // a copy of the current state of the endpoint.
+ probe stack.TCPProbeFunc `state:"nosave"`
+
+ // The following are only used to assist the restore run to re-connect.
+ bindAddress tcpip.Address
+ connectingAddress tcpip.Address
+
+ gso *stack.GSO
+}
+
+// StopWork halts packet processing. Only to be used in tests.
+func (e *endpoint) StopWork() {
+ e.workMu.Lock()
+}
+
+// ResumeWork resumes packet processing. Only to be used in tests.
+func (e *endpoint) ResumeWork() {
+ e.workMu.Unlock()
+}
+
+// keepalive is a synchronization wrapper used to appease stateify. See the
+// comment in endpoint, where it is used.
+//
+// +stateify savable
+type keepalive struct {
+ sync.Mutex `state:"nosave"`
+ enabled bool
+ idle time.Duration
+ interval time.Duration
+ count int
+ unacked int
+ timer timer `state:"nosave"`
+ waker sleep.Waker `state:"nosave"`
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ e := &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSize: DefaultBufferSize,
+ sndBufSize: DefaultBufferSize,
+ sndMTU: int(math.MaxInt32),
+ reuseAddr: true,
+ keepalive: keepalive{
+ // Linux defaults.
+ idle: 2 * time.Hour,
+ interval: 75 * time.Second,
+ count: 9,
+ },
+ }
+
+ var ss SendBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ e.sndBufSize = ss.Default
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ e.rcvBufSize = rs.Default
+ }
+
+ var cs CongestionControlOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ e.cc = cs
+ }
+
+ if p := stack.GetTCPProbe(); p != nil {
+ e.probe = p
+ }
+
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.workMu.Init()
+ e.workMu.Lock()
+ e.tsOffset = timeStampOffset()
+ return e
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ result := waiter.EventMask(0)
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ switch e.state {
+ case stateInitial, stateBound, stateConnecting:
+ // Ready for nothing.
+
+ case stateClosed, stateError:
+ // Ready for anything.
+ result = mask
+
+ case stateListen:
+ // Check if there's anything in the accepted channel.
+ if (mask & waiter.EventIn) != 0 {
+ if len(e.acceptedChan) > 0 {
+ result |= waiter.EventIn
+ }
+ }
+
+ case stateConnected:
+ // Determine if the endpoint is writable if requested.
+ if (mask & waiter.EventOut) != 0 {
+ e.sndBufMu.Lock()
+ if e.sndClosed || e.sndBufUsed < e.sndBufSize {
+ result |= waiter.EventOut
+ }
+ e.sndBufMu.Unlock()
+ }
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvListMu.Lock()
+ if e.rcvBufUsed > 0 || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvListMu.Unlock()
+ }
+ }
+
+ return result
+}
+
+func (e *endpoint) fetchNotifications() uint32 {
+ return atomic.SwapUint32(&e.notifyFlags, 0)
+}
+
+func (e *endpoint) notifyProtocolGoroutine(n uint32) {
+ for {
+ v := atomic.LoadUint32(&e.notifyFlags)
+ if v&n == n {
+ // The flags are already set.
+ return
+ }
+
+ if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) {
+ if v == 0 {
+ // We are causing a transition from no flags to
+ // at least one flag set, so we must cause the
+ // protocol goroutine to wake up.
+ e.notificationWaker.Assert()
+ }
+ return
+ }
+ }
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it. It must be called only once and with no other concurrent calls to
+// the endpoint.
+func (e *endpoint) Close() {
+ // Issue a shutdown so that the peer knows we won't send any more data
+ // if we're connected, or stop accepting if we're listening.
+ e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+
+ e.mu.Lock()
+
+ // For listening sockets, we always release ports inline so that they
+ // are immediately available for reuse after Close() is called. If also
+ // registered, we unregister as well otherwise the next user would fail
+ // in Listen() when trying to register.
+ if e.state == stateListen && e.isPortReserved {
+ if e.isRegistered {
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.isRegistered = false
+ }
+
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.isPortReserved = false
+ }
+
+ // Either perform the local cleanup or kick the worker to make sure it
+ // knows it needs to cleanup.
+ tcpip.AddDanglingEndpoint(e)
+ if !e.workerRunning {
+ e.cleanupLocked()
+ } else {
+ e.workerCleanup = true
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+
+ e.mu.Unlock()
+}
+
+// cleanupLocked frees all resources associated with the endpoint. It is called
+// after Close() is called and the worker goroutine (if any) is done with its
+// work.
+func (e *endpoint) cleanupLocked() {
+ // Close all endpoints that might have been accepted by TCP but not by
+ // the client.
+ if e.acceptedChan != nil {
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.mu.Lock()
+ n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
+ n.Close()
+ }
+ e.acceptedChan = nil
+ }
+ e.workerCleanup = false
+
+ if e.isRegistered {
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.isRegistered = false
+ }
+
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.isPortReserved = false
+ }
+
+ e.route.Release()
+ tcpip.DeleteDanglingEndpoint(e)
+}
+
+// Read reads data from the endpoint.
+func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.mu.RLock()
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data. Also note that a RST being received
+ // would cause the state to become stateError so we should allow the
+ // reads to proceed before returning a ECONNRESET.
+ e.rcvListMu.Lock()
+ bufUsed := e.rcvBufUsed
+ if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 {
+ e.rcvListMu.Unlock()
+ he := e.hardError
+ e.mu.RUnlock()
+ if s == stateError {
+ return buffer.View{}, tcpip.ControlMessages{}, he
+ }
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ }
+
+ v, err := e.readLocked()
+ e.rcvListMu.Unlock()
+
+ e.mu.RUnlock()
+
+ return v, tcpip.ControlMessages{}, err
+}
+
+func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || e.state != stateConnected {
+ return buffer.View{}, tcpip.ErrClosedForReceive
+ }
+ return buffer.View{}, tcpip.ErrWouldBlock
+ }
+
+ s := e.rcvList.Front()
+ views := s.data.Views()
+ v := views[s.viewToDeliver]
+ s.viewToDeliver++
+
+ if s.viewToDeliver >= len(views) {
+ e.rcvList.Remove(s)
+ s.decRef()
+ }
+
+ scale := e.rcv.rcvWndScale
+ wasZero := e.zeroReceiveWindow(scale)
+ e.rcvBufUsed -= len(v)
+ if wasZero && !e.zeroReceiveWindow(scale) {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
+
+ return v, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint cannot be written to if it's not connected.
+ if e.state != stateConnected {
+ switch e.state {
+ case stateError:
+ return 0, nil, e.hardError
+ default:
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+ }
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil, nil
+ }
+
+ e.sndBufMu.Lock()
+
+ // Check if the connection has already been closed for sends.
+ if e.sndClosed {
+ e.sndBufMu.Unlock()
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Check against the limit.
+ avail := e.sndBufSize - e.sndBufUsed
+ if avail <= 0 {
+ e.sndBufMu.Unlock()
+ return 0, nil, tcpip.ErrWouldBlock
+ }
+
+ v, perr := p.Get(avail)
+ if perr != nil {
+ e.sndBufMu.Unlock()
+ return 0, nil, perr
+ }
+
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
+
+ // Add data to the send queue.
+ e.sndBufUsed += l
+ e.sndBufInQueue += seqnum.Size(l)
+ e.sndQueue.PushBack(s)
+
+ e.sndBufMu.Unlock()
+
+ if e.workMu.TryLock() {
+ // Do the work inline.
+ e.handleWrite()
+ e.workMu.Unlock()
+ } else {
+ // Let the protocol goroutine do the work.
+ e.sndWaker.Assert()
+ }
+ return uintptr(l), nil, nil
+}
+
+// Peek reads data without consuming it from the endpoint.
+//
+// This method does not block if there is no data pending.
+func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data.
+ if s := e.state; s != stateConnected && s != stateClosed {
+ if s == stateError {
+ return 0, tcpip.ControlMessages{}, e.hardError
+ }
+ return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || e.state != stateConnected {
+ return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ }
+ return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
+ // Make a copy of vec so we can modify the slide headers.
+ vec = append([][]byte(nil), vec...)
+
+ var num uintptr
+
+ for s := e.rcvList.Front(); s != nil; s = s.Next() {
+ views := s.data.Views()
+
+ for i := s.viewToDeliver; i < len(views); i++ {
+ v := views[i]
+
+ for len(v) > 0 {
+ if len(vec) == 0 {
+ return num, tcpip.ControlMessages{}, nil
+ }
+ if len(vec[0]) == 0 {
+ vec = vec[1:]
+ continue
+ }
+
+ n := copy(vec[0], v)
+ v = v[n:]
+ vec[0] = vec[0][n:]
+ num += uintptr(n)
+ }
+ }
+ }
+
+ return num, tcpip.ControlMessages{}, nil
+}
+
+// zeroReceiveWindow checks if the receive window to be announced now would be
+// zero, based on the amount of available buffer and the receive window scaling.
+//
+// It must be called with rcvListMu held.
+func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
+ if e.rcvBufUsed >= e.rcvBufSize {
+ return true
+ }
+
+ return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs ReceiveBufferSizeOption
+ size := int(v)
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if size < rs.Min {
+ size = rs.Min
+ }
+ if size > rs.Max {
+ size = rs.Max
+ }
+ }
+
+ mask := uint32(notifyReceiveWindowChanged)
+
+ e.rcvListMu.Lock()
+
+ // Make sure the receive buffer size allows us to send a
+ // non-zero window size.
+ scale := uint8(0)
+ if e.rcv != nil {
+ scale = e.rcv.rcvWndScale
+ }
+ if size>>scale == 0 {
+ size = 1 << scale
+ }
+
+ // Make sure 2*size doesn't overflow.
+ if size > math.MaxInt32/2 {
+ size = math.MaxInt32 / 2
+ }
+
+ wasZero := e.zeroReceiveWindow(scale)
+ e.rcvBufSize = size
+ if wasZero && !e.zeroReceiveWindow(scale) {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.rcvListMu.Unlock()
+
+ e.notifyProtocolGoroutine(mask)
+ return nil
+
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ size := int(v)
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if size < ss.Min {
+ size = ss.Min
+ }
+ if size > ss.Max {
+ size = ss.Max
+ }
+ }
+
+ e.sndBufMu.Lock()
+ e.sndBufSize = size
+ e.sndBufMu.Unlock()
+ return nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+ return nil
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ e.keepalive.enabled = v != 0
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ e.keepalive.idle = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ e.keepalive.interval = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ e.keepalive.count = int(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v != 0
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return nil
+ }
+}
+
+// readyReceiveSize returns the number of bytes ready to be received.
+func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint cannot be in listen state.
+ if e.state == stateListen {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ return e.rcvBufUsed, nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ e.lastErrorMu.Lock()
+ err := e.lastError
+ e.lastError = nil
+ e.lastErrorMu.Unlock()
+ return err
+
+ case *tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.sndBufMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
+ e.rcvListMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ v, err := e.readyReceiveSize()
+ if err != nil {
+ return err
+ }
+
+ *o = tcpip.ReceiveQueueSizeOption(v)
+ return nil
+
+ case *tcpip.DelayOption:
+ *o = 0
+ if v := atomic.LoadUint32(&e.delay); v != 0 {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.CorkOption:
+ *o = 0
+ if v := atomic.LoadUint32(&e.cork); v != 0 {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReuseAddressOption:
+ e.mu.RLock()
+ v := e.reuseAddr
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.QuickAckOption:
+ *o = 1
+ if v := atomic.LoadUint32(&e.slowAck); v != 0 {
+ *o = 0
+ }
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.TCPInfoOption:
+ *o = tcpip.TCPInfoOption{}
+ e.mu.RLock()
+ snd := e.snd
+ e.mu.RUnlock()
+ if snd != nil {
+ snd.rtt.Lock()
+ o.RTT = snd.rtt.srtt
+ o.RTTVar = snd.rtt.rttvar
+ snd.rtt.Unlock()
+ }
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveCountOption(e.keepalive.count)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.OutOfBandInlineOption:
+ // We don't currently support disabling this option.
+ *o = 1
+ return nil
+
+ case *tcpip.BroadcastOption:
+ e.mu.Lock()
+ v := e.broadcast
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return e.connect(addr, true, true)
+}
+
+// connect connects the endpoint to its peer. In the normal non-S/R case, the
+// new connection is expected to run the main goroutine and perform handshake.
+// In restore of previously connected endpoints, both ends will be passively
+// created (so no new handshaking is done); for stack-accepted connections not
+// yet accepted by the app, they are restored without running the main goroutine
+// here.
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ defer func() {
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ }
+ }()
+
+ connectingAddr := addr.Addr
+
+ netProto, err := e.checkV4Mapped(&addr)
+ if err != nil {
+ return err
+ }
+
+ nicid := addr.NIC
+ switch e.state {
+ case stateBound:
+ // If we're already bound to a NIC but the caller is requesting
+ // that we use a different one now, we cannot proceed.
+ if e.boundNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.boundNICID {
+ return tcpip.ErrNoRoute
+ }
+
+ nicid = e.boundNICID
+
+ case stateInitial:
+ // Nothing to do. We'll eventually fill-in the gaps in the ID
+ // (if any) when we find a route.
+
+ case stateConnecting:
+ // A connection request has already been issued but hasn't
+ // completed yet.
+ return tcpip.ErrAlreadyConnecting
+
+ case stateConnected:
+ // The endpoint is already connected. If caller hasn't been notified yet, return success.
+ if !e.isConnectNotified {
+ e.isConnectNotified = true
+ return nil
+ }
+ // Otherwise return that it's already connected.
+ return tcpip.ErrAlreadyConnected
+
+ case stateError:
+ return e.hardError
+
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ origID := e.id
+
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ e.id.LocalAddress = r.LocalAddress
+ e.id.RemoteAddress = r.RemoteAddress
+ e.id.RemotePort = addr.Port
+
+ if e.id.LocalPort != 0 {
+ // The endpoint is bound to a port, attempt to register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ if err != nil {
+ return err
+ }
+ } else {
+ // The endpoint doesn't have a local port yet, so try to get
+ // one. Make sure that it isn't one that will result in the same
+ // address/port for both local and remote (otherwise this
+ // endpoint would be trying to connect to itself).
+ sameAddr := e.id.LocalAddress == e.id.RemoteAddress
+ if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.id.RemotePort {
+ return false, nil
+ }
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ return false, nil
+ }
+
+ id := e.id
+ id.LocalPort = p
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ case nil:
+ e.id = id
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ }); err != nil {
+ return err
+ }
+ }
+
+ // Remove the port reservation. This can happen when Bind is called
+ // before Connect: in such a case we don't want to hold on to
+ // reservations anymore.
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.isPortReserved = false
+ }
+
+ e.isRegistered = true
+ e.state = stateConnecting
+ e.route = r.Clone()
+ e.boundNICID = nicid
+ e.effectiveNetProtos = netProtos
+ e.connectingAddress = connectingAddr
+
+ e.initGSO()
+
+ // Connect in the restore phase does not perform handshake. Restore its
+ // connection setting here.
+ if !handshake {
+ e.segmentQueue.mu.Lock()
+ for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
+ for s := l.Front(); s != nil; s = s.Next() {
+ s.id = e.id
+ s.route = r.Clone()
+ e.sndWaker.Assert()
+ }
+ }
+ e.segmentQueue.mu.Unlock()
+ e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
+ e.state = stateConnected
+ }
+
+ if run {
+ e.workerRunning = true
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ }
+
+ return tcpip.ErrConnectStarted
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ switch e.state {
+ case stateConnected:
+ // Close for read.
+ if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ // Mark read side as closed.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ rcvBufUsed := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // If we're fully closed and we have unread data we need to abort
+ // the connection with a RST.
+ if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
+ e.notifyProtocolGoroutine(notifyReset)
+ return nil
+ }
+ }
+
+ // Close for write.
+ if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ e.sndBufMu.Lock()
+
+ if e.sndClosed {
+ // Already closed.
+ e.sndBufMu.Unlock()
+ break
+ }
+
+ // Queue fin segment.
+ s := newSegmentFromView(&e.route, e.id, nil)
+ e.sndQueue.PushBack(s)
+ e.sndBufInQueue++
+
+ // Mark endpoint as closed.
+ e.sndClosed = true
+
+ e.sndBufMu.Unlock()
+
+ // Tell protocol goroutine to close.
+ e.sndCloseWaker.Assert()
+ }
+
+ case stateListen:
+ // Tell protocolListenLoop to stop.
+ if flags&tcpip.ShutdownRead != 0 {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+
+ default:
+ return tcpip.ErrNotConnected
+ }
+
+ return nil
+}
+
+// Listen puts the endpoint in "listen" mode, which allows it to accept
+// new connections.
+func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ defer func() {
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ }
+ }()
+
+ // Allow the backlog to be adjusted if the endpoint is not shutting down.
+ // When the endpoint shuts down, it sets workerCleanup to true, and from
+ // that point onward, acceptedChan is the responsibility of the cleanup()
+ // method (and should not be touched anywhere else, including here).
+ if e.state == stateListen && !e.workerCleanup {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *endpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+
+ // Endpoint must be bound before it can transition to listen mode.
+ if e.state != stateBound {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Register the endpoint.
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ return err
+ }
+
+ e.isRegistered = true
+ e.state = stateListen
+ if e.acceptedChan == nil {
+ e.acceptedChan = make(chan *endpoint, backlog)
+ }
+ e.workerRunning = true
+
+ go e.protocolListenLoop( // S/R-SAFE: drained on save.
+ seqnum.Size(e.receiveBufferAvailable()))
+
+ return nil
+}
+
+// startAcceptedLoop sets up required state and starts a goroutine with the
+// main loop for accepted connections.
+func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+ e.waiterQueue = waiterQueue
+ e.workerRunning = true
+ go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
+}
+
+// Accept returns a new endpoint if a peer has established a connection
+// to an endpoint previously set to listen mode.
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // Endpoint must be in listen state before it can accept connections.
+ if e.state != stateListen {
+ return nil, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Get the new accepted endpoint.
+ var n *endpoint
+ select {
+ case n = <-e.acceptedChan:
+ default:
+ return nil, nil, tcpip.ErrWouldBlock
+ }
+
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ return n, wq, nil
+}
+
+// Bind binds the endpoint to a specific local port and optionally address.
+func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore. This is because once the endpoint goes into a connected or
+ // listen state, it is already bound.
+ if e.state != stateInitial {
+ return tcpip.ErrAlreadyBound
+ }
+
+ e.bindAddress = addr.Addr
+ netProto, err := e.checkV4Mapped(&addr)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ if err != nil {
+ return err
+ }
+
+ e.isPortReserved = true
+ e.effectiveNetProtos = netProtos
+ e.id.LocalPort = port
+
+ // Any failures beyond this point must remove the port registration.
+ defer func() {
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.isPortReserved = false
+ e.effectiveNetProtos = nil
+ e.id.LocalPort = 0
+ e.id.LocalAddress = ""
+ e.boundNICID = 0
+ }
+ }()
+
+ // If an address is specified, we must ensure that it's one of our
+ // local addresses.
+ if len(addr.Addr) != 0 {
+ nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ e.boundNICID = nic
+ e.id.LocalAddress = addr.Addr
+ }
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ s := newSegment(r, id, vv)
+ if !s.parse() {
+ e.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ e.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stack.Stats().TCP.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ e.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ // Send packet to worker goroutine.
+ if e.segmentQueue.enqueue(s) {
+ e.newSegmentWaker.Assert()
+ } else {
+ // The queue is full, so we drop the segment.
+ e.stack.Stats().DroppedPackets.Increment()
+ s.decRef()
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ switch typ {
+ case stack.ControlPacketTooBig:
+ e.sndBufMu.Lock()
+ e.packetTooBigCount++
+ if v := int(extra); v < e.sndMTU {
+ e.sndMTU = v
+ }
+ e.sndBufMu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyMTUChanged)
+ }
+}
+
+// updateSndBufferUsage is called by the protocol goroutine when room opens up
+// in the send buffer. The number of newly available bytes is v.
+func (e *endpoint) updateSndBufferUsage(v int) {
+ e.sndBufMu.Lock()
+ notify := e.sndBufUsed >= e.sndBufSize>>1
+ e.sndBufUsed -= v
+ // We only notify when there is half the sndBufSize available after
+ // a full buffer event occurs. This ensures that we don't wake up
+ // writers to queue just 1-2 segments and go back to sleep.
+ notify = notify && e.sndBufUsed < e.sndBufSize>>1
+ e.sndBufMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.EventOut)
+ }
+}
+
+// readyToRead is called by the protocol goroutine when a new segment is ready
+// to be read, or when the connection is closed for receiving (in which case
+// s will be nil).
+func (e *endpoint) readyToRead(s *segment) {
+ e.rcvListMu.Lock()
+ if s != nil {
+ s.incRef()
+ e.rcvBufUsed += s.data.Size()
+ e.rcvList.PushBack(s)
+ } else {
+ e.rcvClosed = true
+ }
+ e.rcvListMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventIn)
+}
+
+// receiveBufferAvailable calculates how many bytes are still available in the
+// receive buffer.
+func (e *endpoint) receiveBufferAvailable() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // We may use more bytes than the buffer size when the receive buffer
+ // shrinks.
+ if used >= size {
+ return 0
+ }
+
+ return size - used
+}
+
+func (e *endpoint) receiveBufferSize() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ e.rcvListMu.Unlock()
+
+ return size
+}
+
+// updateRecentTimestamp updates the recent timestamp using the algorithm
+// described in https://tools.ietf.org/html/rfc7323#section-4.3
+func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
+ if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.recentTS = tsVal
+ }
+}
+
+// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
+// the SYN options indicate that timestamp option was negotiated. It also
+// initializes the recentTS with the value provided in synOpts.TSval.
+func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+ if synOpts.TS {
+ e.sendTSOk = true
+ e.recentTS = synOpts.TSVal
+ }
+}
+
+// timestamp returns the timestamp value to be used in the TSVal field of the
+// timestamp option for outgoing TCP segments for a given endpoint.
+func (e *endpoint) timestamp() uint32 {
+ return tcpTimeStamp(e.tsOffset)
+}
+
+// tcpTimeStamp returns a timestamp offset by the provided offset. This is
+// not inlined above as it's used when SYN cookies are in use and endpoint
+// is not created at the time when the SYN cookie is sent.
+func tcpTimeStamp(offset uint32) uint32 {
+ now := time.Now()
+ return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+}
+
+// timeStampOffset returns a randomized timestamp offset to be used when sending
+// timestamp values in a timestamp option for a TCP segment.
+func timeStampOffset() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // NOTE: This is not completely to spec as normally this should be
+ // initialized in a manner analogous to how sequence numbers are
+ // randomized per connection basis. But for now this is sufficient.
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
+// if the SYN options indicate that the SACK option was negotiated and the TCP
+// stack is configured to enable TCP SACK option.
+func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+ var v SACKEnabled
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return
+ }
+ if bool(v) && synOpts.SACKPermitted {
+ e.sackPermitted = true
+ }
+}
+
+// maxOptionSize return the maximum size of TCP options.
+func (e *endpoint) maxOptionSize() (size int) {
+ var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock
+ options := e.makeOptions(maxSackBlocks[:])
+ size = len(options)
+ putOptions(options)
+
+ return size
+}
+
+// completeState makes a full copy of the endpoint and returns it. This is used
+// before invoking the probe. The state returned may not be fully consistent if
+// there are intervening syscalls when the state is being copied.
+func (e *endpoint) completeState() stack.TCPEndpointState {
+ var s stack.TCPEndpointState
+ s.SegTime = time.Now()
+
+ // Copy EndpointID.
+ e.mu.Lock()
+ s.ID = stack.TCPEndpointID(e.id)
+ e.mu.Unlock()
+
+ // Copy endpoint rcv state.
+ e.rcvListMu.Lock()
+ s.RcvBufSize = e.rcvBufSize
+ s.RcvBufUsed = e.rcvBufUsed
+ s.RcvClosed = e.rcvClosed
+ e.rcvListMu.Unlock()
+
+ // Endpoint TCP Option state.
+ s.SendTSOk = e.sendTSOk
+ s.RecentTS = e.recentTS
+ s.TSOffset = e.tsOffset
+ s.SACKPermitted = e.sackPermitted
+ s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
+ copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks])
+ s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy()
+
+ // Copy endpoint send state.
+ e.sndBufMu.Lock()
+ s.SndBufSize = e.sndBufSize
+ s.SndBufUsed = e.sndBufUsed
+ s.SndClosed = e.sndClosed
+ s.SndBufInQueue = e.sndBufInQueue
+ s.PacketTooBigCount = e.packetTooBigCount
+ s.SndMTU = e.sndMTU
+ e.sndBufMu.Unlock()
+
+ // Copy receiver state.
+ s.Receiver = stack.TCPReceiverState{
+ RcvNxt: e.rcv.rcvNxt,
+ RcvAcc: e.rcv.rcvAcc,
+ RcvWndScale: e.rcv.rcvWndScale,
+ PendingBufUsed: e.rcv.pendingBufUsed,
+ PendingBufSize: e.rcv.pendingBufSize,
+ }
+
+ // Copy sender state.
+ s.Sender = stack.TCPSenderState{
+ LastSendTime: e.snd.lastSendTime,
+ DupAckCount: e.snd.dupAckCount,
+ FastRecovery: stack.TCPFastRecoveryState{
+ Active: e.snd.fr.active,
+ First: e.snd.fr.first,
+ Last: e.snd.fr.last,
+ MaxCwnd: e.snd.fr.maxCwnd,
+ HighRxt: e.snd.fr.highRxt,
+ RescueRxt: e.snd.fr.rescueRxt,
+ },
+ SndCwnd: e.snd.sndCwnd,
+ Ssthresh: e.snd.sndSsthresh,
+ SndCAAckCount: e.snd.sndCAAckCount,
+ Outstanding: e.snd.outstanding,
+ SndWnd: e.snd.sndWnd,
+ SndUna: e.snd.sndUna,
+ SndNxt: e.snd.sndNxt,
+ RTTMeasureSeqNum: e.snd.rttMeasureSeqNum,
+ RTTMeasureTime: e.snd.rttMeasureTime,
+ Closed: e.snd.closed,
+ RTO: e.snd.rto,
+ SRTTInited: e.snd.srttInited,
+ MaxPayloadSize: e.snd.maxPayloadSize,
+ SndWndScale: e.snd.sndWndScale,
+ MaxSentAck: e.snd.maxSentAck,
+ }
+ e.snd.rtt.Lock()
+ s.Sender.SRTT = e.snd.rtt.srtt
+ e.snd.rtt.Unlock()
+
+ if cubic, ok := e.snd.cc.(*cubicState); ok {
+ s.Sender.Cubic = stack.TCPCubicState{
+ WMax: cubic.wMax,
+ WLastMax: cubic.wLastMax,
+ T: cubic.t,
+ TimeSinceLastCongestion: time.Since(cubic.t),
+ C: cubic.c,
+ K: cubic.k,
+ Beta: cubic.beta,
+ WC: cubic.wC,
+ WEst: cubic.wEst,
+ }
+ }
+ return s
+}
+
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityGSO == 0 {
+ return
+ }
+
+ gso := &stack.GSO{}
+ switch e.route.NetProto {
+ case header.IPv4ProtocolNumber:
+ gso.Type = stack.GSOTCPv4
+ gso.L3HdrLen = header.IPv4MinimumSize
+ case header.IPv6ProtocolNumber:
+ gso.Type = stack.GSOTCPv6
+ gso.L3HdrLen = header.IPv6MinimumSize
+ default:
+ panic(fmt.Sprintf("Unknown netProto: %v", e.netProto))
+ }
+ gso.NeedsCsum = true
+ gso.CsumOffset = header.TCPChecksumOffset
+ gso.MaxSize = e.route.GSOMaxSize()
+ e.gso = gso
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
new file mode 100644
index 000000000..e8aed2875
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -0,0 +1,362 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+func (e *endpoint) drainSegmentLocked() {
+ // Drain only up to once.
+ if e.drainDone != nil {
+ return
+ }
+
+ e.drainDone = make(chan struct{})
+ e.undrain = make(chan struct{})
+ e.mu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyDrain)
+ <-e.drainDone
+
+ e.mu.Lock()
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets.
+ e.segmentQueue.setLimit(0)
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch e.state {
+ case stateInitial, stateBound:
+ case stateConnected:
+ if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
+ if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
+ panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ }
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ e.Close()
+ e.mu.Lock()
+ }
+ if !e.workerRunning {
+ // The endpoint must be in acceptedChan or has been just
+ // disconnected and closed.
+ break
+ }
+ fallthrough
+ case stateListen, stateConnecting:
+ e.drainSegmentLocked()
+ if e.state != stateClosed && e.state != stateError {
+ if !e.workerRunning {
+ panic("endpoint has no worker running in listen, connecting, or connected state")
+ }
+ break
+ }
+ fallthrough
+ case stateError, stateClosed:
+ for e.state == stateError && e.workerRunning {
+ e.mu.Unlock()
+ time.Sleep(100 * time.Millisecond)
+ e.mu.Lock()
+ }
+ if e.workerRunning {
+ panic("endpoint still has worker running in closed or error state")
+ }
+ default:
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ }
+
+ if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
+ panic("endpoint still has waiters upon save")
+ }
+
+ if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) {
+ panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
+ }
+}
+
+// saveAcceptedChan is invoked by stateify.
+func (e *endpoint) saveAcceptedChan() []*endpoint {
+ if e.acceptedChan == nil {
+ return nil
+ }
+ acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
+ for i := 0; i < len(acceptedEndpoints); i++ {
+ select {
+ case ep := <-e.acceptedChan:
+ acceptedEndpoints[i] = ep
+ default:
+ panic("endpoint acceptedChan buffer got consumed by background context")
+ }
+ }
+ for i := 0; i < len(acceptedEndpoints); i++ {
+ select {
+ case e.acceptedChan <- acceptedEndpoints[i]:
+ default:
+ panic("endpoint acceptedChan buffer got populated by background context")
+ }
+ }
+ return acceptedEndpoints
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
+ if cap(acceptedEndpoints) > 0 {
+ e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints))
+ for _, ep := range acceptedEndpoints {
+ e.acceptedChan <- ep
+ }
+ }
+}
+
+// saveState is invoked by stateify.
+func (e *endpoint) saveState() endpointState {
+ return e.state
+}
+
+// Endpoint loading must be done in the following ordering by their state, to
+// avoid dangling connecting w/o listening peer, and to avoid conflicts in port
+// reservation.
+var connectedLoading sync.WaitGroup
+var listenLoading sync.WaitGroup
+var connectingLoading sync.WaitGroup
+
+// Bound endpoint loading happens last.
+
+// loadState is invoked by stateify.
+func (e *endpoint) loadState(state endpointState) {
+ // This is to ensure that the loading wait groups include all applicable
+ // endpoints before any asynchronous calls to the Wait() methods.
+ switch state {
+ case stateConnected:
+ connectedLoading.Add(1)
+ case stateListen:
+ listenLoading.Add(1)
+ case stateConnecting:
+ connectingLoading.Add(1)
+ }
+ e.state = state
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.workMu.Init()
+
+ state := e.state
+ switch state {
+ case stateInitial, stateBound, stateListen, stateConnecting, stateConnected:
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ }
+ if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+ }
+ }
+
+ bind := func() {
+ e.state = stateInitial
+ if len(e.bindAddress) == 0 {
+ e.bindAddress = e.id.LocalAddress
+ }
+ if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
+ panic("endpoint binding failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case stateConnected:
+ bind()
+ if len(e.connectingAddress) == 0 {
+ // This endpoint is accepted by netstack but not yet by
+ // the app. If the endpoint is IPv6 but the remote
+ // address is IPv4, we need to connect as IPv6 so that
+ // dual-stack mode can be properly activated.
+ if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ } else {
+ e.connectingAddress = e.id.RemoteAddress
+ }
+ }
+ // Reset the scoreboard to reinitialize the sack information as
+ // we do not restore SACK information.
+ e.scoreboard.Reset()
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectedLoading.Done()
+ case stateListen:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ bind()
+ backlog := cap(e.acceptedChan)
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ listenLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateConnecting:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ bind()
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectingLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateBound:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateClosed:
+ if e.isPortReserved {
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ e.state = stateClosed
+ tcpip.AsyncLoading.Done()
+ }()
+ }
+ fallthrough
+ case stateError:
+ tcpip.DeleteDanglingEndpoint(e)
+ }
+}
+
+// saveLastError is invoked by stateify.
+func (e *endpoint) saveLastError() string {
+ if e.lastError == nil {
+ return ""
+ }
+
+ return e.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (e *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.lastError = loadError(s)
+}
+
+// saveHardError is invoked by stateify.
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
+ return ""
+ }
+
+ return e.hardError.String()
+}
+
+// loadHardError is invoked by stateify.
+func (e *endpoint) loadHardError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.hardError = loadError(s)
+}
+
+var messageToError map[string]*tcpip.Error
+
+var populate sync.Once
+
+func loadError(s string) *tcpip.Error {
+ populate.Do(func() {
+ var errors = []*tcpip.Error{
+ tcpip.ErrUnknownProtocol,
+ tcpip.ErrUnknownNICID,
+ tcpip.ErrUnknownDevice,
+ tcpip.ErrUnknownProtocolOption,
+ tcpip.ErrDuplicateNICID,
+ tcpip.ErrDuplicateAddress,
+ tcpip.ErrNoRoute,
+ tcpip.ErrBadLinkEndpoint,
+ tcpip.ErrAlreadyBound,
+ tcpip.ErrInvalidEndpointState,
+ tcpip.ErrAlreadyConnecting,
+ tcpip.ErrAlreadyConnected,
+ tcpip.ErrNoPortAvailable,
+ tcpip.ErrPortInUse,
+ tcpip.ErrBadLocalAddress,
+ tcpip.ErrClosedForSend,
+ tcpip.ErrClosedForReceive,
+ tcpip.ErrWouldBlock,
+ tcpip.ErrConnectionRefused,
+ tcpip.ErrTimeout,
+ tcpip.ErrAborted,
+ tcpip.ErrConnectStarted,
+ tcpip.ErrDestinationRequired,
+ tcpip.ErrNotSupported,
+ tcpip.ErrQueueSizeNotSupported,
+ tcpip.ErrNotConnected,
+ tcpip.ErrConnectionReset,
+ tcpip.ErrConnectionAborted,
+ tcpip.ErrNoSuchFile,
+ tcpip.ErrInvalidOptionValue,
+ tcpip.ErrNoLinkAddress,
+ tcpip.ErrBadAddress,
+ tcpip.ErrNetworkUnreachable,
+ tcpip.ErrMessageTooLong,
+ tcpip.ErrNoBufferSpace,
+ tcpip.ErrBroadcastDisabled,
+ tcpip.ErrNotPermitted,
+ }
+
+ messageToError = make(map[string]*tcpip.Error)
+ for _, e := range errors {
+ if messageToError[e.String()] != nil {
+ panic("tcpip errors with duplicated message: " + e.String())
+ }
+ messageToError[e.String()] = e
+ }
+ })
+
+ e, ok := messageToError[s]
+ if !ok {
+ panic("unknown error message: " + s)
+ }
+
+ return e
+}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
new file mode 100644
index 000000000..c30b45c2c
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Forwarder is a connection request forwarder, which allows clients to decide
+// what to do with a connection request, for example: ignore it, send a RST, or
+// attempt to complete the 3-way handshake.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ maxInFlight int
+ handler func(*ForwarderRequest)
+
+ mu sync.Mutex
+ inFlight map[stack.TransportEndpointID]struct{}
+ listen *listenContext
+}
+
+// NewForwarder allocates and initializes a new forwarder with the given
+// maximum number of in-flight connection attempts. Once the maximum is reached
+// new incoming connection requests will be ignored.
+//
+// If rcvWnd is set to zero, the default buffer size is used instead.
+func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
+ if rcvWnd == 0 {
+ rcvWnd = DefaultBufferSize
+ }
+ return &Forwarder{
+ maxInFlight: maxInFlight,
+ handler: handler,
+ inFlight: make(map[stack.TransportEndpointID]struct{}),
+ listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ }
+}
+
+// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if
+// it's a SYN packet), returning true if it's the case. Otherwise the packet
+// is not handled and false is returned.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ s := newSegment(r, id, vv)
+ defer s.decRef()
+
+ // We only care about well-formed SYN packets.
+ if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
+ return false
+ }
+
+ opts := parseSynSegmentOptions(s)
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // We have an inflight request for this id, ignore this one for now.
+ if _, ok := f.inFlight[id]; ok {
+ return true
+ }
+
+ // Ignore the segment if we're beyond the limit.
+ if len(f.inFlight) >= f.maxInFlight {
+ return true
+ }
+
+ // Launch a new goroutine to handle the request.
+ f.inFlight[id] = struct{}{}
+ s.incRef()
+ go f.handler(&ForwarderRequest{ // S/R-SAFE: not used by Sentry.
+ forwarder: f,
+ segment: s,
+ synOptions: opts,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a connection request received by the forwarder
+// and passed to the client. Clients must eventually call Complete() on it, and
+// may optionally create an endpoint to represent it via CreateEndpoint.
+type ForwarderRequest struct {
+ mu sync.Mutex
+ forwarder *Forwarder
+ segment *segment
+ synOptions header.TCPSynOptions
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the connection request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.segment.id
+}
+
+// Complete completes the request, and optionally sends a RST segment back to the
+// sender.
+func (r *ForwarderRequest) Complete(sendReset bool) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ panic("Completing already completed forwarder request")
+ }
+
+ // Remove request from the forwarder.
+ r.forwarder.mu.Lock()
+ delete(r.forwarder.inFlight, r.segment.id)
+ r.forwarder.mu.Unlock()
+
+ // If the caller requested, send a reset.
+ if sendReset {
+ replyWithReset(r.segment)
+ }
+
+ // Release all resources.
+ r.segment.decRef()
+ r.segment = nil
+ r.forwarder = nil
+}
+
+// CreateEndpoint creates a TCP endpoint for the connection request, performing
+// the 3-way handshake in the process.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ f := r.forwarder
+ ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ MSS: r.synOptions.MSS,
+ WS: r.synOptions.WS,
+ TS: r.synOptions.TS,
+ TSVal: r.synOptions.TSVal,
+ TSEcr: r.synOptions.TSEcr,
+ SACKPermitted: r.synOptions.SACKPermitted,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Start the protocol goroutine.
+ ep.startAcceptedLoop(queue)
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
new file mode 100644
index 000000000..b31bcccfa
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -0,0 +1,250 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcp contains the implementation of the TCP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing tcp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package tcp
+
+import (
+ "strings"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the tcp protocol name.
+ ProtocolName = "tcp"
+
+ // ProtocolNumber is the tcp protocol number.
+ ProtocolNumber = header.TCPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ minBufferSize = 4 << 10 // 4096 bytes.
+
+ // DefaultBufferSize is the default size of the receive and send buffers.
+ DefaultBufferSize = 1 << 20 // 1MB
+
+ // MaxBufferSize is the largest size a receive and send buffer can grow to.
+ maxBufferSize = 4 << 20 // 4MB
+
+ // MaxUnprocessedSegments is the maximum number of unprocessed segments
+ // that can be queued for a given endpoint.
+ MaxUnprocessedSegments = 300
+)
+
+// SACKEnabled option can be used to enable SACK support in the TCP
+// protocol. See: https://tools.ietf.org/html/rfc2018.
+type SACKEnabled bool
+
+// SendBufferSizeOption allows the default, min and max send buffer sizes for
+// TCP endpoints to be queried or configured.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption allows the default, min and max receive buffer size
+// for TCP endpoints to be queried or configured.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
+// CongestionControlOption sets the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption returns the supported congestion control
+// algorithms.
+type AvailableCongestionControlOption string
+
+type protocol struct {
+ mu sync.Mutex
+ sackEnabled bool
+ sendBufferSize SendBufferSizeOption
+ recvBufferSize ReceiveBufferSizeOption
+ congestionControl string
+ availableCongestionControl []string
+ allowedCongestionControl []string
+}
+
+// Number returns the tcp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new tcp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid tcp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.TCPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given tcp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.TCP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+//
+// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then
+// a reset is sent in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected by this
+// means."
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+ s := newSegment(r, id, vv)
+ defer s.decRef()
+
+ if !s.parse() || !s.csumValid {
+ return false
+ }
+
+ // There's nothing to do if this is already a reset packet.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return true
+ }
+
+ replyWithReset(s)
+ return true
+}
+
+// replyWithReset replies to the given segment with a reset segment.
+func replyWithReset(s *segment) {
+ // Get the seqnum from the packet if the ack flag is set.
+ seq := seqnum.Value(0)
+ if s.flagIsSet(header.TCPFlagAck) {
+ seq = s.ackNumber
+ }
+
+ ack := s.sequenceNumber.Add(s.logicalLen())
+
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */)
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SACKEnabled:
+ p.mu.Lock()
+ p.sackEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
+ case SendBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.sendBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.recvBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case CongestionControlOption:
+ for _, c := range p.availableCongestionControl {
+ if string(v) == c {
+ p.mu.Lock()
+ p.congestionControl = string(v)
+ p.mu.Unlock()
+ return nil
+ }
+ }
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SACKEnabled:
+ p.mu.Lock()
+ *v = SACKEnabled(p.sackEnabled)
+ p.mu.Unlock()
+ return nil
+
+ case *SendBufferSizeOption:
+ p.mu.Lock()
+ *v = p.sendBufferSize
+ p.mu.Unlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ p.mu.Lock()
+ *v = p.recvBufferSize
+ p.mu.Unlock()
+ return nil
+ case *CongestionControlOption:
+ p.mu.Lock()
+ *v = CongestionControlOption(p.congestionControl)
+ p.mu.Unlock()
+ return nil
+ case *AvailableCongestionControlOption:
+ p.mu.Lock()
+ *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ p.mu.Unlock()
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
+ })
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
new file mode 100644
index 000000000..b08a0e356
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -0,0 +1,221 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "container/heap"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// receiver holds the state necessary to receive TCP segments and turn them
+// into a stream of bytes.
+//
+// +stateify savable
+type receiver struct {
+ ep *endpoint
+
+ rcvNxt seqnum.Value
+
+ // rcvAcc is one beyond the last acceptable sequence number. That is,
+ // the "largest" sequence value that the receiver has announced to the
+ // its peer that it's willing to accept. This may be different than
+ // rcvNxt + rcvWnd if the receive window is reduced; in that case we
+ // have to reduce the window as we receive more data instead of
+ // shrinking it.
+ rcvAcc seqnum.Value
+
+ rcvWndScale uint8
+
+ closed bool
+
+ pendingRcvdSegments segmentHeap
+ pendingBufUsed seqnum.Size
+ pendingBufSize seqnum.Size
+}
+
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
+ return &receiver{
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: rcvWnd,
+ }
+}
+
+// acceptable checks if the segment sequence number range is acceptable
+// according to the table on page 26 of RFC 793.
+func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
+ rcvWnd := r.rcvNxt.Size(r.rcvAcc)
+ if rcvWnd == 0 {
+ return segLen == 0 && segSeq == r.rcvNxt
+ }
+
+ return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
+ seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
+}
+
+// getSendParams returns the parameters needed by the sender when building
+// segments to send.
+func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
+ // Calculate the window size based on the current buffer size.
+ n := r.ep.receiveBufferAvailable()
+ acc := r.rcvNxt.Add(seqnum.Size(n))
+ if r.rcvAcc.LessThan(acc) {
+ r.rcvAcc = acc
+ }
+
+ return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale
+}
+
+// nonZeroWindow is called when the receive window grows from zero to nonzero;
+// in such cases we may need to send an ack to indicate to our peer that it can
+// resume sending data.
+func (r *receiver) nonZeroWindow() {
+ if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
+ // We never got around to announcing a zero window size, so we
+ // don't need to immediately announce a nonzero one.
+ return
+ }
+
+ // Immediately send an ack.
+ r.ep.snd.sendAck()
+}
+
+// consumeSegment attempts to consume a segment that was received by r. The
+// segment may have just been received or may have been received earlier but
+// wasn't ready to be consumed then.
+//
+// Returns true if the segment was consumed, false if it cannot be consumed
+// yet because of a missing segment.
+func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool {
+ if segLen > 0 {
+ // If the segment doesn't include the seqnum we're expecting to
+ // consume now, we're missing a segment. We cannot proceed until
+ // we receive that segment though.
+ if !r.rcvNxt.InWindow(segSeq, segLen) {
+ return false
+ }
+
+ // Trim segment to eliminate already acknowledged data.
+ if segSeq.LessThan(r.rcvNxt) {
+ diff := segSeq.Size(r.rcvNxt)
+ segLen -= diff
+ segSeq.UpdateForward(diff)
+ s.sequenceNumber.UpdateForward(diff)
+ s.data.TrimFront(int(diff))
+ }
+
+ // Move segment to ready-to-deliver list. Wakeup any waiters.
+ r.ep.readyToRead(s)
+
+ } else if segSeq != r.rcvNxt {
+ return false
+ }
+
+ // Update the segment that we're expecting to consume.
+ r.rcvNxt = segSeq.Add(segLen)
+
+ // Trim SACK Blocks to remove any SACK information that covers
+ // sequence numbers that have been consumed.
+ TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+
+ if s.flagIsSet(header.TCPFlagFin) {
+ r.rcvNxt++
+
+ // Send ACK immediately.
+ r.ep.snd.sendAck()
+
+ // Tell any readers that no more data will come.
+ r.closed = true
+ r.ep.readyToRead(nil)
+
+ // Flush out any pending segments, except the very first one if
+ // it happens to be the one we're handling now because the
+ // caller is using it.
+ first := 0
+ if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s {
+ first = 1
+ }
+
+ for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingRcvdSegments[i].decRef()
+ }
+ r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
+ }
+
+ return true
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) {
+ // We don't care about receive processing anymore if the receive side
+ // is closed.
+ if r.closed {
+ return
+ }
+
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // If the sequence number range is outside the acceptable range, just
+ // send an ACK. This is according to RFC 793, page 37.
+ if !r.acceptable(segSeq, segLen) {
+ r.ep.snd.sendAck()
+ return
+ }
+
+ // Defer segment processing if it can't be consumed now.
+ if !r.consumeSegment(s, segSeq, segLen) {
+ if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
+ // We only store the segment if it's within our buffer
+ // size limit.
+ if r.pendingBufUsed < r.pendingBufSize {
+ r.pendingBufUsed += s.logicalLen()
+ s.incRef()
+ heap.Push(&r.pendingRcvdSegments, s)
+ }
+
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
+
+ // Immediately send an ack so that the peer knows it may
+ // have to retransmit.
+ r.ep.snd.sendAck()
+ }
+ return
+ }
+
+ // By consuming the current segment, we may have filled a gap in the
+ // sequence number domain that allows pending segments to be consumed
+ // now. So try to do it.
+ for !r.closed && r.pendingRcvdSegments.Len() > 0 {
+ s := r.pendingRcvdSegments[0]
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // Skip segment altogether if it has already been acknowledged.
+ if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) &&
+ !r.consumeSegment(s, segSeq, segLen) {
+ break
+ }
+
+ heap.Pop(&r.pendingRcvdSegments)
+ r.pendingBufUsed -= s.logicalLen()
+ s.decRef()
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go
new file mode 100644
index 000000000..f83ebc717
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/reno.go
@@ -0,0 +1,103 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+// renoState stores the variables related to TCP New Reno congestion
+// control algorithm.
+//
+// +stateify savable
+type renoState struct {
+ s *sender
+}
+
+// newRenoCC initializes the state for the NewReno congestion control algorithm.
+func newRenoCC(s *sender) *renoState {
+ return &renoState{s: s}
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window
+// we cross the SSthreshold then it will return the number of packets that
+// must be consumed in congestion avoidance mode.
+func (r *renoState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := r.s.sndCwnd + packetsAcked
+ if newcwnd >= r.s.sndSsthresh {
+ newcwnd = r.s.sndSsthresh
+ r.s.sndCAAckCount = 0
+ }
+
+ packetsAcked -= newcwnd - r.s.sndCwnd
+ r.s.sndCwnd = newcwnd
+ return packetsAcked
+}
+
+// updateCongestionAvoidance will update congestion window in congestion
+// avoidance mode as described in RFC5681 section 3.1
+func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
+ // Consume the packets in congestion avoidance mode.
+ r.s.sndCAAckCount += packetsAcked
+ if r.s.sndCAAckCount >= r.s.sndCwnd {
+ r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd
+ r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd
+ }
+}
+
+// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
+// page 6, eq. 4. It is called when we detect congestion in the network.
+func (r *renoState) reduceSlowStartThreshold() {
+ r.s.sndSsthresh = r.s.outstanding / 2
+ if r.s.sndSsthresh < 2 {
+ r.s.sndSsthresh = 2
+ }
+
+}
+
+// Update updates the congestion state based on the number of packets that
+// were acknowledged.
+// Update implements congestionControl.Update.
+func (r *renoState) Update(packetsAcked int) {
+ if r.s.sndCwnd < r.s.sndSsthresh {
+ packetsAcked = r.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ }
+ r.updateCongestionAvoidance(packetsAcked)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (r *renoState) HandleNDupAcks() {
+ // A retransmit was triggered due to nDupAckThreshold
+ // being hit. Reduce our slow start threshold.
+ r.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionControl.HandleRTOExpired.
+func (r *renoState) HandleRTOExpired() {
+ // We lost a packet, so reduce ssthresh.
+ r.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ r.s.sndCwnd = 1
+}
+
+// PostRecovery implements congestionControl.PostRecovery.
+func (r *renoState) PostRecovery() {
+ // noop.
+}
diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go
new file mode 100644
index 000000000..6a013d99b
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // MaxSACKBlocks is the maximum number of SACK blocks stored
+ // at receiver side.
+ MaxSACKBlocks = 6
+)
+
+// UpdateSACKBlocks updates the list of SACK blocks to include the segment
+// specified by segStart->segEnd. If the segment happens to be an out of order
+// delivery then the first block in the sack.blocks always includes the
+// segment identified by segStart->segEnd.
+func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
+ newSB := header.SACKBlock{Start: segStart, End: segEnd}
+ if sack.NumBlocks == 0 {
+ sack.Blocks[0] = newSB
+ sack.NumBlocks = 1
+ return
+ }
+ var n = 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ start, end := sack.Blocks[i].Start, sack.Blocks[i].End
+ if end.LessThanEq(start) || start.LessThanEq(rcvNxt) {
+ // Discard any invalid blocks where end is before start
+ // and discard any sack blocks that are before rcvNxt as
+ // those have already been acked.
+ continue
+ }
+ if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) {
+ // Merge this SACK block into newSB and discard this SACK
+ // block.
+ if start.LessThan(newSB.Start) {
+ newSB.Start = start
+ }
+ if newSB.End.LessThan(end) {
+ newSB.End = end
+ }
+ } else {
+ // Save this block.
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ }
+ if rcvNxt.LessThan(newSB.Start) {
+ // If this was an out of order segment then make sure that the
+ // first SACK block is the one that includes the segment.
+ //
+ // See the first bullet point in
+ // https://tools.ietf.org/html/rfc2018#section-4
+ if n == MaxSACKBlocks {
+ // If the number of SACK blocks is equal to
+ // MaxSACKBlocks then discard the last SACK block.
+ n--
+ }
+ for i := n - 1; i >= 0; i-- {
+ sack.Blocks[i+1] = sack.Blocks[i]
+ }
+ sack.Blocks[0] = newSB
+ n++
+ }
+ sack.NumBlocks = n
+}
+
+// TrimSACKBlockList updates the sack block list by removing/modifying any block
+// where start is < rcvNxt.
+func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) {
+ n := 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ if sack.Blocks[i].End.LessThanEq(rcvNxt) {
+ continue
+ }
+ if sack.Blocks[i].Start.LessThan(rcvNxt) {
+ // Shrink this SACK block.
+ sack.Blocks[i].Start = rcvNxt
+ }
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ sack.NumBlocks = n
+}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
new file mode 100644
index 000000000..1c5766a42
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/google/btree"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // maxSACKBlocks is the maximum number of distinct SACKBlocks the
+ // scoreboard will track. Once there are 100 distinct blocks, new
+ // insertions will fail.
+ maxSACKBlocks = 100
+
+ // defaultBtreeDegree is set to 2 as btree.New(2) results in a 2-3-4
+ // tree.
+ defaultBtreeDegree = 2
+)
+
+// SACKScoreboard stores a set of disjoint SACK ranges.
+//
+// +stateify savable
+type SACKScoreboard struct {
+ // smss is defined in RFC5681 as following:
+ //
+ // The SMSS is the size of the largest segment that the sender can
+ // transmit. This value can be based on the maximum transmission unit
+ // of the network, the path MTU discovery [RFC1191, RFC4821] algorithm,
+ // RMSS (see next item), or other factors. The size does not include
+ // the TCP/IP headers and options.
+ smss uint16
+ maxSACKED seqnum.Value
+ sacked seqnum.Size `state:"nosave"`
+ ranges *btree.BTree `state:"nosave"`
+}
+
+// NewSACKScoreboard returns a new SACK Scoreboard.
+func NewSACKScoreboard(smss uint16, iss seqnum.Value) *SACKScoreboard {
+ return &SACKScoreboard{
+ smss: smss,
+ ranges: btree.New(defaultBtreeDegree),
+ maxSACKED: iss,
+ }
+}
+
+// Reset erases all known range information from the SACK scoreboard.
+func (s *SACKScoreboard) Reset() {
+ s.ranges = btree.New(defaultBtreeDegree)
+ s.sacked = 0
+}
+
+// Insert inserts/merges the provided SACKBlock into the scoreboard.
+func (s *SACKScoreboard) Insert(r header.SACKBlock) {
+ if s.ranges.Len() >= maxSACKBlocks {
+ return
+ }
+
+ // Check if we can merge the new range with a range before or after it.
+ var toDelete []btree.Item
+ if s.maxSACKED.LessThan(r.End - 1) {
+ s.maxSACKED = r.End - 1
+ }
+ s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sacked := i.(header.SACKBlock)
+ // There is a hole between these two SACK blocks, so we can't
+ // merge anymore.
+ if r.End.LessThan(sacked.Start) {
+ return false
+ }
+ // There is some overlap at this point, merge the blocks and
+ // delete the other one.
+ //
+ // ----sS--------sE
+ // r.S---------------rE
+ // -------sE
+ if sacked.End.LessThan(r.End) {
+ // sacked is contained in the newly inserted range.
+ // Delete this block.
+ toDelete = append(toDelete, i)
+ return true
+ }
+ // sacked covers a range past end of the newly inserted
+ // block.
+ r.End = sacked.End
+ toDelete = append(toDelete, i)
+ return true
+ })
+
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sacked := i.(header.SACKBlock)
+ // sA------sE
+ // rA----rE
+ if sacked.End.LessThan(r.Start) {
+ return false
+ }
+ // The previous range extends into the current block. Merge it
+ // into the newly inserted range and delete the other one.
+ //
+ // <-rA---rE----<---rE--->
+ // sA--------------sE
+ r.Start = sacked.Start
+ // Extend r to cover sacked if sacked extends past r.
+ if r.End.LessThan(sacked.End) {
+ r.End = sacked.End
+ }
+ toDelete = append(toDelete, i)
+ return true
+ })
+ for _, i := range toDelete {
+ if sb := s.ranges.Delete(i); sb != nil {
+ sb := i.(header.SACKBlock)
+ s.sacked -= sb.Start.Size(sb.End)
+ }
+ }
+
+ replaced := s.ranges.ReplaceOrInsert(r)
+ if replaced == nil {
+ s.sacked += r.Start.Size(r.End)
+ }
+}
+
+// IsSACKED returns true if the a given range of sequence numbers denoted by r
+// are already covered by SACK information in the scoreboard.
+func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+
+ found := false
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.End.LessThan(r.Start) {
+ return false
+ }
+ if sacked.Contains(r) {
+ found = true
+ return false
+ }
+ return true
+ })
+ return found
+}
+
+// Dump prints the state of the scoreboard structure.
+func (s *SACKScoreboard) String() string {
+ var str strings.Builder
+ str.WriteString("SACKScoreboard: {")
+ s.ranges.Ascend(func(i btree.Item) bool {
+ str.WriteString(fmt.Sprintf("%v,", i))
+ return true
+ })
+ str.WriteString("}\n")
+ return str.String()
+}
+
+// Delete removes all SACK information prior to seq.
+func (s *SACKScoreboard) Delete(seq seqnum.Value) {
+ if s.Empty() {
+ return
+ }
+ toDelete := []btree.Item{}
+ toInsert := []btree.Item{}
+ r := header.SACKBlock{seq, seq.Add(1)}
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sb := i.(header.SACKBlock)
+ toDelete = append(toDelete, i)
+ if sb.End.LessThanEq(seq) {
+ s.sacked -= sb.Start.Size(sb.End)
+ } else {
+ newSB := header.SACKBlock{seq, sb.End}
+ toInsert = append(toInsert, newSB)
+ s.sacked -= sb.Start.Size(seq)
+ }
+ return true
+ })
+ for _, sb := range toDelete {
+ s.ranges.Delete(sb)
+ }
+ for _, sb := range toInsert {
+ s.ranges.ReplaceOrInsert(sb)
+ }
+}
+
+// Copy provides a copy of the SACK scoreboard.
+func (s *SACKScoreboard) Copy() (sackBlocks []header.SACKBlock, maxSACKED seqnum.Value) {
+ s.ranges.Ascend(func(i btree.Item) bool {
+ sackBlocks = append(sackBlocks, i.(header.SACKBlock))
+ return true
+ })
+ return sackBlocks, s.maxSACKED
+}
+
+// IsRangeLost implements the IsLost(SeqNum) operation defined in RFC 6675
+// section 4 but operates on a range of sequence numbers and returns true if
+// there are at least nDupAckThreshold SACK blocks greater than the range being
+// checked or if at least (nDupAckThreshold-1)*s.smss bytes have been SACKED
+// with sequence numbers greater than the block being checked.
+func (s *SACKScoreboard) IsRangeLost(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+ nDupSACK := 0
+ nDupSACKBytes := seqnum.Size(0)
+ isLost := false
+
+ // We need to check if the immediate lower (if any) sacked
+ // range contains or partially overlaps with r.
+ searchMore := true
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ searchMore = false
+ return false
+ }
+ if sacked.End.LessThanEq(r.Start) {
+ // all sequence numbers covered by sacked are below
+ // r so we continue searching.
+ return false
+ }
+ // There is a partial overlap. In this case we r.Start is
+ // between sacked.Start & sacked.End and r.End extends beyond
+ // sacked.End.
+ // Move r.Start to sacked.End and continuing searching blocks
+ // above r.Start.
+ r.Start = sacked.End
+ return false
+ })
+
+ if !searchMore {
+ return isLost
+ }
+
+ s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ return false
+ }
+ nDupSACKBytes += sacked.Start.Size(sacked.End)
+ nDupSACK++
+ if nDupSACK >= nDupAckThreshold || nDupSACKBytes >= seqnum.Size((nDupAckThreshold-1)*s.smss) {
+ isLost = true
+ return false
+ }
+ return true
+ })
+ return isLost
+}
+
+// IsLost implements the IsLost(SeqNum) operation defined in RFC3517 section
+// 4.
+//
+// This routine returns whether the given sequence number is considered to be
+// lost. The routine returns true when either nDupAckThreshold discontiguous
+// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS)
+// bytes with sequence numbers greater than 'SeqNum' have been SACKed.
+// Otherwise, the routine returns false.
+func (s *SACKScoreboard) IsLost(seq seqnum.Value) bool {
+ return s.IsRangeLost(header.SACKBlock{seq, seq.Add(1)})
+}
+
+// Empty returns true if the SACK scoreboard has no entries, false otherwise.
+func (s *SACKScoreboard) Empty() bool {
+ return s.ranges.Len() == 0
+}
+
+// Sacked returns the current number of bytes held in the SACK scoreboard.
+func (s *SACKScoreboard) Sacked() seqnum.Size {
+ return s.sacked
+}
+
+// MaxSACKED returns the highest sequence number ever inserted in the SACK
+// scoreboard.
+func (s *SACKScoreboard) MaxSACKED() seqnum.Value {
+ return s.maxSACKED
+}
+
+// SMSS returns the sender's MSS as held by the SACK scoreboard.
+func (s *SACKScoreboard) SMSS() uint16 {
+ return s.smss
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
new file mode 100644
index 000000000..450d9fbc1
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -0,0 +1,186 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// segment represents a TCP segment. It holds the payload and parsed TCP segment
+// information, and can be added to intrusive lists.
+// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+//
+// +stateify savable
+type segment struct {
+ segmentEntry
+ refCnt int32
+ id stack.TransportEndpointID `state:"manual"`
+ route stack.Route `state:"manual"`
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+ // viewToDeliver keeps track of the next View that should be
+ // delivered by the Read endpoint.
+ viewToDeliver int
+ sequenceNumber seqnum.Value
+ ackNumber seqnum.Value
+ flags uint8
+ window seqnum.Size
+ // csum is only populated for received segments.
+ csum uint16
+ // csumValid is true if the csum in the received segment is valid.
+ csumValid bool
+
+ // parsedOptions stores the parsed values from the options in the segment.
+ parsedOptions header.TCPOptions
+ options []byte `state:".([]byte)"`
+ hasNewSACKInfo bool
+ rcvdTime time.Time `state:".(unixTime)"`
+ // xmitTime is the last transmit time of this segment. A zero value
+ // indicates that the segment has yet to be transmitted.
+ xmitTime time.Time `state:".(unixTime)"`
+}
+
+func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.data = vv.Clone(s.views[:])
+ s.rcvdTime = time.Now()
+ return s
+}
+
+func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.views[0] = v
+ s.data = buffer.NewVectorisedView(len(v), s.views[:1])
+ s.rcvdTime = time.Now()
+ return s
+}
+
+func (s *segment) clone() *segment {
+ t := &segment{
+ refCnt: 1,
+ id: s.id,
+ sequenceNumber: s.sequenceNumber,
+ ackNumber: s.ackNumber,
+ flags: s.flags,
+ window: s.window,
+ route: s.route.Clone(),
+ viewToDeliver: s.viewToDeliver,
+ rcvdTime: s.rcvdTime,
+ }
+ t.data = s.data.Clone(t.views[:])
+ return t
+}
+
+func (s *segment) flagIsSet(flag uint8) bool {
+ return (s.flags & flag) != 0
+}
+
+func (s *segment) decRef() {
+ if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ s.route.Release()
+ }
+}
+
+func (s *segment) incRef() {
+ atomic.AddInt32(&s.refCnt, 1)
+}
+
+// logicalLen is the segment length in the sequence number space. It's defined
+// as the data length plus one for each of the SYN and FIN bits set.
+func (s *segment) logicalLen() seqnum.Size {
+ l := seqnum.Size(s.data.Size())
+ if s.flagIsSet(header.TCPFlagSyn) {
+ l++
+ }
+ if s.flagIsSet(header.TCPFlagFin) {
+ l++
+ }
+ return l
+}
+
+// parse populates the sequence & ack numbers, flags, and window fields of the
+// segment from the TCP header stored in the data. It then updates the view to
+// skip the header.
+//
+// Returns boolean indicating if the parsing was successful.
+//
+// If checksum verification is not offloaded then parse also verifies the
+// TCP checksum and stores the checksum and result of checksum verification in
+// the csum and csumValid fields of the segment.
+func (s *segment) parse() bool {
+ h := header.TCP(s.data.First())
+
+ // h is the header followed by the payload. We check that the offset to
+ // the data respects the following constraints:
+ // 1. That it's at least the minimum header size; if we don't do this
+ // then part of the header would be delivered to user.
+ // 2. That the header fits within the buffer; if we don't do this, we
+ // would panic when we tried to access data beyond the buffer.
+ //
+ // N.B. The segment has already been validated as having at least the
+ // minimum TCP size before reaching here, so it's safe to read the
+ // fields.
+ offset := int(h.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(h) {
+ return false
+ }
+
+ s.options = []byte(h[header.TCPMinimumSize:offset])
+ s.parsedOptions = header.ParseTCPOptions(s.options)
+
+ // Query the link capabilities to decide if checksum validation is
+ // required.
+ verifyChecksum := true
+ if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
+ s.csumValid = true
+ verifyChecksum = false
+ s.data.TrimFront(offset)
+ }
+ if verifyChecksum {
+ s.csum = h.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
+ xsum = h.CalculateChecksum(xsum)
+ s.data.TrimFront(offset)
+ xsum = header.ChecksumVV(s.data, xsum)
+ s.csumValid = xsum == 0xffff
+ }
+
+ s.sequenceNumber = seqnum.Value(h.SequenceNumber())
+ s.ackNumber = seqnum.Value(h.AckNumber())
+ s.flags = h.Flags()
+ s.window = seqnum.Size(h.WindowSize())
+ return true
+}
+
+// sackBlock returns a header.SACKBlock that represents this segment.
+func (s *segment) sackBlock() header.SACKBlock {
+ return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())}
+}
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
new file mode 100644
index 000000000..9fd061d7d
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+type segmentHeap []*segment
+
+// Len returns the length of h.
+func (h segmentHeap) Len() int {
+ return len(h)
+}
+
+// Less determines whether the i-th element of h is less than the j-th element.
+func (h segmentHeap) Less(i, j int) bool {
+ return h[i].sequenceNumber.LessThan(h[j].sequenceNumber)
+}
+
+// Swap swaps the i-th and j-th elements of h.
+func (h segmentHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+// Push adds x as the last element of h.
+func (h *segmentHeap) Push(x interface{}) {
+ *h = append(*h, x.(*segment))
+}
+
+// Pop removes the last element of h and returns it.
+func (h *segmentHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
new file mode 100644
index 000000000..e0759225e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+)
+
+// segmentQueue is a bounded, thread-safe queue of TCP segments.
+//
+// +stateify savable
+type segmentQueue struct {
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ limit int
+ used int
+}
+
+// empty determines if the queue is empty.
+func (q *segmentQueue) empty() bool {
+ q.mu.Lock()
+ r := q.used == 0
+ q.mu.Unlock()
+
+ return r
+}
+
+// setLimit updates the limit. No segments are immediately dropped in case the
+// queue becomes full due to the new limit.
+func (q *segmentQueue) setLimit(limit int) {
+ q.mu.Lock()
+ q.limit = limit
+ q.mu.Unlock()
+}
+
+// enqueue adds the given segment to the queue.
+//
+// Returns true when the segment is successfully added to the queue, in which
+// case ownership of the reference is transferred to the queue. And returns
+// false if the queue is full, in which case ownership is retained by the
+// caller.
+func (q *segmentQueue) enqueue(s *segment) bool {
+ q.mu.Lock()
+ r := q.used < q.limit
+ if r {
+ q.list.PushBack(s)
+ q.used++
+ }
+ q.mu.Unlock()
+
+ return r
+}
+
+// dequeue removes and returns the next segment from queue, if one exists.
+// Ownership is transferred to the caller, who is responsible for decrementing
+// the ref count when done.
+func (q *segmentQueue) dequeue() *segment {
+ q.mu.Lock()
+ s := q.list.Front()
+ if s != nil {
+ q.list.Remove(s)
+ q.used--
+ }
+ q.mu.Unlock()
+
+ return s
+}
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
new file mode 100644
index 000000000..dd7e14aa6
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// saveData is invoked by stateify.
+func (s *segment) saveData() buffer.VectorisedView {
+ // We cannot save s.data directly as s.data.views may alias to s.views,
+ // which is not allowed by state framework (in-struct pointer).
+ v := make([]buffer.View, len(s.data.Views()))
+ // For views already delivered, we cannot save them directly as they may
+ // have already been sliced and saved elsewhere (e.g., readViews).
+ for i := 0; i < s.viewToDeliver; i++ {
+ v[i] = append([]byte(nil), s.data.Views()[i]...)
+ }
+ for i := s.viewToDeliver; i < len(v); i++ {
+ v[i] = s.data.Views()[i]
+ }
+ return buffer.NewVectorisedView(s.data.Size(), v)
+}
+
+// loadData is invoked by stateify.
+func (s *segment) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the s.data = data.Clone(s.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing s.views for data.views.
+ s.data = data
+}
+
+// saveOptions is invoked by stateify.
+func (s *segment) saveOptions() []byte {
+ // We cannot save s.options directly as it may point to s.data's trimmed
+ // tail, which is not allowed by state framework (in-struct pointer).
+ b := make([]byte, 0, cap(s.options))
+ return append(b, s.options...)
+}
+
+// loadOptions is invoked by stateify.
+func (s *segment) loadOptions(options []byte) {
+ // NOTE: We cannot point s.options back into s.data's trimmed tail. But
+ // it is OK as they do not need to aliased. Plus, options is already
+ // allocated so there is no cost here.
+ s.options = options
+}
+
+// saveRcvdTime is invoked by stateify.
+func (s *segment) saveRcvdTime() unixTime {
+ return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
+}
+
+// loadRcvdTime is invoked by stateify.
+func (s *segment) loadRcvdTime(unix unixTime) {
+ s.rcvdTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveXmitTime is invoked by stateify.
+func (s *segment) saveXmitTime() unixTime {
+ return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (s *segment) loadXmitTime(unix unixTime) {
+ s.rcvdTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
new file mode 100644
index 000000000..afc1d0a55
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -0,0 +1,1180 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // minRTO is the minimum allowed value for the retransmit timeout.
+ minRTO = 200 * time.Millisecond
+
+ // InitialCwnd is the initial congestion window.
+ InitialCwnd = 10
+
+ // nDupAckThreshold is the number of duplicate ACK's required
+ // before fast-retransmit is entered.
+ nDupAckThreshold = 3
+)
+
+// congestionControl is an interface that must be implemented by any supported
+// congestion control algorithm.
+type congestionControl interface {
+ // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold
+ // just before entering fast retransmit.
+ HandleNDupAcks()
+
+ // HandleRTOExpired is invoked when the retransmit timer expires.
+ HandleRTOExpired()
+
+ // Update is invoked when processing inbound acks. It's passed the
+ // number of packet's that were acked by the most recent cumulative
+ // acknowledgement.
+ Update(packetsAcked int)
+
+ // PostRecovery is invoked when the sender is exiting a fast retransmit/
+ // recovery phase. This provides congestion control algorithms a way
+ // to adjust their state when exiting recovery.
+ PostRecovery()
+}
+
+// sender holds the state necessary to send TCP segments.
+//
+// +stateify savable
+type sender struct {
+ ep *endpoint
+
+ // lastSendTime is the timestamp when the last packet was sent.
+ lastSendTime time.Time `state:".(unixTime)"`
+
+ // dupAckCount is the number of duplicated acks received. It is used for
+ // fast retransmit.
+ dupAckCount int
+
+ // fr holds state related to fast recovery.
+ fr fastRecovery
+
+ // sndCwnd is the congestion window, in packets.
+ sndCwnd int
+
+ // sndSsthresh is the threshold between slow start and congestion
+ // avoidance.
+ sndSsthresh int
+
+ // sndCAAckCount is the number of packets acknowledged during congestion
+ // avoidance. When enough packets have been ack'd (typically cwnd
+ // packets), the congestion window is incremented by one.
+ sndCAAckCount int
+
+ // outstanding is the number of outstanding packets, that is, packets
+ // that have been sent but not yet acknowledged.
+ outstanding int
+
+ // sndWnd is the send window size.
+ sndWnd seqnum.Size
+
+ // sndUna is the next unacknowledged sequence number.
+ sndUna seqnum.Value
+
+ // sndNxt is the sequence number of the next segment to be sent.
+ sndNxt seqnum.Value
+
+ // sndNxtList is the sequence number of the next segment to be added to
+ // the send list.
+ sndNxtList seqnum.Value
+
+ // rttMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ rttMeasureSeqNum seqnum.Value
+
+ // rttMeasureTime is the time when the rttMeasureSeqNum was sent.
+ rttMeasureTime time.Time `state:".(unixTime)"`
+
+ closed bool
+ writeNext *segment
+ writeList segmentList
+ resendTimer timer `state:"nosave"`
+ resendWaker sleep.Waker `state:"nosave"`
+
+ // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
+ // "round-trip time variation" and "retransmit timeout", as defined in
+ // section 2 of RFC 6298.
+ rtt rtt
+ rto time.Duration
+ srttInited bool
+
+ // maxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ maxPayloadSize int
+
+ // gso is set if generic segmentation offload is enabled.
+ gso bool
+
+ // sndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ sndWndScale uint8
+
+ // maxSentAck is the maxium acknowledgement actually sent.
+ maxSentAck seqnum.Value
+
+ // cc is the congestion control algorithm in use for this sender.
+ cc congestionControl
+}
+
+// rtt is a synchronization wrapper used to appease stateify. See the comment
+// in sender, where it is used.
+//
+// +stateify savable
+type rtt struct {
+ sync.Mutex `state:"nosave"`
+
+ srtt time.Duration
+ rttvar time.Duration
+}
+
+// fastRecovery holds information related to fast recovery from a packet loss.
+//
+// +stateify savable
+type fastRecovery struct {
+ // active whether the endpoint is in fast recovery. The following fields
+ // are only meaningful when active is true.
+ active bool
+
+ // first and last represent the inclusive sequence number range being
+ // recovered.
+ first seqnum.Value
+ last seqnum.Value
+
+ // maxCwnd is the maximum value the congestion window may be inflated to
+ // due to duplicate acks. This exists to avoid attacks where the
+ // receiver intentionally sends duplicate acks to artificially inflate
+ // the sender's cwnd.
+ maxCwnd int
+
+ // highRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase.
+ // See: RFC 6675 Section 2 for details.
+ highRxt seqnum.Value
+
+ // rescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission.
+ // See: RFC 6675 Section 2 for details.
+ rescueRxt seqnum.Value
+}
+
+func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
+ // The sender MUST reduce the TCP data length to account for any IP or
+ // TCP options that it is including in the packets that it sends.
+ // See: https://tools.ietf.org/html/rfc6691#section-2
+ maxPayloadSize := int(mss) - ep.maxOptionSize()
+
+ s := &sender{
+ ep: ep,
+ sndCwnd: InitialCwnd,
+ sndSsthresh: math.MaxInt64,
+ sndWnd: sndWnd,
+ sndUna: iss + 1,
+ sndNxt: iss + 1,
+ sndNxtList: iss + 1,
+ rto: 1 * time.Second,
+ rttMeasureSeqNum: iss + 1,
+ lastSendTime: time.Now(),
+ maxPayloadSize: maxPayloadSize,
+ maxSentAck: irs + 1,
+ fr: fastRecovery{
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
+ last: iss,
+ highRxt: iss,
+ rescueRxt: iss,
+ },
+ gso: ep.gso != nil,
+ }
+
+ if s.gso {
+ s.ep.gso.MSS = uint16(maxPayloadSize)
+ }
+
+ s.cc = s.initCongestionControl(ep.cc)
+
+ // A negative sndWndScale means that no scaling is in use, otherwise we
+ // store the scaling value.
+ if sndWndScale > 0 {
+ s.sndWndScale = uint8(sndWndScale)
+ }
+
+ s.resendTimer.init(&s.resendWaker)
+
+ s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
+
+ // Initialize SACK Scoreboard after updating max payload size as we use
+ // the maxPayloadSize as the smss when determining if a segment is lost
+ // etc.
+ s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+
+ return s
+}
+
+func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+ switch congestionControlName {
+ case ccCubic:
+ return newCubicCC(s)
+ case ccReno:
+ fallthrough
+ default:
+ return newRenoCC(s)
+ }
+}
+
+// updateMaxPayloadSize updates the maximum payload size based on the given
+// MTU. If this is in response to "packet too big" control packets (indicated
+// by the count argument), it also reduces the number of outstanding packets and
+// attempts to retransmit the first packet above the MTU size.
+func (s *sender) updateMaxPayloadSize(mtu, count int) {
+ m := mtu - header.TCPMinimumSize
+
+ m -= s.ep.maxOptionSize()
+
+ // We don't adjust up for now.
+ if m >= s.maxPayloadSize {
+ return
+ }
+
+ // Make sure we can transmit at least one byte.
+ if m <= 0 {
+ m = 1
+ }
+
+ s.maxPayloadSize = m
+ if s.gso {
+ s.ep.gso.MSS = uint16(m)
+ }
+
+ if count == 0 {
+ // updateMaxPayloadSize is also called when the sender is created.
+ // and there is no data to send in such cases. Return immediately.
+ return
+ }
+
+ // Update the scoreboard's smss to reflect the new lowered
+ // maxPayloadSize.
+ s.ep.scoreboard.smss = uint16(m)
+
+ s.outstanding -= count
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ // Rewind writeNext to the first segment exceeding the MTU. Do nothing
+ // if it is already before such a packet.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if seg == s.writeNext {
+ // We got to writeNext before we could find a segment
+ // exceeding the MTU.
+ break
+ }
+
+ if seg.data.Size() > m {
+ // We found a segment exceeding the MTU. Rewind
+ // writeNext and try to retransmit it.
+ s.writeNext = seg
+ break
+ }
+ }
+
+ // Since we likely reduced the number of outstanding packets, we may be
+ // ready to send some more.
+ s.sendData()
+}
+
+// sendAck sends an ACK segment.
+func (s *sender) sendAck() {
+ s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt)
+}
+
+// updateRTO updates the retransmit timeout when a new roud-trip time is
+// available. This is done in accordance with section 2 of RFC 6298.
+func (s *sender) updateRTO(rtt time.Duration) {
+ s.rtt.Lock()
+ if !s.srttInited {
+ s.rtt.rttvar = rtt / 2
+ s.rtt.srtt = rtt
+ s.srttInited = true
+ } else {
+ diff := s.rtt.srtt - rtt
+ if diff < 0 {
+ diff = -diff
+ }
+ // Use RFC6298 standard algorithm to update rttvar and srtt when
+ // no timestamps are available.
+ if !s.ep.sendTSOk {
+ s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4
+ s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8
+ } else {
+ // When we are taking RTT measurements of every ACK then
+ // we need to use a modified method as specified in
+ // https://tools.ietf.org/html/rfc7323#appendix-G
+ if s.outstanding == 0 {
+ s.rtt.Unlock()
+ return
+ }
+ // Netstack measures congestion window/inflight all in
+ // terms of packets and not bytes. This is similar to
+ // how linux also does cwnd and inflight. In practice
+ // this approximation works as expected.
+ expectedSamples := math.Ceil(float64(s.outstanding) / 2)
+
+ // alpha & beta values are the original values as recommended in
+ // https://tools.ietf.org/html/rfc6298#section-2.3.
+ const alpha = 0.125
+ const beta = 0.25
+
+ alphaPrime := alpha / expectedSamples
+ betaPrime := beta / expectedSamples
+ rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds()
+ srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds()
+ s.rtt.rttvar = time.Duration(rttVar * float64(time.Second))
+ s.rtt.srtt = time.Duration(srtt * float64(time.Second))
+ }
+ }
+
+ s.rto = s.rtt.srtt + 4*s.rtt.rttvar
+ s.rtt.Unlock()
+ if s.rto < minRTO {
+ s.rto = minRTO
+ }
+}
+
+// resendSegment resends the first unacknowledged segment.
+func (s *sender) resendSegment() {
+ // Don't use any segments we already sent to measure RTT as they may
+ // have been affected by packets being lost.
+ s.rttMeasureSeqNum = s.sndNxt
+
+ // Resend the segment.
+ if seg := s.writeList.Front(); seg != nil {
+ if seg.data.Size() > s.maxPayloadSize {
+ s.splitSeg(seg, s.maxPayloadSize)
+ }
+
+ // See: RFC 6675 section 5 Step 4.3
+ //
+ // To prevent retransmission, set both the HighRXT and RescueRXT
+ // to the highest sequence number in the retransmitted segment.
+ s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.sendSegment(seg)
+ s.ep.stack.Stats().TCP.FastRetransmit.Increment()
+
+ // Run SetPipe() as per RFC 6675 section 5 Step 4.4
+ s.SetPipe()
+ }
+}
+
+// retransmitTimerExpired is called when the retransmit timer expires, and
+// unacknowledged segments are assumed lost, and thus need to be resent.
+// Returns true if the connection is still usable, or false if the connection
+// is deemed lost.
+func (s *sender) retransmitTimerExpired() bool {
+ // Check if the timer actually expired or if it's a spurious wake due
+ // to a previously orphaned runtime timer.
+ if !s.resendTimer.checkExpiration() {
+ return true
+ }
+
+ s.ep.stack.Stats().TCP.Timeouts.Increment()
+
+ // Give up if we've waited more than a minute since the last resend.
+ if s.rto >= 60*time.Second {
+ return false
+ }
+
+ // Set new timeout. The timer will be restarted by the call to sendData
+ // below.
+ s.rto *= 2
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
+ //
+ // Retransmit timeouts:
+ // After a retransmit timeout, record the highest sequence number
+ // transmitted in the variable recover, and exit the fast recovery
+ // procedure if applicable.
+ s.fr.last = s.sndNxt - 1
+
+ if s.fr.active {
+ // We were attempting fast recovery but were not successful.
+ // Leave the state. We don't need to update ssthresh because it
+ // has already been updated when entered fast-recovery.
+ s.leaveFastRecovery()
+ }
+
+ s.cc.HandleRTOExpired()
+
+ // Mark the next segment to be sent as the first unacknowledged one and
+ // start sending again. Set the number of outstanding packets to 0 so
+ // that we'll be able to retransmit.
+ //
+ // We'll keep on transmitting (or retransmitting) as we get acks for
+ // the data we transmit.
+ s.outstanding = 0
+
+ // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1
+ //
+ // In order to avoid memory deadlocks, the TCP receiver is allowed to
+ // discard data that has already been selectively acknowledged. As a
+ // result, [RFC2018] suggests that a TCP sender SHOULD expunge the SACK
+ // information gathered from a receiver upon a retransmission timeout
+ // (RTO) "since the timeout might indicate that the data receiver has
+ // reneged." Additionally, a TCP sender MUST "ignore prior SACK
+ // information in determining which data to retransmit."
+ //
+ // NOTE: We take the stricter interpretation and just expunge all
+ // information as we lack more rigorous checks to validate if the SACK
+ // information is usable after an RTO.
+ s.ep.scoreboard.Reset()
+ s.writeNext = s.writeList.Front()
+ s.sendData()
+
+ return true
+}
+
+// pCount returns the number of packets in the segment. Due to GSO, a segment
+// can be composed of multiple packets.
+func (s *sender) pCount(seg *segment) int {
+ size := seg.data.Size()
+ if size == 0 {
+ return 1
+ }
+
+ return (size-1)/s.maxPayloadSize + 1
+}
+
+// splitSeg splits a given segment at the size specified and inserts the
+// remainder as a new segment after the current one in the write list.
+func (s *sender) splitSeg(seg *segment, size int) {
+ if seg.data.Size() <= size {
+ return
+ }
+ // Split this segment up.
+ nSeg := seg.clone()
+ nSeg.data.TrimFront(size)
+ nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
+ s.writeList.InsertAfter(seg, nSeg)
+ seg.data.CapLength(size)
+}
+
+// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that
+// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2
+// is handled by the normal send logic.
+func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
+ var s3 *segment
+ var s4 *segment
+ smss := s.ep.scoreboard.SMSS()
+ // Step 1.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if !s.isAssignedSequenceNumber(seg) {
+ break
+ }
+ segSeq := seg.sequenceNumber
+ if seg.data.Size() > int(smss) {
+ s.splitSeg(seg, int(smss))
+ }
+ // See RFC 6675 Section 4
+ //
+ // 1. If there exists a smallest unSACKED sequence number
+ // 'S2' that meets the following 3 criteria for determinig
+ // loss, the sequence range of one segment of up to SMSS
+ // octects starting with S2 MUST be returned.
+ if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) {
+ // NextSeg():
+ //
+ // (1.a) S2 is greater than HighRxt
+ // (1.b) S2 is less than highest octect covered by
+ // any received SACK.
+ if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
+ // NextSeg():
+ // (1.c) IsLost(S2) returns true.
+ if s.ep.scoreboard.IsLost(segSeq) {
+ return seg, s3, s4
+ }
+ // NextSeg():
+ //
+ // (3): If the conditions for rules (1) and (2)
+ // fail, but there exists an unSACKed sequence
+ // number S3 that meets the criteria for
+ // detecting loss given in steps 1.a and 1.b
+ // above (specifically excluding (1.c)) then one
+ // segment of upto SMSS octets starting with S3
+ // SHOULD be returned.
+ if s3 == nil {
+ s3 = seg
+ }
+ }
+ // NextSeg():
+ //
+ // (4) If the conditions for (1), (2) and (3) fail,
+ // but there exists outstanding unSACKED data, we
+ // provide the opportunity for a single "rescue"
+ // retransmission per entry into loss recovery. If
+ // HighACK is greater than RescueRxt, the one
+ // segment of upto SMSS octects that MUST include
+ // the highest outstanding unSACKed sequence number
+ // SHOULD be returned.
+ if s.fr.rescueRxt.LessThan(s.sndUna - 1) {
+ if s4 != nil {
+ if s4.sequenceNumber.LessThan(segSeq) {
+ s4 = seg
+ }
+ } else {
+ s4 = seg
+ }
+ s.fr.rescueRxt = s.fr.last
+ }
+ }
+ }
+
+ return nil, s3, s4
+}
+
+// maybeSendSegment tries to send the specified segment and either coalesces
+// other segments into this one or splits the specified segment based on the
+// lower of the specified limit value or the receivers window size specified by
+// end.
+func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (sent bool) {
+ // We abuse the flags field to determine if we have already
+ // assigned a sequence number to this segment.
+ if !s.isAssignedSequenceNumber(seg) {
+ // Merge segments if allowed.
+ if seg.data.Size() != 0 {
+ available := int(seg.sequenceNumber.Size(end))
+ if available > limit {
+ available = limit
+ }
+
+ // nextTooBig indicates that the next segment was too
+ // large to entirely fit in the current segment. It
+ // would be possible to split the next segment and merge
+ // the portion that fits, but unexpectedly splitting
+ // segments can have user visible side-effects which can
+ // break applications. For example, RFC 7766 section 8
+ // says that the length and data of a DNS response
+ // should be sent in the same TCP segment to avoid
+ // triggering bugs in poorly written DNS
+ // implementations.
+ var nextTooBig bool
+ for seg.Next() != nil && seg.Next().data.Size() != 0 {
+ if seg.data.Size()+seg.Next().data.Size() > available {
+ nextTooBig = true
+ break
+ }
+ seg.data.Append(seg.Next().data)
+
+ // Consume the segment that we just merged in.
+ s.writeList.Remove(seg.Next())
+ }
+ if !nextTooBig && seg.data.Size() < available {
+ // Segment is not full.
+ if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ // Nagle's algorithm. From Wikipedia:
+ // Nagle's algorithm works by
+ // combining a number of small
+ // outgoing messages and sending them
+ // all at once. Specifically, as long
+ // as there is a sent packet for which
+ // the sender has received no
+ // acknowledgment, the sender should
+ // keep buffering its output until it
+ // has a full packet's worth of
+ // output, thus allowing output to be
+ // sent all at once.
+ return false
+ }
+ if atomic.LoadUint32(&s.ep.cork) != 0 {
+ // Hold back the segment until full.
+ return false
+ }
+ }
+ }
+
+ // Assign flags. We don't do it above so that we can merge
+ // additional data if Nagle holds the segment.
+ seg.sequenceNumber = s.sndNxt
+ seg.flags = header.TCPFlagAck | header.TCPFlagPsh
+ }
+
+ var segEnd seqnum.Value
+ if seg.data.Size() == 0 {
+ if s.writeList.Back() != seg {
+ panic("FIN segments must be the final segment in the write list.")
+ }
+ seg.flags = header.TCPFlagAck | header.TCPFlagFin
+ segEnd = seg.sequenceNumber.Add(1)
+ } else {
+ // We're sending a non-FIN segment.
+ if seg.flags&header.TCPFlagFin != 0 {
+ panic("Netstack queues FIN segments without data.")
+ }
+
+ if !seg.sequenceNumber.LessThan(end) {
+ return false
+ }
+
+ available := int(seg.sequenceNumber.Size(end))
+ if available == 0 {
+ return false
+ }
+ if available > limit {
+ available = limit
+ }
+
+ if seg.data.Size() > available {
+ s.splitSeg(seg, available)
+ }
+
+ segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ }
+
+ s.sendSegment(seg)
+
+ // Update sndNxt if we actually sent new data (as opposed to
+ // retransmitting some previously sent data).
+ if s.sndNxt.LessThan(segEnd) {
+ s.sndNxt = segEnd
+ }
+
+ return true
+}
+
+// handleSACKRecovery implements the loss recovery phase as described in RFC6675
+// section 5, step C.
+func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
+ s.SetPipe()
+ for s.outstanding < s.sndCwnd {
+ nextSeg, s3, s4 := s.NextSeg()
+ if nextSeg == nil {
+ // NextSeg():
+ //
+ // Step (2): "If no sequence number 'S2' per rule (1)
+ // exists but there exists available unsent data and the
+ // receiver's advertised window allows, the sequence
+ // range of one segment of up to SMSS octets of
+ // previously unsent data starting with sequence number
+ // HighData+1 MUST be returned."
+ for seg := s.writeNext; seg != nil; seg = seg.Next() {
+ if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
+ continue
+ }
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
+ //
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = seg.Next()
+ nextSeg = seg
+ break
+ }
+ if nextSeg != nil {
+ continue
+ }
+ }
+ rescueRtx := false
+ if nextSeg == nil && s3 != nil {
+ nextSeg = s3
+ }
+ if nextSeg == nil && s4 != nil {
+ nextSeg = s4
+ rescueRtx = true
+ }
+ if nextSeg == nil {
+ break
+ }
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ s.fr.highRxt = segEnd - 1
+ }
+
+ // RFC 6675, Step C.4.
+ //
+ // "The estimate of the amount of data outstanding in the network
+ // must be updated by incrementing pipe by the number of octets
+ // transmitted in (C.1)."
+ s.outstanding++
+ dataSent = true
+ s.sendSegment(nextSeg)
+ }
+ return dataSent
+}
+
+// sendData sends new data segments. It is called when data becomes available or
+// when the send window opens up.
+func (s *sender) sendData() {
+ limit := s.maxPayloadSize
+ if s.gso {
+ limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize)
+ }
+ end := s.sndUna.Add(s.sndWnd)
+
+ // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
+ // "A TCP SHOULD set cwnd to no more than RW before beginning
+ // transmission if the TCP has not sent data in the interval exceeding
+ // the retrasmission timeout."
+ if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
+ if s.sndCwnd > InitialCwnd {
+ s.sndCwnd = InitialCwnd
+ }
+ }
+
+ var dataSent bool
+
+ // RFC 6675 recovery algorithm step C 1-5.
+ if s.fr.active && s.ep.sackPermitted {
+ dataSent = s.handleSACKRecovery(s.maxPayloadSize, end)
+ } else {
+ for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ if cwndLimit < limit {
+ limit = cwndLimit
+ }
+ if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ continue
+ }
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = seg.Next()
+ }
+ }
+
+ if dataSent {
+ // We sent data, so we should stop the keepalive timer to ensure
+ // that no keepalives are sent while there is pending data.
+ s.ep.disableKeepaliveTimer()
+ }
+
+ // Enable the timer if we have pending data and it's not enabled yet.
+ if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
+ s.resendTimer.enable(s.rto)
+ }
+ // If we have no more pending data, start the keepalive timer.
+ if s.sndUna == s.sndNxt {
+ s.ep.resetKeepaliveTimer(false)
+ }
+}
+
+func (s *sender) enterFastRecovery() {
+ s.fr.active = true
+ // Save state to reflect we're now in fast recovery.
+ //
+ // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
+ // We inflate the cwnd by 3 to account for the 3 packets which triggered
+ // the 3 duplicate ACKs and are now not in flight.
+ s.sndCwnd = s.sndSsthresh + 3
+ s.fr.first = s.sndUna
+ s.fr.last = s.sndNxt - 1
+ s.fr.maxCwnd = s.sndCwnd + s.outstanding
+ if s.ep.sackPermitted {
+ s.ep.stack.Stats().TCP.SACKRecovery.Increment()
+ return
+ }
+ s.ep.stack.Stats().TCP.FastRecovery.Increment()
+}
+
+func (s *sender) leaveFastRecovery() {
+ s.fr.active = false
+ s.fr.maxCwnd = 0
+ s.dupAckCount = 0
+
+ // Deflate cwnd. It had been artificially inflated when new dups arrived.
+ s.sndCwnd = s.sndSsthresh
+
+ s.cc.PostRecovery()
+}
+
+func (s *sender) handleFastRecovery(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ // We are in fast recovery mode. Ignore the ack if it's out of
+ // range.
+ if !ack.InRange(s.sndUna, s.sndNxt+1) {
+ return false
+ }
+
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if s.fr.last.LessThan(ack) {
+ s.leaveFastRecovery()
+ return false
+ }
+
+ if s.ep.sackPermitted {
+ // When SACK is enabled we let retransmission be governed by
+ // the SACK logic.
+ return false
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if seg.logicalLen() != 0 || s.sndWnd != seg.window {
+ return false
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if ack == s.fr.first {
+ // We received a dup, inflate the congestion window by 1 packet
+ // if we're not at the max yet. Only inflate the window if
+ // regular FastRecovery is in use, RFC6675 does not require
+ // inflating cwnd on duplicate ACKs.
+ if s.sndCwnd < s.fr.maxCwnd {
+ s.sndCwnd++
+ }
+ return false
+ }
+
+ // A partial ack was received. Retransmit this packet and
+ // remember it so that we don't retransmit it again. We don't
+ // inflate the window because we're putting the same packet back
+ // onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ s.fr.first = ack
+ s.dupAckCount = 0
+ return true
+}
+
+// isAssignedSequenceNumber relies on the fact that we only set flags once a
+// sequencenumber is assigned and that is only done right before we send the
+// segment. As a result any segment that has a non-zero flag has a valid
+// sequence number assigned to it.
+func (s *sender) isAssignedSequenceNumber(seg *segment) bool {
+ return seg.flags != 0
+}
+
+// SetPipe implements the SetPipe() function described in RFC6675. Netstack
+// maintains the congestion window in number of packets and not bytes, so
+// SetPipe() here measures number of outstanding packets rather than actual
+// outstanding bytes in the network.
+func (s *sender) SetPipe() {
+ // If SACK isn't permitted or it is permitted but recovery is not active
+ // then ignore pipe calculations.
+ if !s.ep.sackPermitted || !s.fr.active {
+ return
+ }
+ pipe := 0
+ smss := seqnum.Size(s.ep.scoreboard.SMSS())
+ for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() {
+ // With GSO each segment can be much larger than SMSS. So check the segment
+ // in SMSS sized ranges.
+ segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size()))
+ for startSeq := s1.sequenceNumber; startSeq.LessThan(segEnd); startSeq = startSeq.Add(smss) {
+ endSeq := startSeq.Add(smss)
+ if segEnd.LessThan(endSeq) {
+ endSeq = segEnd
+ }
+ sb := header.SACKBlock{startSeq, endSeq}
+ // SetPipe():
+ //
+ // After initializing pipe to zero, the following steps are
+ // taken for each octet 'S1' in the sequence space between
+ // HighACK and HighData that has not been SACKed:
+ if !s1.sequenceNumber.LessThan(s.sndNxt) {
+ break
+ }
+ if s.ep.scoreboard.IsSACKED(sb) {
+ continue
+ }
+
+ // SetPipe():
+ //
+ // (a) If IsLost(S1) returns false, Pipe is incremened by 1.
+ //
+ // NOTE: here we mark the whole segment as lost. We do not try
+ // and test every byte in our write buffer as we maintain our
+ // pipe in terms of oustanding packets and not bytes.
+ if !s.ep.scoreboard.IsRangeLost(sb) {
+ pipe++
+ }
+ // SetPipe():
+ // (b) If S1 <= HighRxt, Pipe is incremented by 1.
+ if s1.sequenceNumber.LessThanEq(s.fr.highRxt) {
+ pipe++
+ }
+ }
+ }
+ s.outstanding = pipe
+}
+
+// checkDuplicateAck is called when an ack is received. It manages the state
+// related to duplicate acks and determines if a retransmit is needed according
+// to the rules in RFC 6582 (NewReno).
+func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ if s.fr.active {
+ return s.handleFastRecovery(seg)
+ }
+
+ // We're not in fast recovery yet. A segment is considered a duplicate
+ // only if it doesn't carry any data and doesn't update the send window,
+ // because if it does, it wasn't sent in response to an out-of-order
+ // segment. If SACK is enabled then we have an additional check to see
+ // if the segment carries new SACK information. If it does then it is
+ // considered a duplicate ACK as per RFC6675.
+ if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt {
+ if !s.ep.sackPermitted || !seg.hasNewSACKInfo {
+ s.dupAckCount = 0
+ return false
+ }
+ }
+
+ s.dupAckCount++
+
+ // Do not enter fast recovery until we reach nDupAckThreshold or the
+ // first unacknowledged byte is considered lost as per SACK scoreboard.
+ if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) {
+ // RFC 6675 Step 3.
+ s.fr.highRxt = s.sndUna - 1
+ // Do run SetPipe() to calculate the outstanding segments.
+ s.SetPipe()
+ return false
+ }
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
+ //
+ // We only do the check here, the incrementing of last to the highest
+ // sequence number transmitted till now is done when enterFastRecovery
+ // is invoked.
+ if !s.fr.last.LessThan(seg.ackNumber) {
+ s.dupAckCount = 0
+ return false
+ }
+ s.cc.HandleNDupAcks()
+ s.enterFastRecovery()
+ s.dupAckCount = 0
+ return true
+}
+
+// handleRcvdSegment is called when a segment is received; it is responsible for
+// updating the send-related state.
+func (s *sender) handleRcvdSegment(seg *segment) {
+ // Check if we can extract an RTT measurement from this ack.
+ if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ s.updateRTO(time.Now().Sub(s.rttMeasureTime))
+ s.rttMeasureSeqNum = s.sndNxt
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if s.ep.sendTSOk && seg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ }
+
+ // Insert SACKBlock information into our scoreboard.
+ if s.ep.sackPermitted {
+ for _, sb := range seg.parsedOptions.SACKBlocks {
+ // Only insert the SACK block if the following holds
+ // true:
+ // * SACK block acks data after the ack number in the
+ // current segment.
+ // * SACK block represents a sequence
+ // between sndUna and sndNxt (i.e. data that is
+ // currently unacked and in-flight).
+ // * SACK block that has not been SACKed already.
+ //
+ // NOTE: This check specifically excludes DSACK blocks
+ // which have start/end before sndUna and are used to
+ // indicate spurious retransmissions.
+ if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ s.ep.scoreboard.Insert(sb)
+ seg.hasNewSACKInfo = true
+ }
+ }
+ s.SetPipe()
+ }
+
+ // Count the duplicates and do the fast retransmit if needed.
+ rtx := s.checkDuplicateAck(seg)
+
+ // Stash away the current window size.
+ s.sndWnd = seg.window
+
+ // Ignore ack if it doesn't acknowledge any new data.
+ ack := seg.ackNumber
+ if (ack - 1).InRange(s.sndUna, s.sndNxt) {
+ s.dupAckCount = 0
+
+ // See : https://tools.ietf.org/html/rfc1323#section-3.3.
+ // Specifically we should only update the RTO using TSEcr if the
+ // following condition holds:
+ //
+ // A TSecr value received in a segment is used to update the
+ // averaged RTT measurement only if the segment acknowledges
+ // some new data, i.e., only if it advances the left edge of
+ // the send window.
+ if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ // TSVal/Ecr values sent by Netstack are at a millisecond
+ // granularity.
+ elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ s.updateRTO(elapsed)
+ }
+
+ // When an ack is received we must rearm the timer.
+ // RFC 6298 5.2
+ s.resendTimer.enable(s.rto)
+
+ // Remove all acknowledged data from the write list.
+ acked := s.sndUna.Size(ack)
+ s.sndUna = ack
+
+ ackLeft := acked
+ originalOutstanding := s.outstanding
+ for ackLeft > 0 {
+ // We use logicalLen here because we can have FIN
+ // segments (which are always at the end of list) that
+ // have no data, but do consume a sequence number.
+ seg := s.writeList.Front()
+ datalen := seg.logicalLen()
+
+ if datalen > ackLeft {
+ prevCount := s.pCount(seg)
+ seg.data.TrimFront(int(ackLeft))
+ seg.sequenceNumber.UpdateForward(ackLeft)
+ s.outstanding -= prevCount - s.pCount(seg)
+ break
+ }
+
+ if s.writeNext == seg {
+ s.writeNext = seg.Next()
+ }
+ s.writeList.Remove(seg)
+
+ // if SACK is enabled then Only reduce outstanding if
+ // the segment was not previously SACKED as these have
+ // already been accounted for in SetPipe().
+ if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ s.outstanding -= s.pCount(seg)
+ }
+ seg.decRef()
+ ackLeft -= datalen
+ }
+
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
+ // Clear SACK information for all acked data.
+ s.ep.scoreboard.Delete(s.sndUna)
+
+ // If we are not in fast recovery then update the congestion
+ // window based on the number of acknowledged packets.
+ if !s.fr.active {
+ s.cc.Update(originalOutstanding - s.outstanding)
+ }
+
+ // It is possible for s.outstanding to drop below zero if we get
+ // a retransmit timeout, reset outstanding to zero but later
+ // get an ack that cover previously sent data.
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ s.SetPipe()
+
+ // If all outstanding data was acknowledged the disable the timer.
+ // RFC 6298 Rule 5.3
+ if s.sndUna == s.sndNxt {
+ s.outstanding = 0
+ s.resendTimer.disable()
+ }
+ }
+ // Now that we've popped all acknowledged data from the retransmit
+ // queue, retransmit if needed.
+ if rtx {
+ s.resendSegment()
+ }
+
+ // Send more data now that some of the pending data has been ack'd, or
+ // that the window opened up, or the congestion window was inflated due
+ // to a duplicate ack during fast recovery. This will also re-enable
+ // the retransmit timer if needed.
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ s.sendData()
+ }
+}
+
+// sendSegment sends the specified segment.
+func (s *sender) sendSegment(seg *segment) *tcpip.Error {
+ if !seg.xmitTime.IsZero() {
+ s.ep.stack.Stats().TCP.Retransmits.Increment()
+ if s.sndCwnd < s.sndSsthresh {
+ s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
+ }
+ }
+ seg.xmitTime = time.Now()
+ return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
+}
+
+// sendSegmentFromView sends a new segment containing the given payload, flags
+// and sequence number.
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+ s.lastSendTime = time.Now()
+ if seq == s.rttMeasureSeqNum {
+ s.rttMeasureTime = s.lastSendTime
+ }
+
+ rcvNxt, rcvWnd := s.ep.rcv.getSendParams()
+
+ // Remember the max sent ack.
+ s.maxSentAck = rcvNxt
+
+ // Every time a packet containing data is sent (including a
+ // retransmission), if SACK is enabled then use the conservative timer
+ // described in RFC6675 Section 4.0, otherwise follow the standard time
+ // described in RFC6298 Section 5.2.
+ if data.Size() != 0 {
+ if s.ep.sackPermitted {
+ s.resendTimer.enable(s.rto)
+ } else {
+ if !s.resendTimer.enabled() {
+ s.resendTimer.enable(s.rto)
+ }
+ }
+ }
+
+ return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
+}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
new file mode 100644
index 000000000..12eff8afc
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -0,0 +1,50 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastSendTime is invoked by stateify.
+func (s *sender) saveLastSendTime() unixTime {
+ return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()}
+}
+
+// loadLastSendTime is invoked by stateify.
+func (s *sender) loadLastSendTime(unix unixTime) {
+ s.lastSendTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRttMeasureTime is invoked by stateify.
+func (s *sender) saveRttMeasureTime() unixTime {
+ return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()}
+}
+
+// loadRttMeasureTime is invoked by stateify.
+func (s *sender) loadRttMeasureTime(unix unixTime) {
+ s.rttMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// afterLoad is invoked by stateify.
+func (s *sender) afterLoad() {
+ s.resendTimer.init(&s.resendWaker)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go
new file mode 100755
index 000000000..029f98a11
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go
@@ -0,0 +1,173 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type segmentElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type segmentList struct {
+ head *segment
+ tail *segment
+}
+
+// Reset resets list l to the empty state.
+func (l *segmentList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *segmentList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *segmentList) Front() *segment {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *segmentList) Back() *segment {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *segmentList) PushFront(e *segment) {
+ segmentElementMapper{}.linkerFor(e).SetNext(l.head)
+ segmentElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *segmentList) PushBack(e *segment) {
+ segmentElementMapper{}.linkerFor(e).SetNext(nil)
+ segmentElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *segmentList) PushBackList(m *segmentList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *segmentList) InsertAfter(b, e *segment) {
+ a := segmentElementMapper{}.linkerFor(b).Next()
+ segmentElementMapper{}.linkerFor(e).SetNext(a)
+ segmentElementMapper{}.linkerFor(e).SetPrev(b)
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *segmentList) InsertBefore(a, e *segment) {
+ b := segmentElementMapper{}.linkerFor(a).Prev()
+ segmentElementMapper{}.linkerFor(e).SetNext(a)
+ segmentElementMapper{}.linkerFor(e).SetPrev(b)
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *segmentList) Remove(e *segment) {
+ prev := segmentElementMapper{}.linkerFor(e).Prev()
+ next := segmentElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ segmentElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ segmentElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type segmentEntry struct {
+ next *segment
+ prev *segment
+}
+
+// Next returns the entry that follows e in the list.
+func (e *segmentEntry) Next() *segment {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *segmentEntry) Prev() *segment {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *segmentEntry) SetNext(elem *segment) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *segmentEntry) SetPrev(elem *segment) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
new file mode 100755
index 000000000..9049a99b2
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -0,0 +1,400 @@
+// automatically generated by stateify.
+
+package tcp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *SACKInfo) beforeSave() {}
+func (x *SACKInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Blocks", &x.Blocks)
+ m.Save("NumBlocks", &x.NumBlocks)
+}
+
+func (x *SACKInfo) afterLoad() {}
+func (x *SACKInfo) load(m state.Map) {
+ m.Load("Blocks", &x.Blocks)
+ m.Load("NumBlocks", &x.NumBlocks)
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var lastError string = x.saveLastError()
+ m.SaveValue("lastError", lastError)
+ var state endpointState = x.saveState()
+ m.SaveValue("state", state)
+ var hardError string = x.saveHardError()
+ m.SaveValue("hardError", hardError)
+ var acceptedChan []*endpoint = x.saveAcceptedChan()
+ m.SaveValue("acceptedChan", acceptedChan)
+ m.Save("netProto", &x.netProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvBufUsed", &x.rcvBufUsed)
+ m.Save("id", &x.id)
+ m.Save("isRegistered", &x.isRegistered)
+ m.Save("v6only", &x.v6only)
+ m.Save("isConnectNotified", &x.isConnectNotified)
+ m.Save("broadcast", &x.broadcast)
+ m.Save("workerRunning", &x.workerRunning)
+ m.Save("workerCleanup", &x.workerCleanup)
+ m.Save("sendTSOk", &x.sendTSOk)
+ m.Save("recentTS", &x.recentTS)
+ m.Save("tsOffset", &x.tsOffset)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("sackPermitted", &x.sackPermitted)
+ m.Save("sack", &x.sack)
+ m.Save("reusePort", &x.reusePort)
+ m.Save("delay", &x.delay)
+ m.Save("cork", &x.cork)
+ m.Save("scoreboard", &x.scoreboard)
+ m.Save("reuseAddr", &x.reuseAddr)
+ m.Save("slowAck", &x.slowAck)
+ m.Save("segmentQueue", &x.segmentQueue)
+ m.Save("synRcvdCount", &x.synRcvdCount)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("sndBufUsed", &x.sndBufUsed)
+ m.Save("sndClosed", &x.sndClosed)
+ m.Save("sndBufInQueue", &x.sndBufInQueue)
+ m.Save("sndQueue", &x.sndQueue)
+ m.Save("cc", &x.cc)
+ m.Save("packetTooBigCount", &x.packetTooBigCount)
+ m.Save("sndMTU", &x.sndMTU)
+ m.Save("keepalive", &x.keepalive)
+ m.Save("rcv", &x.rcv)
+ m.Save("snd", &x.snd)
+ m.Save("bindAddress", &x.bindAddress)
+ m.Save("connectingAddress", &x.connectingAddress)
+ m.Save("gso", &x.gso)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.LoadWait("waiterQueue", &x.waiterQueue)
+ m.LoadWait("rcvList", &x.rcvList)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvBufUsed", &x.rcvBufUsed)
+ m.Load("id", &x.id)
+ m.Load("isRegistered", &x.isRegistered)
+ m.Load("v6only", &x.v6only)
+ m.Load("isConnectNotified", &x.isConnectNotified)
+ m.Load("broadcast", &x.broadcast)
+ m.Load("workerRunning", &x.workerRunning)
+ m.Load("workerCleanup", &x.workerCleanup)
+ m.Load("sendTSOk", &x.sendTSOk)
+ m.Load("recentTS", &x.recentTS)
+ m.Load("tsOffset", &x.tsOffset)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("sackPermitted", &x.sackPermitted)
+ m.Load("sack", &x.sack)
+ m.Load("reusePort", &x.reusePort)
+ m.Load("delay", &x.delay)
+ m.Load("cork", &x.cork)
+ m.Load("scoreboard", &x.scoreboard)
+ m.Load("reuseAddr", &x.reuseAddr)
+ m.Load("slowAck", &x.slowAck)
+ m.LoadWait("segmentQueue", &x.segmentQueue)
+ m.Load("synRcvdCount", &x.synRcvdCount)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("sndBufUsed", &x.sndBufUsed)
+ m.Load("sndClosed", &x.sndClosed)
+ m.Load("sndBufInQueue", &x.sndBufInQueue)
+ m.LoadWait("sndQueue", &x.sndQueue)
+ m.Load("cc", &x.cc)
+ m.Load("packetTooBigCount", &x.packetTooBigCount)
+ m.Load("sndMTU", &x.sndMTU)
+ m.Load("keepalive", &x.keepalive)
+ m.LoadWait("rcv", &x.rcv)
+ m.LoadWait("snd", &x.snd)
+ m.Load("bindAddress", &x.bindAddress)
+ m.Load("connectingAddress", &x.connectingAddress)
+ m.Load("gso", &x.gso)
+ m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) })
+ m.LoadValue("state", new(endpointState), func(y interface{}) { x.loadState(y.(endpointState)) })
+ m.LoadValue("hardError", new(string), func(y interface{}) { x.loadHardError(y.(string)) })
+ m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *keepalive) beforeSave() {}
+func (x *keepalive) save(m state.Map) {
+ x.beforeSave()
+ m.Save("enabled", &x.enabled)
+ m.Save("idle", &x.idle)
+ m.Save("interval", &x.interval)
+ m.Save("count", &x.count)
+ m.Save("unacked", &x.unacked)
+}
+
+func (x *keepalive) afterLoad() {}
+func (x *keepalive) load(m state.Map) {
+ m.Load("enabled", &x.enabled)
+ m.Load("idle", &x.idle)
+ m.Load("interval", &x.interval)
+ m.Load("count", &x.count)
+ m.Load("unacked", &x.unacked)
+}
+
+func (x *receiver) beforeSave() {}
+func (x *receiver) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ep", &x.ep)
+ m.Save("rcvNxt", &x.rcvNxt)
+ m.Save("rcvAcc", &x.rcvAcc)
+ m.Save("rcvWndScale", &x.rcvWndScale)
+ m.Save("closed", &x.closed)
+ m.Save("pendingRcvdSegments", &x.pendingRcvdSegments)
+ m.Save("pendingBufUsed", &x.pendingBufUsed)
+ m.Save("pendingBufSize", &x.pendingBufSize)
+}
+
+func (x *receiver) afterLoad() {}
+func (x *receiver) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("rcvNxt", &x.rcvNxt)
+ m.Load("rcvAcc", &x.rcvAcc)
+ m.Load("rcvWndScale", &x.rcvWndScale)
+ m.Load("closed", &x.closed)
+ m.Load("pendingRcvdSegments", &x.pendingRcvdSegments)
+ m.Load("pendingBufUsed", &x.pendingBufUsed)
+ m.Load("pendingBufSize", &x.pendingBufSize)
+}
+
+func (x *renoState) beforeSave() {}
+func (x *renoState) save(m state.Map) {
+ x.beforeSave()
+ m.Save("s", &x.s)
+}
+
+func (x *renoState) afterLoad() {}
+func (x *renoState) load(m state.Map) {
+ m.Load("s", &x.s)
+}
+
+func (x *SACKScoreboard) beforeSave() {}
+func (x *SACKScoreboard) save(m state.Map) {
+ x.beforeSave()
+ m.Save("smss", &x.smss)
+ m.Save("maxSACKED", &x.maxSACKED)
+}
+
+func (x *SACKScoreboard) afterLoad() {}
+func (x *SACKScoreboard) load(m state.Map) {
+ m.Load("smss", &x.smss)
+ m.Load("maxSACKED", &x.maxSACKED)
+}
+
+func (x *segment) beforeSave() {}
+func (x *segment) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ var options []byte = x.saveOptions()
+ m.SaveValue("options", options)
+ var rcvdTime unixTime = x.saveRcvdTime()
+ m.SaveValue("rcvdTime", rcvdTime)
+ var xmitTime unixTime = x.saveXmitTime()
+ m.SaveValue("xmitTime", xmitTime)
+ m.Save("segmentEntry", &x.segmentEntry)
+ m.Save("refCnt", &x.refCnt)
+ m.Save("viewToDeliver", &x.viewToDeliver)
+ m.Save("sequenceNumber", &x.sequenceNumber)
+ m.Save("ackNumber", &x.ackNumber)
+ m.Save("flags", &x.flags)
+ m.Save("window", &x.window)
+ m.Save("csum", &x.csum)
+ m.Save("csumValid", &x.csumValid)
+ m.Save("parsedOptions", &x.parsedOptions)
+ m.Save("hasNewSACKInfo", &x.hasNewSACKInfo)
+}
+
+func (x *segment) afterLoad() {}
+func (x *segment) load(m state.Map) {
+ m.Load("segmentEntry", &x.segmentEntry)
+ m.Load("refCnt", &x.refCnt)
+ m.Load("viewToDeliver", &x.viewToDeliver)
+ m.Load("sequenceNumber", &x.sequenceNumber)
+ m.Load("ackNumber", &x.ackNumber)
+ m.Load("flags", &x.flags)
+ m.Load("window", &x.window)
+ m.Load("csum", &x.csum)
+ m.Load("csumValid", &x.csumValid)
+ m.Load("parsedOptions", &x.parsedOptions)
+ m.Load("hasNewSACKInfo", &x.hasNewSACKInfo)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+ m.LoadValue("options", new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) })
+ m.LoadValue("rcvdTime", new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) })
+ m.LoadValue("xmitTime", new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) })
+}
+
+func (x *segmentQueue) beforeSave() {}
+func (x *segmentQueue) save(m state.Map) {
+ x.beforeSave()
+ m.Save("list", &x.list)
+ m.Save("limit", &x.limit)
+ m.Save("used", &x.used)
+}
+
+func (x *segmentQueue) afterLoad() {}
+func (x *segmentQueue) load(m state.Map) {
+ m.LoadWait("list", &x.list)
+ m.Load("limit", &x.limit)
+ m.Load("used", &x.used)
+}
+
+func (x *sender) beforeSave() {}
+func (x *sender) save(m state.Map) {
+ x.beforeSave()
+ var lastSendTime unixTime = x.saveLastSendTime()
+ m.SaveValue("lastSendTime", lastSendTime)
+ var rttMeasureTime unixTime = x.saveRttMeasureTime()
+ m.SaveValue("rttMeasureTime", rttMeasureTime)
+ m.Save("ep", &x.ep)
+ m.Save("dupAckCount", &x.dupAckCount)
+ m.Save("fr", &x.fr)
+ m.Save("sndCwnd", &x.sndCwnd)
+ m.Save("sndSsthresh", &x.sndSsthresh)
+ m.Save("sndCAAckCount", &x.sndCAAckCount)
+ m.Save("outstanding", &x.outstanding)
+ m.Save("sndWnd", &x.sndWnd)
+ m.Save("sndUna", &x.sndUna)
+ m.Save("sndNxt", &x.sndNxt)
+ m.Save("sndNxtList", &x.sndNxtList)
+ m.Save("rttMeasureSeqNum", &x.rttMeasureSeqNum)
+ m.Save("closed", &x.closed)
+ m.Save("writeNext", &x.writeNext)
+ m.Save("writeList", &x.writeList)
+ m.Save("rtt", &x.rtt)
+ m.Save("rto", &x.rto)
+ m.Save("srttInited", &x.srttInited)
+ m.Save("maxPayloadSize", &x.maxPayloadSize)
+ m.Save("gso", &x.gso)
+ m.Save("sndWndScale", &x.sndWndScale)
+ m.Save("maxSentAck", &x.maxSentAck)
+ m.Save("cc", &x.cc)
+}
+
+func (x *sender) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("dupAckCount", &x.dupAckCount)
+ m.Load("fr", &x.fr)
+ m.Load("sndCwnd", &x.sndCwnd)
+ m.Load("sndSsthresh", &x.sndSsthresh)
+ m.Load("sndCAAckCount", &x.sndCAAckCount)
+ m.Load("outstanding", &x.outstanding)
+ m.Load("sndWnd", &x.sndWnd)
+ m.Load("sndUna", &x.sndUna)
+ m.Load("sndNxt", &x.sndNxt)
+ m.Load("sndNxtList", &x.sndNxtList)
+ m.Load("rttMeasureSeqNum", &x.rttMeasureSeqNum)
+ m.Load("closed", &x.closed)
+ m.Load("writeNext", &x.writeNext)
+ m.Load("writeList", &x.writeList)
+ m.Load("rtt", &x.rtt)
+ m.Load("rto", &x.rto)
+ m.Load("srttInited", &x.srttInited)
+ m.Load("maxPayloadSize", &x.maxPayloadSize)
+ m.Load("gso", &x.gso)
+ m.Load("sndWndScale", &x.sndWndScale)
+ m.Load("maxSentAck", &x.maxSentAck)
+ m.Load("cc", &x.cc)
+ m.LoadValue("lastSendTime", new(unixTime), func(y interface{}) { x.loadLastSendTime(y.(unixTime)) })
+ m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *rtt) beforeSave() {}
+func (x *rtt) save(m state.Map) {
+ x.beforeSave()
+ m.Save("srtt", &x.srtt)
+ m.Save("rttvar", &x.rttvar)
+}
+
+func (x *rtt) afterLoad() {}
+func (x *rtt) load(m state.Map) {
+ m.Load("srtt", &x.srtt)
+ m.Load("rttvar", &x.rttvar)
+}
+
+func (x *fastRecovery) beforeSave() {}
+func (x *fastRecovery) save(m state.Map) {
+ x.beforeSave()
+ m.Save("active", &x.active)
+ m.Save("first", &x.first)
+ m.Save("last", &x.last)
+ m.Save("maxCwnd", &x.maxCwnd)
+ m.Save("highRxt", &x.highRxt)
+ m.Save("rescueRxt", &x.rescueRxt)
+}
+
+func (x *fastRecovery) afterLoad() {}
+func (x *fastRecovery) load(m state.Map) {
+ m.Load("active", &x.active)
+ m.Load("first", &x.first)
+ m.Load("last", &x.last)
+ m.Load("maxCwnd", &x.maxCwnd)
+ m.Load("highRxt", &x.highRxt)
+ m.Load("rescueRxt", &x.rescueRxt)
+}
+
+func (x *unixTime) beforeSave() {}
+func (x *unixTime) save(m state.Map) {
+ x.beforeSave()
+ m.Save("second", &x.second)
+ m.Save("nano", &x.nano)
+}
+
+func (x *unixTime) afterLoad() {}
+func (x *unixTime) load(m state.Map) {
+ m.Load("second", &x.second)
+ m.Load("nano", &x.nano)
+}
+
+func (x *segmentList) beforeSave() {}
+func (x *segmentList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *segmentList) afterLoad() {}
+func (x *segmentList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *segmentEntry) beforeSave() {}
+func (x *segmentEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *segmentEntry) afterLoad() {}
+func (x *segmentEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load})
+ state.Register("tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load})
+ state.Register("tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load})
+ state.Register("tcp.renoState", (*renoState)(nil), state.Fns{Save: (*renoState).save, Load: (*renoState).load})
+ state.Register("tcp.SACKScoreboard", (*SACKScoreboard)(nil), state.Fns{Save: (*SACKScoreboard).save, Load: (*SACKScoreboard).load})
+ state.Register("tcp.segment", (*segment)(nil), state.Fns{Save: (*segment).save, Load: (*segment).load})
+ state.Register("tcp.segmentQueue", (*segmentQueue)(nil), state.Fns{Save: (*segmentQueue).save, Load: (*segmentQueue).load})
+ state.Register("tcp.sender", (*sender)(nil), state.Fns{Save: (*sender).save, Load: (*sender).load})
+ state.Register("tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load})
+ state.Register("tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load})
+ state.Register("tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load})
+ state.Register("tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load})
+ state.Register("tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load})
+}
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
new file mode 100644
index 000000000..fc1c7cbd2
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -0,0 +1,141 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+)
+
+type timerState int
+
+const (
+ timerStateDisabled timerState = iota
+ timerStateEnabled
+ timerStateOrphaned
+)
+
+// timer is a timer implementation that reduces the interactions with the
+// runtime timer infrastructure by letting timers run (and potentially
+// eventually expire) even if they are stopped. It makes it cheaper to
+// disable/reenable timers at the expense of spurious wakes. This is useful for
+// cases when the same timer is disabled/reenabled repeatedly with relatively
+// long timeouts farther into the future.
+//
+// TCP retransmit timers benefit from this because they the timeouts are long
+// (currently at least 200ms), and get disabled when acks are received, and
+// reenabled when new pending segments are sent.
+//
+// It is advantageous to avoid interacting with the runtime because it acquires
+// a global mutex and performs O(log n) operations, where n is the global number
+// of timers, whenever a timer is enabled or disabled, and may make a syscall.
+//
+// This struct is thread-compatible.
+type timer struct {
+ // state is the current state of the timer, it can be one of the
+ // following values:
+ // disabled - the timer is disabled.
+ // orphaned - the timer is disabled, but the runtime timer is
+ // enabled, which means that it will evetually cause a
+ // spurious wake (unless it gets enabled again before
+ // then).
+ // enabled - the timer is enabled, but the runtime timer may be set
+ // to an earlier expiration time due to a previous
+ // orphaned state.
+ state timerState
+
+ // target is the expiration time of the current timer. It is only
+ // meaningful in the enabled state.
+ target time.Time
+
+ // runtimeTarget is the expiration time of the runtime timer. It is
+ // meaningful in the enabled and orphaned states.
+ runtimeTarget time.Time
+
+ // timer is the runtime timer used to wait on.
+ timer *time.Timer
+}
+
+// init initializes the timer. Once it expires, it the given waker will be
+// asserted.
+func (t *timer) init(w *sleep.Waker) {
+ t.state = timerStateDisabled
+
+ // Initialize a runtime timer that will assert the waker, then
+ // immediately stop it.
+ t.timer = time.AfterFunc(time.Hour, func() {
+ w.Assert()
+ })
+ t.timer.Stop()
+}
+
+// cleanup frees all resources associated with the timer.
+func (t *timer) cleanup() {
+ t.timer.Stop()
+}
+
+// checkExpiration checks if the given timer has actually expired, it should be
+// called whenever a sleeper wakes up due to the waker being asserted, and is
+// used to check if it's a supurious wake (due to a previously orphaned timer)
+// or a legitimate one.
+func (t *timer) checkExpiration() bool {
+ // Transition to fully disabled state if we're just consuming an
+ // orphaned timer.
+ if t.state == timerStateOrphaned {
+ t.state = timerStateDisabled
+ return false
+ }
+
+ // The timer is enabled, but it may have expired early. Check if that's
+ // the case, and if so, reset the runtime timer to the correct time.
+ now := time.Now()
+ if now.Before(t.target) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(t.target.Sub(now))
+ return false
+ }
+
+ // The timer has actually expired, disable it for now and inform the
+ // caller.
+ t.state = timerStateDisabled
+ return true
+}
+
+// disable disables the timer, leaving it in an orphaned state if it wasn't
+// already disabled.
+func (t *timer) disable() {
+ if t.state != timerStateDisabled {
+ t.state = timerStateOrphaned
+ }
+}
+
+// enabled returns true if the timer is currently enabled, false otherwise.
+func (t *timer) enabled() bool {
+ return t.state == timerStateEnabled
+}
+
+// enable enables the timer, programming the runtime timer if necessary.
+func (t *timer) enable(d time.Duration) {
+ t.target = time.Now().Add(d)
+
+ // Check if we need to set the runtime timer.
+ if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(d)
+ }
+
+ t.state = timerStateEnabled
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
new file mode 100644
index 000000000..3d52a4f31
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -0,0 +1,1002 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "math"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type udpPacket struct {
+ udpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents a UDP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// It implements tcpip.Endpoint.
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+ multicastTTL uint8
+ multicastAddr tcpip.Address
+ multicastNICID tcpip.NICID
+ multicastLoop bool
+ reusePort bool
+ broadcast bool
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // multicastMemberships that need to be remvoed when the endpoint is
+ // closed. Protected by the mu mutex.
+ multicastMemberships []multicastMembership
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ // RFC 1075 section 5.4 recommends a TTL of 1 for membership
+ // requests.
+ //
+ // RFC 5135 4.2.1 appears to assume that IGMP messages have a
+ // TTL of 1.
+ //
+ // RFC 5135 Appendix A defines TTL=1: A multicast source that
+ // wants its traffic to not traverse a router (e.g., leave a
+ // home network) may find it useful to send traffic with IP
+ // TTL=1.
+ //
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastLoop: true,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ }
+
+ for _, mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// connectRoute establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return stack.Route{}, 0, 0, err
+ }
+
+ localAddr := e.id.LocalAddress
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicid == 0 {
+ nicid = e.multicastNICID
+ }
+ if localAddr == "" {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
+ if err != nil {
+ return stack.Route{}, 0, 0, err
+ }
+ return r, nicid, netProto, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ if p.Size() > math.MaxUint16 {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ var dstPort uint16
+ if to == nil {
+ route = &e.route
+ dstPort = e.dstPort
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route, given that it may need to
+ // change in Route.Resolve() call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ if to.Addr == header.IPv4Broadcast && !e.broadcast {
+ return 0, nil, tcpip.ErrBroadcastDisabled
+ }
+
+ r, _, _, err := e.connectRoute(nicid, *to)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer r.Release()
+
+ route = &r
+ dstPort = to.Port
+ }
+
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ ttl := route.DefaultTTL()
+ if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
+ ttl = e.multicastTTL
+ }
+
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(len(v)), nil, nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ netProto, err := e.checkV4Mapped(&fa, false)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return tcpip.ErrBadLocalAddress
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ if e.bindNICID != 0 && e.bindNICID != nic {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrUnknownDevice
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for _, mem := range e.multicastMemberships {
+ if mem == memToInsert {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships = append(e.multicastMemberships, memToInsert)
+
+ case tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrUnknownDevice
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+ memToRemoveIndex := -1
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for i, mem := range e.multicastMemberships {
+ if mem == memToRemove {
+ memToRemoveIndex = i
+ break
+ }
+ }
+ if memToRemoveIndex == -1 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+
+ case tcpip.MulticastLoopOption:
+ e.mu.Lock()
+ e.multicastLoop = bool(v)
+ e.mu.Unlock()
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v != 0
+ e.mu.Unlock()
+
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastTTLOption(e.multicastTTL)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ e.multicastNICID,
+ e.multicastAddr,
+ }
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.MulticastLoopOption:
+ e.mu.RLock()
+ v := e.multicastLoop
+ e.mu.RUnlock()
+
+ *o = tcpip.MulticastLoopOption(v)
+ return nil
+
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ case *tcpip.BroadcastOption:
+ e.mu.RLock()
+ v := e.broadcast
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// sendUDP sends a UDP segment via the provided network endpoint and under the
+// provided identity.
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
+ // Allocate a buffer for the UDP header.
+ hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+
+ // Initialize the header.
+ udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+
+ length := uint16(hdr.UsedLength() + data.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: localPort,
+ DstPort: remotePort,
+ Length: length,
+ })
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ for _, v := range data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udp.SetChecksum(^udp.CalculateChecksum(xsum))
+ }
+
+ // Track count of packets sent.
+ r.Stats().UDP.PacketsSent.Increment()
+
+ return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl)
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+
+ // Fail if we are bound to an IPv6 address.
+ if !allowMismatch && len(e.id.LocalAddress) == 16 {
+ return 0, tcpip.ErrNetworkUnreachable
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Port == 0 {
+ // We don't support connecting to port zero.
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ var localPort uint16
+ switch e.state {
+ case stateInitial:
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ r, nicid, netProto, err := e.connectRoute(nicid, addr)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemotePort: addr.Port,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
+ }
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ // Remove the old registration.
+ if e.id.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.dstPort = addr.Port
+ e.regNICID = nicid
+ e.effectiveNetProtos = netProtos
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // A socket in the bound state can still receive multicast messages,
+ // so we need to notify waiters on shutdown.
+ if e.state != stateBound && e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ e.shutdownFlags |= flags
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if e.id.LocalPort == 0 {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ if err != nil {
+ return id, err
+ }
+ id.LocalPort = port
+ }
+
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ }
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, true)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ nicid := addr.NIC
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nicid == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.regNICID = nicid
+ e.effectiveNetProtos = netProtos
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ // Save the effective NICID generated by bindLocked.
+ e.bindNICID = e.regNICID
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ vv.TrimFront(header.UDPMinimumSize)
+
+ e.rcvMu.Lock()
+ e.stack.Stats().UDP.PacketsReceived.Increment()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &udpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ Port: hdr.SourcePort(),
+ },
+ }
+ pkt.data = vv.Clone(pkt.views[:])
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += vv.Size()
+
+ pkt.timestamp = e.stack.NowNanoseconds()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
new file mode 100644
index 000000000..74e8e9fd5
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -0,0 +1,112 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves udpPacket.data field.
+func (u *udpPacket) saveData() buffer.VectorisedView {
+ // We cannot save u.data directly as u.data.views may alias to u.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return u.data.Clone(nil)
+}
+
+// loadData loads udpPacket.data field.
+func (u *udpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing u.views for data.views.
+ u.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ netProto := e.effectiveNetProtos[0]
+ // Connect() and bindLocked() both assert
+ //
+ // netProto == header.IPv6ProtocolNumber
+ //
+ // before creating a multi-entry effectiveNetProtos.
+ if len(e.effectiveNetProtos) > 1 {
+ netProto = header.IPv6ProtocolNumber
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ id := e.id
+ e.id.LocalPort = 0
+ e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
new file mode 100644
index 000000000..25bdd2929
--- /dev/null
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -0,0 +1,96 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Forwarder is a session request forwarder, which allows clients to decide
+// what to do with a session request, for example: ignore it, or process it.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ handler func(*ForwarderRequest)
+
+ stack *stack.Stack
+}
+
+// NewForwarder allocates and initializes a new forwarder.
+func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
+ return &Forwarder{
+ stack: s,
+ handler: handler,
+ }
+}
+
+// HandlePacket handles all packets.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ f.handler(&ForwarderRequest{
+ stack: f.stack,
+ route: r,
+ id: id,
+ vv: vv,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a session request received by the forwarder and
+// passed to the client. Clients may optionally create an endpoint to represent
+// it via CreateEndpoint.
+type ForwarderRequest struct {
+ stack *stack.Stack
+ route *stack.Route
+ id stack.TransportEndpointID
+ vv buffer.VectorisedView
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the session request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.id
+}
+
+// CreateEndpoint creates a connected UDP endpoint for the session request.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := newEndpoint(r.stack, r.route.NetProto, queue)
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ ep.id = r.id
+ ep.route = r.route.Clone()
+ ep.dstPort = r.id.RemotePort
+ ep.regNICID = r.route.NICID()
+
+ ep.state = stateConnected
+
+ ep.rcvMu.Lock()
+ ep.rcvReady = true
+ ep.rcvMu.Unlock()
+
+ ep.HandlePacket(r.route, r.id, r.vv)
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
new file mode 100644
index 000000000..3d31dfbf1
--- /dev/null
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package udp contains the implementation of the UDP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing udp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the udp protocol name.
+ ProtocolName = "udp"
+
+ // ProtocolNumber is the udp protocol number.
+ ProtocolNumber = header.UDPProtocolNumber
+)
+
+type protocol struct{}
+
+// Number returns the udp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new udp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// NewRawEndpoint creates a new raw UDP endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid udp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.UDPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given udp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.UDP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go
new file mode 100755
index 000000000..673a9373b
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_packet_list.go
@@ -0,0 +1,173 @@
+package udp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type udpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type udpPacketList struct {
+ head *udpPacket
+ tail *udpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *udpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *udpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *udpPacketList) Front() *udpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *udpPacketList) Back() *udpPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *udpPacketList) PushFront(e *udpPacket) {
+ udpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *udpPacketList) PushBack(e *udpPacket) {
+ udpPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *udpPacketList) PushBackList(m *udpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *udpPacketList) InsertAfter(b, e *udpPacket) {
+ a := udpPacketElementMapper{}.linkerFor(b).Next()
+ udpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *udpPacketList) InsertBefore(a, e *udpPacket) {
+ b := udpPacketElementMapper{}.linkerFor(a).Prev()
+ udpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *udpPacketList) Remove(e *udpPacket) {
+ prev := udpPacketElementMapper{}.linkerFor(e).Prev()
+ next := udpPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ udpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ udpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type udpPacketEntry struct {
+ next *udpPacket
+ prev *udpPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *udpPacketEntry) Next() *udpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *udpPacketEntry) Prev() *udpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *udpPacketEntry) SetNext(elem *udpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *udpPacketEntry) SetPrev(elem *udpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
new file mode 100755
index 000000000..711e2feeb
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -0,0 +1,128 @@
+// automatically generated by stateify.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *udpPacket) beforeSave() {}
+func (x *udpPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("udpPacketEntry", &x.udpPacketEntry)
+ m.Save("senderAddress", &x.senderAddress)
+ m.Save("timestamp", &x.timestamp)
+}
+
+func (x *udpPacket) afterLoad() {}
+func (x *udpPacket) load(m state.Map) {
+ m.Load("udpPacketEntry", &x.udpPacketEntry)
+ m.Load("senderAddress", &x.senderAddress)
+ m.Load("timestamp", &x.timestamp)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvReady", &x.rcvReady)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("id", &x.id)
+ m.Save("state", &x.state)
+ m.Save("bindNICID", &x.bindNICID)
+ m.Save("regNICID", &x.regNICID)
+ m.Save("dstPort", &x.dstPort)
+ m.Save("v6only", &x.v6only)
+ m.Save("multicastTTL", &x.multicastTTL)
+ m.Save("multicastAddr", &x.multicastAddr)
+ m.Save("multicastNICID", &x.multicastNICID)
+ m.Save("multicastLoop", &x.multicastLoop)
+ m.Save("reusePort", &x.reusePort)
+ m.Save("broadcast", &x.broadcast)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("multicastMemberships", &x.multicastMemberships)
+ m.Save("effectiveNetProtos", &x.effectiveNetProtos)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvReady", &x.rcvReady)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("id", &x.id)
+ m.Load("state", &x.state)
+ m.Load("bindNICID", &x.bindNICID)
+ m.Load("regNICID", &x.regNICID)
+ m.Load("dstPort", &x.dstPort)
+ m.Load("v6only", &x.v6only)
+ m.Load("multicastTTL", &x.multicastTTL)
+ m.Load("multicastAddr", &x.multicastAddr)
+ m.Load("multicastNICID", &x.multicastNICID)
+ m.Load("multicastLoop", &x.multicastLoop)
+ m.Load("reusePort", &x.reusePort)
+ m.Load("broadcast", &x.broadcast)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("multicastMemberships", &x.multicastMemberships)
+ m.Load("effectiveNetProtos", &x.effectiveNetProtos)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *multicastMembership) beforeSave() {}
+func (x *multicastMembership) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nicID", &x.nicID)
+ m.Save("multicastAddr", &x.multicastAddr)
+}
+
+func (x *multicastMembership) afterLoad() {}
+func (x *multicastMembership) load(m state.Map) {
+ m.Load("nicID", &x.nicID)
+ m.Load("multicastAddr", &x.multicastAddr)
+}
+
+func (x *udpPacketList) beforeSave() {}
+func (x *udpPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *udpPacketList) afterLoad() {}
+func (x *udpPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *udpPacketEntry) beforeSave() {}
+func (x *udpPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *udpPacketEntry) afterLoad() {}
+func (x *udpPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load})
+ state.Register("udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load})
+ state.Register("udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load})
+ state.Register("udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load})
+}
diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go
new file mode 100644
index 000000000..c4685020d
--- /dev/null
+++ b/pkg/tmutex/tmutex.go
@@ -0,0 +1,81 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tmutex provides the implementation of a mutex that implements an
+// efficient TryLock function in addition to Lock and Unlock.
+package tmutex
+
+import (
+ "sync/atomic"
+)
+
+// Mutex is a mutual exclusion primitive that implements TryLock in addition
+// to Lock and Unlock.
+type Mutex struct {
+ v int32
+ ch chan struct{}
+}
+
+// Init initializes the mutex.
+func (m *Mutex) Init() {
+ m.v = 1
+ m.ch = make(chan struct{}, 1)
+}
+
+// Lock acquires the mutex. If it is currently held by another goroutine, Lock
+// will wait until it has a chance to acquire it.
+func (m *Mutex) Lock() {
+ // Uncontended case.
+ if atomic.AddInt32(&m.v, -1) == 0 {
+ return
+ }
+
+ for {
+ // Try to acquire the mutex again, at the same time making sure
+ // that m.v is negative, which indicates to the owner of the
+ // lock that it is contended, which will force it to try to wake
+ // someone up when it releases the mutex.
+ if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 {
+ return
+ }
+
+ // Wait for the mutex to be released before trying again.
+ <-m.ch
+ }
+}
+
+// TryLock attempts to acquire the mutex without blocking. If the mutex is
+// currently held by another goroutine, it fails to acquire it and returns
+// false.
+func (m *Mutex) TryLock() bool {
+ v := atomic.LoadInt32(&m.v)
+ if v <= 0 {
+ return false
+ }
+ return atomic.CompareAndSwapInt32(&m.v, 1, 0)
+}
+
+// Unlock releases the mutex.
+func (m *Mutex) Unlock() {
+ if atomic.SwapInt32(&m.v, 1) == 0 {
+ // There were no pending waiters.
+ return
+ }
+
+ // Wake some waiter up.
+ select {
+ case m.ch <- struct{}{}:
+ default:
+ }
+}
diff --git a/pkg/tmutex/tmutex_state_autogen.go b/pkg/tmutex/tmutex_state_autogen.go
new file mode 100755
index 000000000..2b2bb599e
--- /dev/null
+++ b/pkg/tmutex/tmutex_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package tmutex
+
diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go
new file mode 100644
index 000000000..2aa1af4ff
--- /dev/null
+++ b/pkg/unet/unet.go
@@ -0,0 +1,569 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unet provides a minimal net package based on Unix Domain Sockets.
+//
+// This does no pooling, and should only be used for a limited number of
+// connections in a Go process. Don't use this package for arbitrary servers.
+package unet
+
+import (
+ "errors"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/gate"
+)
+
+// backlog is used for the listen request.
+const backlog = 16
+
+// errClosing is returned by wait if the Socket is in the process of closing.
+var errClosing = errors.New("Socket is closing")
+
+// errMessageTruncated indicates that data was lost because the provided buffer
+// was too small.
+var errMessageTruncated = errors.New("message truncated")
+
+// socketType returns the appropriate type.
+func socketType(packet bool) int {
+ if packet {
+ return syscall.SOCK_SEQPACKET
+ }
+ return syscall.SOCK_STREAM
+}
+
+// socket creates a new host socket.
+func socket(packet bool) (int, error) {
+ // Make a new socket.
+ fd, err := syscall.Socket(syscall.AF_UNIX, socketType(packet), 0)
+ if err != nil {
+ return 0, err
+ }
+
+ return fd, nil
+}
+
+// eventFD returns a new event FD with initial value 0.
+func eventFD() (int, error) {
+ f, _, e := syscall.Syscall(syscall.SYS_EVENTFD2, 0, 0, 0)
+ if e != 0 {
+ return -1, e
+ }
+ return int(f), nil
+}
+
+// Socket is a connected unix domain socket.
+type Socket struct {
+ // gate protects use of fd.
+ gate gate.Gate
+
+ // fd is the bound socket.
+ //
+ // fd must be read atomically, and only remains valid if read while
+ // within gate.
+ fd int32
+
+ // efd is an event FD that is signaled when the socket is closing.
+ //
+ // efd is immutable and remains valid until Close/Release.
+ efd int
+
+ // race is an atomic variable used to avoid triggering the race
+ // detector. See comment in SocketPair below.
+ race *int32
+}
+
+// NewSocket returns a socket from an existing FD.
+//
+// NewSocket takes ownership of fd.
+func NewSocket(fd int) (*Socket, error) {
+ // fd must be non-blocking for non-blocking syscall.Accept in
+ // ServerSocket.Accept.
+ if err := syscall.SetNonblock(fd, true); err != nil {
+ return nil, err
+ }
+
+ efd, err := eventFD()
+ if err != nil {
+ return nil, err
+ }
+
+ return &Socket{
+ fd: int32(fd),
+ efd: efd,
+ }, nil
+}
+
+// finish completes use of s.fd by evicting any waiters, closing the gate, and
+// closing the event FD.
+func (s *Socket) finish() error {
+ // Signal any blocked or future polls.
+ //
+ // N.B. eventfd writes must be 8 bytes.
+ if _, err := syscall.Write(s.efd, []byte{1, 0, 0, 0, 0, 0, 0, 0}); err != nil {
+ return err
+ }
+
+ // Close the gate, blocking until all FD users leave.
+ s.gate.Close()
+
+ return syscall.Close(s.efd)
+}
+
+// Close closes the socket.
+func (s *Socket) Close() error {
+ // Set the FD in the socket to -1, to ensure that all future calls to
+ // FD/Release get nothing and Close calls return immediately.
+ fd := int(atomic.SwapInt32(&s.fd, -1))
+ if fd < 0 {
+ // Already closed or closing.
+ return syscall.EBADF
+ }
+
+ // Shutdown the socket to cancel any pending accepts.
+ s.shutdown(fd)
+
+ if err := s.finish(); err != nil {
+ return err
+ }
+
+ return syscall.Close(fd)
+}
+
+// Release releases ownership of the socket FD.
+//
+// The returned FD is non-blocking.
+//
+// Any concurrent or future callers of Socket methods will receive EBADF.
+func (s *Socket) Release() (int, error) {
+ // Set the FD in the socket to -1, to ensure that all future calls to
+ // FD/Release get nothing and Close calls return immediately.
+ fd := int(atomic.SwapInt32(&s.fd, -1))
+ if fd < 0 {
+ // Already closed or closing.
+ return -1, syscall.EBADF
+ }
+
+ if err := s.finish(); err != nil {
+ return -1, err
+ }
+
+ return fd, nil
+}
+
+// FD returns the FD for this Socket.
+//
+// The FD is non-blocking and must not be made blocking.
+//
+// N.B. os.File.Fd makes the FD blocking. Use of Release instead of FD is
+// strongly preferred.
+//
+// The returned FD cannot be used safely if there may be concurrent callers to
+// Close or Release.
+//
+// Use Release to take ownership of the FD.
+func (s *Socket) FD() int {
+ return int(atomic.LoadInt32(&s.fd))
+}
+
+// enterFD enters the FD gate and returns the FD value.
+//
+// If enterFD returns ok, s.gate.Leave must be called when done with the FD.
+// Callers may only block while within the gate using s.wait.
+//
+// The returned FD is guaranteed to remain valid until s.gate.Leave.
+func (s *Socket) enterFD() (int, bool) {
+ if !s.gate.Enter() {
+ return -1, false
+ }
+
+ fd := int(atomic.LoadInt32(&s.fd))
+ if fd < 0 {
+ s.gate.Leave()
+ return -1, false
+ }
+
+ return fd, true
+}
+
+// SocketPair creates a pair of connected sockets.
+func SocketPair(packet bool) (*Socket, *Socket, error) {
+ // Make a new pair.
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, socketType(packet)|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // race is an atomic variable used to avoid triggering the race
+ // detector. We have to fool TSAN into thinking there is a race
+ // variable between our two sockets. We only use SocketPair in tests
+ // anyway.
+ //
+ // NOTE(b/27107811): This is purely due to the fact that the raw
+ // syscall does not serve as a boundary for the sanitizer.
+ var race int32
+ a, err := NewSocket(fds[0])
+ if err != nil {
+ syscall.Close(fds[0])
+ syscall.Close(fds[1])
+ return nil, nil, err
+ }
+ a.race = &race
+ b, err := NewSocket(fds[1])
+ if err != nil {
+ a.Close()
+ syscall.Close(fds[1])
+ return nil, nil, err
+ }
+ b.race = &race
+ return a, b, nil
+}
+
+// Connect connects to a server.
+func Connect(addr string, packet bool) (*Socket, error) {
+ fd, err := socket(packet)
+ if err != nil {
+ return nil, err
+ }
+
+ // Connect the socket.
+ usa := &syscall.SockaddrUnix{Name: addr}
+ if err := syscall.Connect(fd, usa); err != nil {
+ syscall.Close(fd)
+ return nil, err
+ }
+
+ return NewSocket(fd)
+}
+
+// ControlMessage wraps around a byte array and provides functions for parsing
+// as a Unix Domain Socket control message.
+type ControlMessage []byte
+
+// EnableFDs enables receiving FDs via control message.
+//
+// This guarantees only a MINIMUM number of FDs received. You may receive MORE
+// than this due to the way FDs are packed. To be specific, the number of
+// receivable buffers will be rounded up to the nearest even number.
+//
+// This must be called prior to ReadVec if you want to receive FDs.
+func (c *ControlMessage) EnableFDs(count int) {
+ *c = make([]byte, syscall.CmsgSpace(count*4))
+}
+
+// ExtractFDs returns the list of FDs in the control message.
+//
+// Either this or CloseFDs should be used after EnableFDs.
+func (c *ControlMessage) ExtractFDs() ([]int, error) {
+ msgs, err := syscall.ParseSocketControlMessage(*c)
+ if err != nil {
+ return nil, err
+ }
+ var fds []int
+ for _, msg := range msgs {
+ thisFds, err := syscall.ParseUnixRights(&msg)
+ if err != nil {
+ // Different control message.
+ return nil, err
+ }
+ for _, fd := range thisFds {
+ if fd >= 0 {
+ fds = append(fds, fd)
+ }
+ }
+ }
+ return fds, nil
+}
+
+// CloseFDs closes the list of FDs in the control message.
+//
+// Either this or ExtractFDs should be used after EnableFDs.
+func (c *ControlMessage) CloseFDs() {
+ fds, _ := c.ExtractFDs()
+ for _, fd := range fds {
+ if fd >= 0 {
+ syscall.Close(fd)
+ }
+ }
+}
+
+// PackFDs packs the given list of FDs in the control message.
+//
+// This must be used prior to WriteVec.
+func (c *ControlMessage) PackFDs(fds ...int) {
+ *c = ControlMessage(syscall.UnixRights(fds...))
+}
+
+// UnpackFDs clears the control message.
+func (c *ControlMessage) UnpackFDs() {
+ *c = nil
+}
+
+// SocketWriter wraps an individual send operation.
+//
+// The normal entrypoint is WriteVec.
+type SocketWriter struct {
+ socket *Socket
+ to []byte
+ blocking bool
+ race *int32
+
+ ControlMessage
+}
+
+// Writer returns a writer for this socket.
+func (s *Socket) Writer(blocking bool) SocketWriter {
+ return SocketWriter{socket: s, blocking: blocking, race: s.race}
+}
+
+// Write implements io.Writer.Write.
+func (s *Socket) Write(p []byte) (int, error) {
+ r := s.Writer(true)
+ return r.WriteVec([][]byte{p})
+}
+
+// GetSockOpt gets the given socket option.
+func (s *Socket) GetSockOpt(level int, name int, b []byte) (uint32, error) {
+ fd, ok := s.enterFD()
+ if !ok {
+ return 0, syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ return getsockopt(fd, level, name, b)
+}
+
+// SetSockOpt sets the given socket option.
+func (s *Socket) SetSockOpt(level, name int, b []byte) error {
+ fd, ok := s.enterFD()
+ if !ok {
+ return syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ return setsockopt(fd, level, name, b)
+}
+
+// GetSockName returns the socket name.
+func (s *Socket) GetSockName() ([]byte, error) {
+ fd, ok := s.enterFD()
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ var buf []byte
+ l := syscall.SizeofSockaddrAny
+
+ for {
+ // If the buffer is not large enough, allocate a new one with the hint.
+ buf = make([]byte, l)
+ l, err := getsockname(fd, buf)
+ if err != nil {
+ return nil, err
+ }
+
+ if l <= uint32(len(buf)) {
+ return buf[:l], nil
+ }
+ }
+}
+
+// GetPeerName returns the peer name.
+func (s *Socket) GetPeerName() ([]byte, error) {
+ fd, ok := s.enterFD()
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ var buf []byte
+ l := syscall.SizeofSockaddrAny
+
+ for {
+ // See above.
+ buf = make([]byte, l)
+ l, err := getpeername(fd, buf)
+ if err != nil {
+ return nil, err
+ }
+
+ if l <= uint32(len(buf)) {
+ return buf[:l], nil
+ }
+ }
+}
+
+// GetPeerCred returns the peer's unix credentials.
+func (s *Socket) GetPeerCred() (*syscall.Ucred, error) {
+ fd, ok := s.enterFD()
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ return syscall.GetsockoptUcred(fd, syscall.SOL_SOCKET, syscall.SO_PEERCRED)
+}
+
+// SocketReader wraps an individual receive operation.
+//
+// This may be used for doing vectorized reads and/or sending additional
+// control messages (e.g. FDs). The normal entrypoint is ReadVec.
+//
+// One of ExtractFDs or DisposeFDs must be called if EnableFDs is used.
+type SocketReader struct {
+ socket *Socket
+ source []byte
+ blocking bool
+ race *int32
+
+ ControlMessage
+}
+
+// Reader returns a reader for this socket.
+func (s *Socket) Reader(blocking bool) SocketReader {
+ return SocketReader{socket: s, blocking: blocking, race: s.race}
+}
+
+// Read implements io.Reader.Read.
+func (s *Socket) Read(p []byte) (int, error) {
+ r := s.Reader(true)
+ return r.ReadVec([][]byte{p})
+}
+
+func (s *Socket) shutdown(fd int) error {
+ // Shutdown the socket to cancel any pending accepts.
+ return syscall.Shutdown(fd, syscall.SHUT_RDWR)
+}
+
+// Shutdown closes the socket for read and write.
+func (s *Socket) Shutdown() error {
+ fd, ok := s.enterFD()
+ if !ok {
+ return syscall.EBADF
+ }
+ defer s.gate.Leave()
+
+ return s.shutdown(fd)
+}
+
+// ServerSocket is a bound unix domain socket.
+type ServerSocket struct {
+ socket *Socket
+}
+
+// NewServerSocket returns a socket from an existing FD.
+func NewServerSocket(fd int) (*ServerSocket, error) {
+ s, err := NewSocket(fd)
+ if err != nil {
+ return nil, err
+ }
+ return &ServerSocket{socket: s}, nil
+}
+
+// Bind creates and binds a new socket.
+func Bind(addr string, packet bool) (*ServerSocket, error) {
+ fd, err := socket(packet)
+ if err != nil {
+ return nil, err
+ }
+
+ // Do the bind.
+ usa := &syscall.SockaddrUnix{Name: addr}
+ if err := syscall.Bind(fd, usa); err != nil {
+ syscall.Close(fd)
+ return nil, err
+ }
+
+ return NewServerSocket(fd)
+}
+
+// BindAndListen creates, binds and listens on a new socket.
+func BindAndListen(addr string, packet bool) (*ServerSocket, error) {
+ s, err := Bind(addr, packet)
+ if err != nil {
+ return nil, err
+ }
+
+ // Start listening.
+ if err := s.Listen(); err != nil {
+ s.Close()
+ return nil, err
+ }
+
+ return s, nil
+}
+
+// Listen starts listening on the socket.
+func (s *ServerSocket) Listen() error {
+ fd, ok := s.socket.enterFD()
+ if !ok {
+ return syscall.EBADF
+ }
+ defer s.socket.gate.Leave()
+
+ return syscall.Listen(fd, backlog)
+}
+
+// Accept accepts a new connection.
+//
+// This is always blocking.
+//
+// Preconditions:
+// * ServerSocket is listening (Listen called).
+func (s *ServerSocket) Accept() (*Socket, error) {
+ fd, ok := s.socket.enterFD()
+ if !ok {
+ return nil, syscall.EBADF
+ }
+ defer s.socket.gate.Leave()
+
+ for {
+ nfd, _, err := syscall.Accept(fd)
+ switch err {
+ case nil:
+ return NewSocket(nfd)
+ case syscall.EAGAIN:
+ err = s.socket.wait(false)
+ if err == errClosing {
+ err = syscall.EBADF
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ }
+}
+
+// Close closes the server socket.
+//
+// This must only be called once.
+func (s *ServerSocket) Close() error {
+ return s.socket.Close()
+}
+
+// FD returns the socket's file descriptor.
+//
+// See Socket.FD.
+func (s *ServerSocket) FD() int {
+ return s.socket.FD()
+}
+
+// Release releases ownership of the socket's file descriptor.
+//
+// See Socket.Release.
+func (s *ServerSocket) Release() (int, error) {
+ return s.socket.Release()
+}
diff --git a/pkg/unet/unet_state_autogen.go b/pkg/unet/unet_state_autogen.go
new file mode 100755
index 000000000..1f7c7fa59
--- /dev/null
+++ b/pkg/unet/unet_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package unet
+
diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go
new file mode 100644
index 000000000..fa0916439
--- /dev/null
+++ b/pkg/unet/unet_unsafe.go
@@ -0,0 +1,289 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unet
+
+import (
+ "io"
+ "math"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+)
+
+// wait blocks until the socket FD is ready for reading or writing, depending
+// on the value of write.
+//
+// Returns errClosing if the Socket is in the process of closing.
+func (s *Socket) wait(write bool) error {
+ for {
+ // Checking the FD on each loop is not strictly necessary, it
+ // just avoids an extra poll call.
+ fd := atomic.LoadInt32(&s.fd)
+ if fd < 0 {
+ return errClosing
+ }
+
+ events := []linux.PollFD{
+ {
+ // The actual socket FD.
+ FD: fd,
+ Events: linux.POLLIN,
+ },
+ {
+ // The eventfd, signaled when we are closing.
+ FD: int32(s.efd),
+ Events: linux.POLLIN,
+ },
+ }
+ if write {
+ events[0].Events = linux.POLLOUT
+ }
+
+ _, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(&events[0])), 2, uintptr(math.MaxUint64))
+ if e == syscall.EINTR {
+ continue
+ }
+ if e != 0 {
+ return e
+ }
+
+ if events[1].REvents&linux.POLLIN == linux.POLLIN {
+ // eventfd signaled, we're closing.
+ return errClosing
+ }
+
+ return nil
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// iovecs is used as an initial slice, to avoid excessive allocations.
+func buildIovec(bufs [][]byte, iovecs []syscall.Iovec) ([]syscall.Iovec, int) {
+ var length int
+ for i := range bufs {
+ if l := len(bufs[i]); l > 0 {
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(l),
+ })
+ length += l
+ }
+ }
+ return iovecs, length
+}
+
+// ReadVec reads into the pre-allocated bufs. Returns bytes read.
+//
+// The pre-allocatted space used by ReadVec is based upon slice lengths.
+//
+// This function is not guaranteed to read all available data, it
+// returns as soon as a single recvmsg call succeeds.
+func (r *SocketReader) ReadVec(bufs [][]byte) (int, error) {
+ iovecs, length := buildIovec(bufs, make([]syscall.Iovec, 0, 2))
+
+ var msg syscall.Msghdr
+ if len(r.source) != 0 {
+ msg.Name = &r.source[0]
+ msg.Namelen = uint32(len(r.source))
+ }
+
+ if len(r.ControlMessage) != 0 {
+ msg.Control = &r.ControlMessage[0]
+ msg.Controllen = uint64(len(r.ControlMessage))
+ }
+
+ if len(iovecs) != 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ // n is the bytes received.
+ var n uintptr
+
+ fd, ok := r.socket.enterFD()
+ if !ok {
+ return 0, syscall.EBADF
+ }
+ // Leave on returns below.
+ for {
+ var e syscall.Errno
+
+ // Try a non-blocking recv first, so we don't give up the go runtime M.
+ n, _, e = syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_TRUNC)
+ if e == 0 {
+ break
+ }
+ if e == syscall.EINTR {
+ continue
+ }
+ if !r.blocking {
+ r.socket.gate.Leave()
+ return 0, e
+ }
+ if e != syscall.EAGAIN && e != syscall.EWOULDBLOCK {
+ r.socket.gate.Leave()
+ return 0, e
+ }
+
+ // Wait for the socket to become readable.
+ err := r.socket.wait(false)
+ if err == errClosing {
+ err = syscall.EBADF
+ }
+ if err != nil {
+ r.socket.gate.Leave()
+ return 0, err
+ }
+ }
+
+ r.socket.gate.Leave()
+
+ if msg.Controllen < uint64(len(r.ControlMessage)) {
+ r.ControlMessage = r.ControlMessage[:msg.Controllen]
+ }
+
+ if msg.Namelen < uint32(len(r.source)) {
+ r.source = r.source[:msg.Namelen]
+ }
+
+ // All unet sockets are SOCK_STREAM or SOCK_SEQPACKET, both of which
+ // indicate that the other end is closed by returning a 0 length read
+ // with no error.
+ if n == 0 {
+ return 0, io.EOF
+ }
+
+ if r.race != nil {
+ // See comments on Socket.race.
+ atomic.AddInt32(r.race, 1)
+ }
+
+ if int(n) > length {
+ return length, errMessageTruncated
+ }
+
+ return int(n), nil
+}
+
+// WriteVec writes the bufs to the socket. Returns bytes written.
+//
+// This function is not guaranteed to send all data, it returns
+// as soon as a single sendmsg call succeeds.
+func (w *SocketWriter) WriteVec(bufs [][]byte) (int, error) {
+ iovecs, _ := buildIovec(bufs, make([]syscall.Iovec, 0, 2))
+
+ if w.race != nil {
+ // See comments on Socket.race.
+ atomic.AddInt32(w.race, 1)
+ }
+
+ var msg syscall.Msghdr
+ if len(w.to) != 0 {
+ msg.Name = &w.to[0]
+ msg.Namelen = uint32(len(w.to))
+ }
+
+ if len(w.ControlMessage) != 0 {
+ msg.Control = &w.ControlMessage[0]
+ msg.Controllen = uint64(len(w.ControlMessage))
+ }
+
+ if len(iovecs) > 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ fd, ok := w.socket.enterFD()
+ if !ok {
+ return 0, syscall.EBADF
+ }
+ // Leave on returns below.
+ for {
+ // Try a non-blocking send first, so we don't give up the go runtime M.
+ n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
+ if e == 0 {
+ w.socket.gate.Leave()
+ return int(n), nil
+ }
+ if e == syscall.EINTR {
+ continue
+ }
+ if !w.blocking {
+ w.socket.gate.Leave()
+ return 0, e
+ }
+ if e != syscall.EAGAIN && e != syscall.EWOULDBLOCK {
+ w.socket.gate.Leave()
+ return 0, e
+ }
+
+ // Wait for the socket to become writeable.
+ err := w.socket.wait(true)
+ if err == errClosing {
+ err = syscall.EBADF
+ }
+ if err != nil {
+ w.socket.gate.Leave()
+ return 0, err
+ }
+ }
+ // Unreachable, no s.gate.Leave needed.
+}
+
+// getsockopt issues a getsockopt syscall.
+func getsockopt(fd int, level int, optname int, buf []byte) (uint32, error) {
+ l := uint32(len(buf))
+ _, _, e := syscall.RawSyscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), uintptr(level), uintptr(optname), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l)), 0)
+ if e != 0 {
+ return 0, e
+ }
+
+ return l, nil
+}
+
+// setsockopt issues a setsockopt syscall.
+func setsockopt(fd int, level int, optname int, buf []byte) error {
+ _, _, e := syscall.RawSyscall6(syscall.SYS_SETSOCKOPT, uintptr(fd), uintptr(level), uintptr(optname), uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)), 0)
+ if e != 0 {
+ return e
+ }
+
+ return nil
+}
+
+// getsockname issues a getsockname syscall.
+func getsockname(fd int, buf []byte) (uint32, error) {
+ l := uint32(len(buf))
+ _, _, e := syscall.RawSyscall(syscall.SYS_GETSOCKNAME, uintptr(fd), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l)))
+ if e != 0 {
+ return 0, e
+ }
+
+ return l, nil
+}
+
+// getpeername issues a getpeername syscall.
+func getpeername(fd int, buf []byte) (uint32, error) {
+ l := uint32(len(buf))
+ _, _, e := syscall.RawSyscall(syscall.SYS_GETPEERNAME, uintptr(fd), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l)))
+ if e != 0 {
+ return 0, e
+ }
+
+ return l, nil
+}
diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go
new file mode 100644
index 000000000..0f155ec74
--- /dev/null
+++ b/pkg/urpc/urpc.go
@@ -0,0 +1,636 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package urpc provides a minimal RPC package based on unet.
+//
+// RPC requests are _not_ concurrent and methods must be explicitly
+// registered. However, files may be send as part of the payload.
+package urpc
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+ "runtime"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// maxFiles determines the maximum file payload.
+const maxFiles = 16
+
+// ErrTooManyFiles is returned when too many file descriptors are mapped.
+var ErrTooManyFiles = errors.New("too many files")
+
+// ErrUnknownMethod is returned when a method is not known.
+var ErrUnknownMethod = errors.New("unknown method")
+
+// errStopped is an internal error indicating the server has been stopped.
+var errStopped = errors.New("stopped")
+
+// RemoteError is an error returned by the remote invocation.
+//
+// This indicates that the RPC transport was correct, but that the called
+// function itself returned an error.
+type RemoteError struct {
+ // Message is the result of calling Error() on the remote error.
+ Message string
+}
+
+// Error returns the remote error string.
+func (r RemoteError) Error() string {
+ return r.Message
+}
+
+// FilePayload may be _embedded_ in another type in order to send or receive a
+// file as a result of an RPC. These are not actually serialized, rather they
+// are sent via an accompanying SCM_RIGHTS message (plumbed through the unet
+// package).
+//
+// When embedding a FilePayload in an argument struct, the argument type _must_
+// be a pointer to the struct rather than the struct type itself. This is
+// because the urpc package defines pointer methods on FilePayload.
+type FilePayload struct {
+ Files []*os.File `json:"-"`
+}
+
+// ReleaseFD releases the indexth FD.
+func (f *FilePayload) ReleaseFD(index int) (*fd.FD, error) {
+ return fd.NewFromFile(f.Files[index])
+}
+
+// filePayload returns the file. It may be nil.
+func (f *FilePayload) filePayload() []*os.File {
+ return f.Files
+}
+
+// setFilePayload sets the payload.
+func (f *FilePayload) setFilePayload(fs []*os.File) {
+ f.Files = fs
+}
+
+// closeAll closes a slice of files.
+func closeAll(files []*os.File) {
+ for _, f := range files {
+ f.Close()
+ }
+}
+
+// filePayloader is implemented only by FilePayload and will be implicitly
+// implemented by types that have the FilePayload embedded. Note that there is
+// no way to implement these methods other than by embedding FilePayload, due
+// to the way unexported method names are mangled.
+type filePayloader interface {
+ filePayload() []*os.File
+ setFilePayload([]*os.File)
+}
+
+// clientCall is the client=>server method call on the client side.
+type clientCall struct {
+ Method string `json:"method"`
+ Arg interface{} `json:"arg"`
+}
+
+// serverCall is the client=>server method call on the server side.
+type serverCall struct {
+ Method string `json:"method"`
+ Arg json.RawMessage `json:"arg"`
+}
+
+// callResult is the server=>client method call result.
+type callResult struct {
+ Success bool `json:"success"`
+ Err string `json:"err"`
+ Result interface{} `json:"result"`
+}
+
+// registeredMethod is method registered with the server.
+type registeredMethod struct {
+ // fn is the underlying function.
+ fn reflect.Value
+
+ // rcvr is the receiver value.
+ rcvr reflect.Value
+
+ // argType is a typed argument.
+ argType reflect.Type
+
+ // resultType is also a type result.
+ resultType reflect.Type
+}
+
+// clientState is client metadata.
+//
+// The following are valid states:
+//
+// idle - not processing any requests, no close request.
+// processing - actively processing, no close request.
+// closeRequested - actively processing, pending close.
+// closed - client connection has been closed.
+//
+// The following transitions are possible:
+//
+// idle -> processing, closed
+// processing -> idle, closeRequested
+// closeRequested -> closed
+//
+type clientState int
+
+// See clientState.
+const (
+ idle clientState = iota
+ processing
+ closeRequested
+ closed
+)
+
+// Server is an RPC server.
+type Server struct {
+ // mu protects all fields, except wg.
+ mu sync.Mutex
+
+ // methods is the set of server methods.
+ methods map[string]registeredMethod
+
+ // clients is a map of clients.
+ clients map[*unet.Socket]clientState
+
+ // wg is a wait group for all outstanding clients.
+ wg sync.WaitGroup
+
+ // afterRPCCallback is called after each RPC is successfully completed.
+ afterRPCCallback func()
+}
+
+// NewServer returns a new server.
+func NewServer() *Server {
+ return NewServerWithCallback(nil)
+}
+
+// NewServerWithCallback returns a new server, who upon completion of each RPC
+// calls the given function.
+func NewServerWithCallback(afterRPCCallback func()) *Server {
+ return &Server{
+ methods: make(map[string]registeredMethod),
+ clients: make(map[*unet.Socket]clientState),
+ afterRPCCallback: afterRPCCallback,
+ }
+}
+
+// Register registers the given object as an RPC receiver.
+//
+// This functions is the same way as the built-in RPC package, but it does not
+// tolerate any object with non-conforming methods. Any non-confirming methods
+// will lead to an immediate panic, instead of being skipped or an error.
+// Panics will also be generated by anonymous objects and duplicate entries.
+func (s *Server) Register(obj interface{}) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ typ := reflect.TypeOf(obj)
+
+ // If we got a pointer, deref it to the underlying object. We need this to
+ // obtain the name of the underlying type.
+ typDeref := typ
+ if typ.Kind() == reflect.Ptr {
+ typDeref = typ.Elem()
+ }
+
+ for m := 0; m < typ.NumMethod(); m++ {
+ method := typ.Method(m)
+
+ if typDeref.Name() == "" {
+ // Can't be anonymous.
+ panic("type not named.")
+ }
+
+ prettyName := typDeref.Name() + "." + method.Name
+ if _, ok := s.methods[prettyName]; ok {
+ // Duplicate entry.
+ panic(fmt.Sprintf("method %s is duplicated.", prettyName))
+ }
+
+ if method.PkgPath != "" {
+ // Must be exported.
+ panic(fmt.Sprintf("method %s is not exported.", prettyName))
+ }
+ mtype := method.Type
+ if mtype.NumIn() != 3 {
+ // Need exactly two arguments (+ receiver).
+ panic(fmt.Sprintf("method %s has wrong number of arguments.", prettyName))
+ }
+ argType := mtype.In(1)
+ if argType.Kind() != reflect.Ptr {
+ // Need arg pointer.
+ panic(fmt.Sprintf("method %s has non-pointer first argument.", prettyName))
+ }
+ resultType := mtype.In(2)
+ if resultType.Kind() != reflect.Ptr {
+ // Need result pointer.
+ panic(fmt.Sprintf("method %s has non-pointer second argument.", prettyName))
+ }
+ if mtype.NumOut() != 1 {
+ // Need single return.
+ panic(fmt.Sprintf("method %s has wrong number of returns.", prettyName))
+ }
+ if returnType := mtype.Out(0); returnType != reflect.TypeOf((*error)(nil)).Elem() {
+ // Need error return.
+ panic(fmt.Sprintf("method %s has non-error return value.", prettyName))
+ }
+
+ // Register the method.
+ s.methods[prettyName] = registeredMethod{
+ fn: method.Func,
+ rcvr: reflect.ValueOf(obj),
+ argType: argType,
+ resultType: resultType,
+ }
+ }
+}
+
+// lookup looks up the given method.
+func (s *Server) lookup(method string) (registeredMethod, bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ rm, ok := s.methods[method]
+ return rm, ok
+}
+
+// handleOne handles a single call.
+func (s *Server) handleOne(client *unet.Socket) error {
+ // Unmarshal the call.
+ var c serverCall
+ newFs, err := unmarshal(client, &c)
+ if err != nil {
+ // Client is dead.
+ return err
+ }
+
+ defer func() {
+ if s.afterRPCCallback != nil {
+ s.afterRPCCallback()
+ }
+ }()
+ // Explicitly close all these files after the call.
+ //
+ // This is also explicitly a reference to the files after the call,
+ // which means they are kept open for the duration of the call.
+ defer closeAll(newFs)
+
+ // Start the request.
+ if !s.clientBeginRequest(client) {
+ // Client is dead; don't process this call.
+ return errStopped
+ }
+ defer s.clientEndRequest(client)
+
+ // Lookup the method.
+ rm, ok := s.lookup(c.Method)
+ if !ok {
+ // Try to serialize the error.
+ return marshal(client, &callResult{Err: ErrUnknownMethod.Error()}, nil)
+ }
+
+ // Unmarshal the arguments now that we know the type.
+ na := reflect.New(rm.argType.Elem())
+ if err := json.Unmarshal(c.Arg, na.Interface()); err != nil {
+ return marshal(client, &callResult{Err: err.Error()}, nil)
+ }
+
+ // Set the file payload as an argument.
+ if fp, ok := na.Interface().(filePayloader); ok {
+ fp.setFilePayload(newFs)
+ }
+
+ // Call the method.
+ re := reflect.New(rm.resultType.Elem())
+ rValues := rm.fn.Call([]reflect.Value{rm.rcvr, na, re})
+ if errVal := rValues[0].Interface(); errVal != nil {
+ return marshal(client, &callResult{Err: errVal.(error).Error()}, nil)
+ }
+
+ // Set the resulting payload.
+ var fs []*os.File
+ if fp, ok := re.Interface().(filePayloader); ok {
+ fs = fp.filePayload()
+ if len(fs) > maxFiles {
+ // Ugh. Send an error to the client, despite success.
+ return marshal(client, &callResult{Err: ErrTooManyFiles.Error()}, nil)
+ }
+ }
+
+ // Marshal the result.
+ return marshal(client, &callResult{Success: true, Result: re.Interface()}, fs)
+}
+
+// clientBeginRequest begins a request.
+//
+// If true is returned, the request may be processed. If false is returned,
+// then the server has been stopped and the request should be skipped.
+func (s *Server) clientBeginRequest(client *unet.Socket) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ switch state := s.clients[client]; state {
+ case idle:
+ // Mark as processing.
+ s.clients[client] = processing
+ return true
+ case closed:
+ // Whoops, how did this happen? Must have closed immediately
+ // following the deserialization. Don't let the RPC actually go
+ // through, since we won't be able to serialize a proper
+ // response.
+ return false
+ default:
+ // Should not happen.
+ panic(fmt.Sprintf("expected idle or closed, got %d", state))
+ }
+}
+
+// clientEndRequest ends a request.
+func (s *Server) clientEndRequest(client *unet.Socket) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ switch state := s.clients[client]; state {
+ case processing:
+ // Return to idle.
+ s.clients[client] = idle
+ case closeRequested:
+ // Close the connection.
+ client.Close()
+ s.clients[client] = closed
+ default:
+ // Should not happen.
+ panic(fmt.Sprintf("expected processing or requestClose, got %d", state))
+ }
+}
+
+// clientRegister registers a connection.
+//
+// See Stop for more context.
+func (s *Server) clientRegister(client *unet.Socket) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.clients[client] = idle
+ s.wg.Add(1)
+}
+
+// clientUnregister unregisters and closes a connection if necessary.
+//
+// See Stop for more context.
+func (s *Server) clientUnregister(client *unet.Socket) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ switch state := s.clients[client]; state {
+ case idle:
+ // Close the connection.
+ client.Close()
+ case closed:
+ // Already done.
+ default:
+ // Should not happen.
+ panic(fmt.Sprintf("expected idle or closed, got %d", state))
+ }
+ delete(s.clients, client)
+ s.wg.Done()
+}
+
+// handleRegistered handles calls from a registered client.
+func (s *Server) handleRegistered(client *unet.Socket) error {
+ for {
+ // Handle one call.
+ if err := s.handleOne(client); err != nil {
+ // Client is dead.
+ return err
+ }
+ }
+}
+
+// Handle synchronously handles a single client over a connection.
+func (s *Server) Handle(client *unet.Socket) error {
+ s.clientRegister(client)
+ defer s.clientUnregister(client)
+ return s.handleRegistered(client)
+}
+
+// StartHandling creates a goroutine that handles a single client over a
+// connection.
+func (s *Server) StartHandling(client *unet.Socket) {
+ s.clientRegister(client)
+ go func() { // S/R-SAFE: out of scope
+ defer s.clientUnregister(client)
+ s.handleRegistered(client)
+ }()
+}
+
+// Stop safely terminates outstanding clients.
+//
+// No new requests should be initiated after calling Stop. Existing clients
+// will be closed after completing any pending RPCs. This method will block
+// until all clients have disconnected.
+func (s *Server) Stop() {
+ // Wait for all outstanding requests.
+ defer s.wg.Wait()
+
+ // Close all known clients.
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for client, state := range s.clients {
+ switch state {
+ case idle:
+ // Close connection now.
+ client.Close()
+ s.clients[client] = closed
+ case processing:
+ // Request close when done.
+ s.clients[client] = closeRequested
+ }
+ }
+}
+
+// Client is a urpc client.
+type Client struct {
+ // mu protects all members.
+ //
+ // It also enforces single-call semantics.
+ mu sync.Mutex
+
+ // Socket is the underlying socket for this client.
+ //
+ // This _must_ be provided and must be closed manually by calling
+ // Close.
+ Socket *unet.Socket
+}
+
+// NewClient returns a new client.
+func NewClient(socket *unet.Socket) *Client {
+ return &Client{
+ Socket: socket,
+ }
+}
+
+// marshal sends the given FD and json struct.
+func marshal(s *unet.Socket, v interface{}, fs []*os.File) error {
+ // Marshal to a buffer.
+ data, err := json.Marshal(v)
+ if err != nil {
+ log.Warningf("urpc: error marshalling %s: %s", fmt.Sprintf("%v", v), err.Error())
+ return err
+ }
+
+ // Write to the socket.
+ w := s.Writer(true)
+ if fs != nil {
+ var fds []int
+ for _, f := range fs {
+ fds = append(fds, int(f.Fd()))
+ }
+ w.PackFDs(fds...)
+ }
+
+ // Send.
+ for n := 0; n < len(data); {
+ cur, err := w.WriteVec([][]byte{data[n:]})
+ if n == 0 && cur < len(data) {
+ // Don't send FDs anymore. This call is only made on
+ // the first successful call to WriteVec, assuming cur
+ // is not sufficient to fill the entire buffer.
+ w.PackFDs()
+ }
+ n += cur
+ if err != nil {
+ log.Warningf("urpc: error writing %v: %s", data[n:], err.Error())
+ return err
+ }
+ }
+
+ // We're done sending the fds to the client. Explicitly prevent fs from
+ // being GCed until here. Urpc rpcs often unlink the file to send, relying
+ // on the kernel to automatically delete it once the last reference is
+ // dropped. Until we successfully call sendmsg(2), fs may contain the last
+ // references to these files. Without this explicit reference to fs here,
+ // the go runtime is free to assume we're done with fs after the fd
+ // collection loop above, since it just sees us copying ints.
+ runtime.KeepAlive(fs)
+
+ log.Debugf("urpc: successfully marshalled %d bytes.", len(data))
+ return nil
+}
+
+// unmarhsal receives an FD (optional) and unmarshals the given struct.
+func unmarshal(s *unet.Socket, v interface{}) ([]*os.File, error) {
+ // Receive a single byte.
+ r := s.Reader(true)
+ r.EnableFDs(maxFiles)
+ firstByte := make([]byte, 1)
+
+ // Extract any FDs that may be there.
+ if _, err := r.ReadVec([][]byte{firstByte}); err != nil {
+ return nil, err
+ }
+ fds, err := r.ExtractFDs()
+ if err != nil {
+ log.Warningf("urpc: error extracting fds: %s", err.Error())
+ return nil, err
+ }
+ var fs []*os.File
+ for _, fd := range fds {
+ fs = append(fs, os.NewFile(uintptr(fd), "urpc"))
+ }
+
+ // Read the rest.
+ d := json.NewDecoder(io.MultiReader(bytes.NewBuffer(firstByte), s))
+ // urpc internally decodes / re-encodes the data with interface{} as the
+ // intermediate type. We have to unmarshal integers to json.Number type
+ // instead of the default float type for those intermediate values, such
+ // that when they get re-encoded, their values are not printed out in
+ // floating-point formats such as 1e9, which could not be decoded to
+ // explicitly typed intergers later.
+ d.UseNumber()
+ if err := d.Decode(v); err != nil {
+ log.Warningf("urpc: error decoding: %s", err.Error())
+ for _, f := range fs {
+ f.Close()
+ }
+ return nil, err
+ }
+
+ // All set.
+ log.Debugf("urpc: unmarshal success.")
+ return fs, nil
+}
+
+// Call calls a function.
+func (c *Client) Call(method string, arg interface{}, result interface{}) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // If arg is a FilePayload, not a *FilePayload, files won't actually be
+ // sent, so error out.
+ if _, ok := arg.(FilePayload); ok {
+ return fmt.Errorf("argument is a FilePayload, but should be a *FilePayload")
+ }
+
+ // Are there files to send?
+ var fs []*os.File
+ if fp, ok := arg.(filePayloader); ok {
+ fs = fp.filePayload()
+ if len(fs) > maxFiles {
+ return ErrTooManyFiles
+ }
+ }
+
+ // Marshal the data.
+ if err := marshal(c.Socket, &clientCall{Method: method, Arg: arg}, fs); err != nil {
+ return err
+ }
+
+ // Wait for the response.
+ callR := callResult{Result: result}
+ newFs, err := unmarshal(c.Socket, &callR)
+ if err != nil {
+ return fmt.Errorf("urpc method %q failed: %v", method, err)
+ }
+
+ // Set the file payload.
+ if fp, ok := result.(filePayloader); ok {
+ fp.setFilePayload(newFs)
+ } else {
+ closeAll(newFs)
+ }
+
+ // Did an error occur?
+ if !callR.Success {
+ return RemoteError{Message: callR.Err}
+ }
+
+ // All set.
+ return nil
+}
+
+// Close closes the underlying socket.
+//
+// Further calls to the client may result in undefined behavior.
+func (c *Client) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.Socket.Close()
+}
diff --git a/pkg/urpc/urpc_state_autogen.go b/pkg/urpc/urpc_state_autogen.go
new file mode 100755
index 000000000..01bf2172f
--- /dev/null
+++ b/pkg/urpc/urpc_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package urpc
+
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
new file mode 100644
index 000000000..8a65ed164
--- /dev/null
+++ b/pkg/waiter/waiter.go
@@ -0,0 +1,250 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package waiter provides the implementation of a wait queue, where waiters can
+// be enqueued to be notified when an event of interest happens.
+//
+// Becoming readable and/or writable are examples of events. Waiters are
+// expected to use a pattern similar to this to make a blocking function out of
+// a non-blocking one:
+//
+// func (o *object) blockingRead(...) error {
+// err := o.nonBlockingRead(...)
+// if err != ErrAgain {
+// // Completed with no need to wait!
+// return err
+// }
+//
+// e := createOrGetWaiterEntry(...)
+// o.EventRegister(&e, waiter.EventIn)
+// defer o.EventUnregister(&e)
+//
+// // We need to try to read again after registration because the
+// // object may have become readable between the last attempt to
+// // read and read registration.
+// err = o.nonBlockingRead(...)
+// for err == ErrAgain {
+// wait()
+// err = o.nonBlockingRead(...)
+// }
+//
+// return err
+// }
+//
+// Another goroutine needs to notify waiters when events happen. For example:
+//
+// func (o *object) Write(...) ... {
+// // Do write work.
+// [...]
+//
+// if oldDataAvailableSize == 0 && dataAvailableSize > 0 {
+// // If no data was available and now some data is
+// // available, the object became readable, so notify
+// // potential waiters about this.
+// o.Notify(waiter.EventIn)
+// }
+// }
+package waiter
+
+import (
+ "sync"
+)
+
+// EventMask represents io events as used in the poll() syscall.
+type EventMask uint16
+
+// Events that waiters can wait on. The meaning is the same as those in the
+// poll() syscall.
+const (
+ EventIn EventMask = 0x01 // POLLIN
+ EventPri EventMask = 0x02 // POLLPRI
+ EventOut EventMask = 0x04 // POLLOUT
+ EventErr EventMask = 0x08 // POLLERR
+ EventHUp EventMask = 0x10 // POLLHUP
+
+ allEvents EventMask = 0x1f
+)
+
+// EventMaskFromLinux returns an EventMask representing the supported events
+// from the Linux events e, which is in the format used by poll(2).
+func EventMaskFromLinux(e uint32) EventMask {
+ // Our flag definitions are currently identical to Linux.
+ return EventMask(e) & allEvents
+}
+
+// ToLinux returns e in the format used by Linux poll(2).
+func (e EventMask) ToLinux() uint32 {
+ // Our flag definitions are currently identical to Linux.
+ return uint32(e)
+}
+
+// Waitable contains the methods that need to be implemented by waitable
+// objects.
+type Waitable interface {
+ // Readiness returns what the object is currently ready for. If it's
+ // not ready for a desired purpose, the caller may use EventRegister and
+ // EventUnregister to get notifications once the object becomes ready.
+ //
+ // Implementations should allow for events like EventHUp and EventErr
+ // to be returned regardless of whether they are in the input EventMask.
+ Readiness(mask EventMask) EventMask
+
+ // EventRegister registers the given waiter entry to receive
+ // notifications when an event occurs that makes the object ready for
+ // at least one of the events in mask.
+ EventRegister(e *Entry, mask EventMask)
+
+ // EventUnregister unregisters a waiter entry previously registered with
+ // EventRegister().
+ EventUnregister(e *Entry)
+}
+
+// EntryCallback provides a notify callback.
+type EntryCallback interface {
+ // Callback is the function to be called when the waiter entry is
+ // notified. It is responsible for doing whatever is needed to wake up
+ // the waiter.
+ //
+ // The callback is supposed to perform minimal work, and cannot call
+ // any method on the queue itself because it will be locked while the
+ // callback is running.
+ Callback(e *Entry)
+}
+
+// Entry represents a waiter that can be add to the a wait queue. It can
+// only be in one queue at a time, and is added "intrusively" to the queue with
+// no extra memory allocations.
+//
+// +stateify savable
+type Entry struct {
+ // Context stores any state the waiter may wish to store in the entry
+ // itself, which may be used at wake up time.
+ //
+ // Note that use of this field is optional and state may alternatively be
+ // stored in the callback itself.
+ Context interface{}
+
+ Callback EntryCallback
+
+ // The following fields are protected by the queue lock.
+ mask EventMask
+ waiterEntry
+}
+
+type channelCallback struct{}
+
+// Callback implements EntryCallback.Callback.
+func (*channelCallback) Callback(e *Entry) {
+ ch := e.Context.(chan struct{})
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+}
+
+// NewChannelEntry initializes a new Entry that does a non-blocking write to a
+// struct{} channel when the callback is called. It returns the new Entry
+// instance and the channel being used.
+//
+// If a channel isn't specified (i.e., if "c" is nil), then NewChannelEntry
+// allocates a new channel.
+func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) {
+ if c == nil {
+ c = make(chan struct{}, 1)
+ }
+
+ return Entry{Context: c, Callback: &channelCallback{}}, c
+}
+
+// Queue represents the wait queue where waiters can be added and
+// notifiers can notify them when events happen.
+//
+// The zero value for waiter.Queue is an empty queue ready for use.
+//
+// +stateify savable
+type Queue struct {
+ list waiterList `state:"zerovalue"`
+ mu sync.RWMutex `state:"nosave"`
+}
+
+// EventRegister adds a waiter to the wait queue; the waiter will be notified
+// when at least one of the events specified in mask happens.
+func (q *Queue) EventRegister(e *Entry, mask EventMask) {
+ q.mu.Lock()
+ e.mask = mask
+ q.list.PushBack(e)
+ q.mu.Unlock()
+}
+
+// EventUnregister removes the given waiter entry from the wait queue.
+func (q *Queue) EventUnregister(e *Entry) {
+ q.mu.Lock()
+ q.list.Remove(e)
+ q.mu.Unlock()
+}
+
+// Notify notifies all waiters in the queue whose masks have at least one bit
+// in common with the notification mask.
+func (q *Queue) Notify(mask EventMask) {
+ q.mu.RLock()
+ for e := q.list.Front(); e != nil; e = e.Next() {
+ if mask&e.mask != 0 {
+ e.Callback.Callback(e)
+ }
+ }
+ q.mu.RUnlock()
+}
+
+// Events returns the set of events being waited on. It is the union of the
+// masks of all registered entries.
+func (q *Queue) Events() EventMask {
+ ret := EventMask(0)
+
+ q.mu.RLock()
+ for e := q.list.Front(); e != nil; e = e.Next() {
+ ret |= e.mask
+ }
+ q.mu.RUnlock()
+
+ return ret
+}
+
+// IsEmpty returns if the wait queue is empty or not.
+func (q *Queue) IsEmpty() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.list.Front() == nil
+}
+
+// AlwaysReady implements the Waitable interface but is always ready. Embedding
+// this struct into another struct makes it implement the boilerplate empty
+// functions automatically.
+type AlwaysReady struct {
+}
+
+// Readiness always returns the input mask because this object is always ready.
+func (*AlwaysReady) Readiness(mask EventMask) EventMask {
+ return mask
+}
+
+// EventRegister doesn't do anything because this object doesn't need to issue
+// notifications because its readiness never changes.
+func (*AlwaysReady) EventRegister(*Entry, EventMask) {
+}
+
+// EventUnregister doesn't do anything because this object doesn't need to issue
+// notifications because its readiness never changes.
+func (*AlwaysReady) EventUnregister(e *Entry) {
+}
diff --git a/pkg/waiter/waiter_list.go b/pkg/waiter/waiter_list.go
new file mode 100755
index 000000000..00b304a31
--- /dev/null
+++ b/pkg/waiter/waiter_list.go
@@ -0,0 +1,173 @@
+package waiter
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type waiterElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (waiterElementMapper) linkerFor(elem *Entry) *Entry { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type waiterList struct {
+ head *Entry
+ tail *Entry
+}
+
+// Reset resets list l to the empty state.
+func (l *waiterList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *waiterList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *waiterList) Front() *Entry {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *waiterList) Back() *Entry {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *waiterList) PushFront(e *Entry) {
+ waiterElementMapper{}.linkerFor(e).SetNext(l.head)
+ waiterElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ waiterElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *waiterList) PushBack(e *Entry) {
+ waiterElementMapper{}.linkerFor(e).SetNext(nil)
+ waiterElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ waiterElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *waiterList) PushBackList(m *waiterList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *waiterList) InsertAfter(b, e *Entry) {
+ a := waiterElementMapper{}.linkerFor(b).Next()
+ waiterElementMapper{}.linkerFor(e).SetNext(a)
+ waiterElementMapper{}.linkerFor(e).SetPrev(b)
+ waiterElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ waiterElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *waiterList) InsertBefore(a, e *Entry) {
+ b := waiterElementMapper{}.linkerFor(a).Prev()
+ waiterElementMapper{}.linkerFor(e).SetNext(a)
+ waiterElementMapper{}.linkerFor(e).SetPrev(b)
+ waiterElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ waiterElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *waiterList) Remove(e *Entry) {
+ prev := waiterElementMapper{}.linkerFor(e).Prev()
+ next := waiterElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ waiterElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ waiterElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type waiterEntry struct {
+ next *Entry
+ prev *Entry
+}
+
+// Next returns the entry that follows e in the list.
+func (e *waiterEntry) Next() *Entry {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *waiterEntry) Prev() *Entry {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *waiterEntry) SetNext(elem *Entry) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *waiterEntry) SetPrev(elem *Entry) {
+ e.prev = elem
+}
diff --git a/pkg/waiter/waiter_state_autogen.go b/pkg/waiter/waiter_state_autogen.go
new file mode 100755
index 000000000..b9d3e2798
--- /dev/null
+++ b/pkg/waiter/waiter_state_autogen.go
@@ -0,0 +1,67 @@
+// automatically generated by stateify.
+
+package waiter
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *Entry) beforeSave() {}
+func (x *Entry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Context", &x.Context)
+ m.Save("Callback", &x.Callback)
+ m.Save("mask", &x.mask)
+ m.Save("waiterEntry", &x.waiterEntry)
+}
+
+func (x *Entry) afterLoad() {}
+func (x *Entry) load(m state.Map) {
+ m.Load("Context", &x.Context)
+ m.Load("Callback", &x.Callback)
+ m.Load("mask", &x.mask)
+ m.Load("waiterEntry", &x.waiterEntry)
+}
+
+func (x *Queue) beforeSave() {}
+func (x *Queue) save(m state.Map) {
+ x.beforeSave()
+ if !state.IsZeroValue(x.list) { m.Failf("list is %v, expected zero", x.list) }
+}
+
+func (x *Queue) afterLoad() {}
+func (x *Queue) load(m state.Map) {
+}
+
+func (x *waiterList) beforeSave() {}
+func (x *waiterList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *waiterList) afterLoad() {}
+func (x *waiterList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *waiterEntry) beforeSave() {}
+func (x *waiterEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *waiterEntry) afterLoad() {}
+func (x *waiterEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("waiter.Entry", (*Entry)(nil), state.Fns{Save: (*Entry).save, Load: (*Entry).load})
+ state.Register("waiter.Queue", (*Queue)(nil), state.Fns{Save: (*Queue).save, Load: (*Queue).load})
+ state.Register("waiter.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load})
+ state.Register("waiter.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load})
+}